"""Mesa visualization space drawers.
This module provides the core logic for drawing spaces in Mesa, supporting
orthogonal grids, hexagonal grids, networks, continuous spaces, and Voronoi grids.
It includes implementations for both Matplotlib and Altair backends.
"""
import itertools
from itertools import pairwise
import altair as alt
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from matplotlib.collections import LineCollection
import mesa
from mesa.discrete_space import (
OrthogonalMooreGrid,
OrthogonalVonNeumannGrid,
VoronoiGrid,
)
from mesa.space import (
ContinuousSpace,
HexMultiGrid,
HexSingleGrid,
MultiGrid,
NetworkGrid,
SingleGrid,
)
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
Network = NetworkGrid | mesa.discrete_space.Network
[docs]
class BaseSpaceDrawer:
"""Base class for all space drawers."""
def __init__(self, space):
"""Initialize the base space drawer.
Args:
space: Grid/Space type to draw.
"""
self.space = space
self.viz_xmin = None
self.viz_xmax = None
self.viz_ymin = None
self.viz_ymax = None
[docs]
def get_viz_limits(self):
"""Get visualization limits for the space.
Returns:
A tuple of (xmin, xmax, ymin, ymax) for visualization limits.
"""
return (
self.viz_xmin,
self.viz_xmax,
self.viz_ymin,
self.viz_ymax,
)
[docs]
class OrthogonalSpaceDrawer(BaseSpaceDrawer):
"""Drawer for orthogonal grid spaces (SingleGrid, MultiGrid, Moore, VonNeumann)."""
def __init__(self, space: OrthogonalGrid):
"""Initialize the orthogonal space drawer.
Args:
space: The orthogonal grid space to draw
"""
super().__init__(space)
self.s_default = (180 / max(self.space.width, self.space.height)) ** 2
# Parameters for visualization limits
self.viz_xmin = -0.5
self.viz_xmax = self.space.width - 0.5
self.viz_ymin = -0.5
self.viz_ymax = self.space.height - 0.5
[docs]
def draw_matplotlib(self, ax=None, **space_kwargs):
"""Draw the orthogonal grid using matplotlib.
Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Additional keyword arguments for styling.
Examples:
figsize=(10, 10), color="blue", linewidth=2.
Returns:
The modified axes object
"""
fig_kwargs = {
"figsize": space_kwargs.pop("figsize", (8, 8)),
"dpi": space_kwargs.pop("dpi", 100),
}
if ax is None:
_, ax = plt.subplots(**fig_kwargs)
# gridline styling kwargs
line_kwargs = {
"color": "gray",
"linestyle": ":",
"linewidth": 1,
"alpha": 1,
}
line_kwargs.update(space_kwargs)
ax.set_xlim(self.viz_xmin, self.viz_xmax)
ax.set_ylim(self.viz_ymin, self.viz_ymax)
# Draw grid lines
for x in np.arange(-0.5, self.space.width, 1):
ax.axvline(x, **line_kwargs)
for y in np.arange(-0.5, self.space.height, 1):
ax.axhline(y, **line_kwargs)
return ax
[docs]
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
"""Draw the orthogonal grid using Altair.
Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Additional keyword arguments for styling the chart.
Examples:
width=500, height=500, title="Grid".
Returns:
Altair chart object
"""
# for axis and grid styling
axis_kwargs = {
"xlabel": chart_kwargs.pop("xlabel", "X"),
"ylabel": chart_kwargs.pop("ylabel", "Y"),
"grid_color": chart_kwargs.pop("grid_color", "lightgray"),
"grid_dash": chart_kwargs.pop("grid_dash", [2, 2]),
"grid_width": chart_kwargs.pop("grid_width", 1),
"grid_opacity": chart_kwargs.pop("grid_opacity", 1),
}
# for chart properties
chart_props = {
"width": chart_width,
"height": chart_height,
}
chart_props.update(chart_kwargs)
chart = (
alt.Chart(pd.DataFrame([{}]))
.mark_point(opacity=0)
.encode(
x=alt.X(
"X:Q",
title=axis_kwargs["xlabel"],
scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax], nice=False),
axis=alt.Axis(
grid=True,
gridColor=axis_kwargs["grid_color"],
gridDash=axis_kwargs["grid_dash"],
gridWidth=axis_kwargs["grid_width"],
gridOpacity=axis_kwargs["grid_opacity"],
),
),
y=alt.Y(
"Y:Q",
title=axis_kwargs["ylabel"],
scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax], nice=False),
axis=alt.Axis(
grid=True,
gridColor=axis_kwargs["grid_color"],
gridDash=axis_kwargs["grid_dash"],
gridWidth=axis_kwargs["grid_width"],
gridOpacity=axis_kwargs["grid_opacity"],
),
),
)
.properties(**chart_props)
)
return chart
[docs]
class HexSpaceDrawer(BaseSpaceDrawer):
"""Drawer for hexagonal grid spaces."""
def __init__(self, space: HexGrid):
"""Initialize the hexagonal space drawer.
Args:
space: The hexagonal grid space to draw
"""
super().__init__(space)
self.s_default = (180 / max(self.space.width, self.space.height)) ** 2
size = 1.0
self.x_spacing = np.sqrt(3) * size
self.y_spacing = 1.5 * size
x_max = self.space.width * self.x_spacing + (self.space.height % 2) * (
self.x_spacing / 2
)
y_max = self.space.height * self.y_spacing
x_padding = size * np.sqrt(3) / 2
y_padding = size
self.hexagons = self._get_hexmesh(self.space.width, self.space.height, size)
# Parameters for visualization limits
self.viz_xmin = -1.8 * x_padding
self.viz_xmax = x_max
self.viz_ymin = -1.8 * y_padding
self.viz_ymax = y_max
def _get_hexmesh(
self, width: int, height: int, size: float = 1.0
) -> list[tuple[float, float]]:
"""Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon."""
# Helper function for getting the vertices of a hexagon given the center and size
def _get_hex_vertices(
center_x: float, center_y: float, size: float = 1.0
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size
hexagons = []
for row, col in itertools.product(range(height), range(width)):
# Calculate center position with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing
hexagons.append(_get_hex_vertices(x, y, size))
return hexagons
def _get_unique_edges(self):
"""Helper method to extract unique edges from all hexagons."""
edges = set()
# Generate edges for each hexagon
for vertices in self.hexagons:
# Edge logic, connecting each vertex to the next
for v1, v2 in pairwise([*vertices, vertices[0]]):
# Sort vertices to ensure consistent edge representation and avoid duplicates.
edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
edges.add(edge)
return edges
[docs]
def draw_matplotlib(self, ax=None, **space_kwargs):
"""Draw the hexagonal grid using matplotlib.
Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Additional keyword arguments for styling.
Examples:
figsize=(8, 8), color="red", alpha=0.5.
Returns:
The modified axes object
"""
fig_kwargs = {
"figsize": space_kwargs.pop("figsize", (8, 8)),
"dpi": space_kwargs.pop("dpi", 100),
}
if ax is None:
_, ax = plt.subplots(**fig_kwargs)
line_kwargs = {
"color": "black",
"linestyle": ":",
"linewidth": 1,
"alpha": 0.8,
}
line_kwargs.update(space_kwargs)
ax.set_xlim(self.viz_xmin, self.viz_xmax)
ax.set_ylim(self.viz_ymin, self.viz_ymax)
ax.set_aspect("equal", adjustable="box")
edges = self._get_unique_edges()
ax.add_collection(LineCollection(list(edges), **line_kwargs))
return ax
[docs]
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
"""Draw the hexagonal grid using Altair.
Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Additional keyword arguments for styling the chart.
Examples:
* Line properties like color, strokeDash, strokeWidth, opacity.
* Other kwargs (e.g., width, title) apply to the chart.
Returns:
Altair chart object representing the hexagonal grid.
"""
mark_kwargs = {
"color": chart_kwargs.pop("color", "black"),
"strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
"strokeWidth": chart_kwargs.pop("strokeWidth", 1),
"opacity": chart_kwargs.pop("opacity", 0.8),
}
chart_props = {
"width": chart_width,
"height": chart_height,
}
chart_props.update(chart_kwargs)
edge_data = []
edges = self._get_unique_edges()
for i, edge_tuple in enumerate(edges):
p1, p2 = edge_tuple
edge_data.append({"edge_id": i, "point_order": 0, "x": p1[0], "y": p1[1]})
edge_data.append({"edge_id": i, "point_order": 1, "x": p2[0], "y": p2[1]})
source = pd.DataFrame(edge_data)
chart = (
alt.Chart(source)
.mark_line(**mark_kwargs)
.encode(
x=alt.X(
"x:Q",
scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax], zero=False),
axis=None,
),
y=alt.Y(
"y:Q",
scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax], zero=False),
axis=None,
),
detail="edge_id:N",
order="point_order:Q",
)
.properties(**chart_props)
)
return chart
[docs]
class NetworkSpaceDrawer(BaseSpaceDrawer):
"""Drawer for network-based spaces."""
def __init__(
self,
space: Network,
layout_alg=nx.spring_layout,
layout_kwargs=None,
):
"""Initialize the network space drawer.
Args:
space: The network space to draw
layout_alg: NetworkX layout algorithm to use
layout_kwargs: Keyword arguments for the layout algorithm
"""
super().__init__(space)
self.layout_alg = layout_alg
self.layout_kwargs = layout_kwargs if layout_kwargs is not None else {"seed": 0}
# gather locations for nodes in network
self.graph = self.space.G
self.pos = self.layout_alg(self.graph, **self.layout_kwargs)
x, y = list(zip(*self.pos.values())) if self.pos else ([0], [0])
xmin, xmax = min(x), max(x)
ymin, ymax = min(y), max(y)
width = xmax - xmin
height = ymax - ymin
self.s_default = (
(180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
)
# Parameters for visualization limits
self.viz_xmin = xmin - width / 20
self.viz_xmax = xmax + width / 20
self.viz_ymin = ymin - height / 20
self.viz_ymax = ymax + height / 20
[docs]
def draw_matplotlib(self, ax=None, **space_kwargs):
"""Draw the network using matplotlib.
Args:
ax: Matplotlib axes object to draw on.
**space_kwargs: Dictionaries of keyword arguments for styling.
Can also handle zorder for both nodes and edges if passed.
* ``node_kwargs``: A dict passed to nx.draw_networkx_nodes.
* ``edge_kwargs``: A dict passed to nx.draw_networkx_edges.
Returns:
The modified axes object.
"""
if ax is None:
_, ax = plt.subplots()
ax.set_axis_off()
ax.set_xlim(self.viz_xmin, self.viz_xmax)
ax.set_ylim(self.viz_ymin, self.viz_ymax)
node_kwargs = {"alpha": 0.5}
edge_kwargs = {"alpha": 0.5, "style": "--"}
node_kwargs.update(space_kwargs.get("node_kwargs", {}))
edge_kwargs.update(space_kwargs.get("edge_kwargs", {}))
node_zorder = node_kwargs.pop("zorder", 1)
edge_zorder = edge_kwargs.pop("zorder", 0)
nodes = nx.draw_networkx_nodes(self.graph, self.pos, ax=ax, **node_kwargs)
edges = nx.draw_networkx_edges(self.graph, self.pos, ax=ax, **edge_kwargs)
if nodes:
nodes.set_zorder(node_zorder)
# In some matplotlib versions, edges can be a list of collections
if isinstance(edges, list):
for edge_collection in edges:
edge_collection.set_zorder(edge_zorder)
elif edges:
edges.set_zorder(edge_zorder)
return ax
[docs]
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
"""Draw the network using Altair.
Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Dictionaries for styling the chart.
* ``node_kwargs``: A dict of properties for the node's mark_point.
* ``edge_kwargs``: A dict of properties for the edge's mark_rule.
* Other kwargs (e.g., title, width) are passed to chart.properties().
Returns:
Altair chart object representing the network.
"""
nodes_df = pd.DataFrame(self.pos).T.reset_index()
nodes_df.columns = ["node", "x", "y"]
edges_df = pd.DataFrame(self.graph.edges(), columns=["source", "target"])
edge_positions = edges_df.merge(
nodes_df, how="left", left_on="source", right_on="node"
).merge(
nodes_df,
how="left",
left_on="target",
right_on="node",
suffixes=("_source", "_target"),
)
node_mark_kwargs = {"filled": True, "opacity": 0.5, "size": 500}
edge_mark_kwargs = {"opacity": 0.5, "strokeDash": [5, 3]}
node_mark_kwargs.update(chart_kwargs.pop("node_kwargs", {}))
edge_mark_kwargs.update(chart_kwargs.pop("edge_kwargs", {}))
chart_kwargs = {
"width": chart_width,
"height": chart_height,
}
chart_kwargs.update(chart_kwargs)
edge_plot = (
alt.Chart(edge_positions)
.mark_rule(**edge_mark_kwargs)
.encode(
x=alt.X(
"x_source",
scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax]),
axis=None,
),
y=alt.Y(
"y_source",
scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax]),
axis=None,
),
x2="x_target",
y2="y_target",
)
)
node_plot = (
alt.Chart(nodes_df)
.mark_point(**node_mark_kwargs)
.encode(x="x", y="y", tooltip=["node"])
)
chart = edge_plot + node_plot
if chart_kwargs:
chart = chart.properties(**chart_kwargs)
return chart
[docs]
class ContinuousSpaceDrawer(BaseSpaceDrawer):
"""Drawer for continuous spaces."""
def __init__(self, space: ContinuousSpace):
"""Initialize the continuous space drawer.
Args:
space: The continuous space to draw
"""
super().__init__(space)
width = self.space.x_max - self.space.x_min
height = self.space.y_max - self.space.y_min
self.s_default = (
(180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
)
x_padding = width / 20
y_padding = height / 20
self.viz_xmin = self.space.x_min - x_padding
self.viz_xmax = self.space.x_max + x_padding
self.viz_ymin = self.space.y_min - y_padding
self.viz_ymax = self.space.y_max + y_padding
[docs]
def draw_matplotlib(self, ax=None, **space_kwargs):
"""Draw the continuous space using matplotlib.
Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Keyword arguments for styling the axis frame.
Examples:
linewidth=3, color="green"
Returns:
The modified axes object
"""
if ax is None:
_, ax = plt.subplots()
border_style = "solid" if not self.space.torus else (0, (5, 10))
spine_kwargs = {"linewidth": 1.5, "color": "black", "linestyle": border_style}
spine_kwargs.update(space_kwargs)
for spine in ax.spines.values():
spine.set(**spine_kwargs)
ax.set_xlim(self.viz_xmin, self.viz_xmax)
ax.set_ylim(self.viz_ymin, self.viz_ymax)
return ax
[docs]
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
"""Draw the continuous space using Altair.
Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Keyword arguments for styling the chart's view properties.
See Altair's documentation for `configure_view`.
Returns:
An Altair Chart object representing the space.
"""
chart_props = {"width": chart_width, "height": chart_height}
chart_props.update(chart_kwargs)
chart = (
alt.Chart(pd.DataFrame([{}]))
.mark_rect(color="transparent")
.encode(
x=alt.X(scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax])),
y=alt.Y(scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax])),
)
.properties(**chart_props)
)
return chart
[docs]
class VoronoiSpaceDrawer(BaseSpaceDrawer):
"""Drawer for Voronoi diagram spaces."""
def __init__(self, space: VoronoiGrid):
"""Initialize the Voronoi space drawer.
Args:
space: The Voronoi grid space to draw
"""
super().__init__(space)
if self.space.centroids_coordinates:
x_list = [i[0] for i in self.space.centroids_coordinates]
y_list = [i[1] for i in self.space.centroids_coordinates]
x_max, x_min = max(x_list), min(x_list)
y_max, y_min = max(y_list), min(y_list)
else:
x_max, x_min, y_max, y_min = 1, 0, 1, 0
width = x_max - x_min
height = y_max - y_min
self.s_default = (
(180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
)
# Parameters for visualization limits
self.viz_xmin = x_min - width / 20
self.viz_xmax = x_max + width / 20
self.viz_ymin = y_min - height / 20
self.viz_ymax = y_max + height / 20
def _clip_line(self, p1, p2, box):
"""Clips a line segment using the Cohen-Sutherland algorithm.
Returns the clipped line segment (p1, p2) or None if it's outside.
"""
x1, y1 = p1
x2, y2 = p2
min_x, min_y, max_x, max_y = box
# Define region codes
INSIDE, LEFT, RIGHT, BOTTOM, TOP = 0, 1, 2, 4, 8 # noqa: N806
def compute_outcode(x, y):
code = INSIDE
if x < min_x:
code |= LEFT
elif x > max_x:
code |= RIGHT
if y < min_y:
code |= BOTTOM
elif y > max_y:
code |= TOP
return code
outcode1 = compute_outcode(x1, y1)
outcode2 = compute_outcode(x2, y2)
while True:
if not (outcode1 | outcode2): # Both points inside
return (x1, y1), (x2, y2)
elif outcode1 & outcode2: # Both points share an outside region
return None
else:
outcode_out = outcode1 if outcode1 else outcode2
x, y = 0.0, 0.0
# Check for horizontal line
if y1 != y2:
if outcode_out & TOP:
x = x1 + (x2 - x1) * (max_y - y1) / (y2 - y1)
y = max_y
elif outcode_out & BOTTOM:
x = x1 + (x2 - x1) * (min_y - y1) / (y2 - y1)
y = min_y
# Check for vertical line
if x1 != x2:
if outcode_out & RIGHT:
y = y1 + (y2 - y1) * (max_x - x1) / (x2 - x1)
x = max_x
elif outcode_out & LEFT:
y = y1 + (y2 - y1) * (min_x - x1) / (x2 - x1)
x = min_x
if outcode_out == outcode1:
x1, y1 = x, y
outcode1 = compute_outcode(x1, y1)
else:
x2, y2 = x, y
outcode2 = compute_outcode(x2, y2)
def _get_clipped_segments(self):
"""Helper method to perform the segment extraction, de-duplication and clipping logic."""
clip_box = (
self.viz_xmin,
self.viz_ymin,
self.viz_xmax,
self.viz_ymax,
)
unique_segments = set()
for cell in self.space.all_cells.cells:
vertices = [tuple(v) for v in cell.properties["polygon"]]
for p1, p2 in pairwise([*vertices, vertices[0]]):
# Sort to avoid duplicate segments going in opposite directions
unique_segments.add(tuple(sorted((p1, p2))))
# Clip each unique segment
final_segments = []
for p1, p2 in unique_segments:
clipped_segment = self._clip_line(p1, p2, clip_box)
if clipped_segment:
final_segments.append(clipped_segment)
return final_segments, clip_box
[docs]
def draw_matplotlib(self, ax=None, **space_kwargs):
"""Draw the Voronoi diagram using matplotlib.
Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Keyword arguments passed to matplotlib's LineCollection.
Examples:
lw=2, alpha=0.5, colors='red'
Returns:
The modified axes object
"""
if ax is None:
_, ax = plt.subplots()
final_segments, clip_box = self._get_clipped_segments()
ax.set_xlim(clip_box[0], clip_box[2])
ax.set_ylim(clip_box[1], clip_box[3])
if final_segments:
# Define default styles for the plot
style_args = {"colors": "k", "linestyle": "dotted", "lw": 1}
style_args.update(space_kwargs)
# Create the LineCollection with the final styles
lc = LineCollection(final_segments, **style_args)
ax.add_collection(lc)
return ax
[docs]
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
"""Draw the Voronoi diagram using Altair.
Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Additional keyword arguments for styling the chart.
Examples:
* Line properties like color, strokeDash, strokeWidth, opacity.
* Other kwargs (e.g., width, title) apply to the chart.
Returns:
An Altair Chart object representing the Voronoi diagram.
"""
final_segments, clip_box = self._get_clipped_segments()
# Prepare data
final_data = []
for i, (p1, p2) in enumerate(final_segments):
final_data.append({"x": p1[0], "y": p1[1], "line_id": i})
final_data.append({"x": p2[0], "y": p2[1], "line_id": i})
df = pd.DataFrame(final_data)
# Define default properties for the mark
mark_kwargs = {
"color": chart_kwargs.pop("color", "black"),
"strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
"strokeWidth": chart_kwargs.pop("strokeWidth", 1),
"opacity": chart_kwargs.pop("opacity", 0.8),
}
chart_props = {"width": chart_width, "height": chart_height}
chart_props.update(chart_kwargs)
chart = (
alt.Chart(df)
.mark_line(**mark_kwargs)
.encode(
x=alt.X(
"x:Q", scale=alt.Scale(domain=[clip_box[0], clip_box[2]]), axis=None
),
y=alt.Y(
"y:Q", scale=alt.Scale(domain=[clip_box[1], clip_box[3]]), axis=None
),
detail="line_id:N",
)
.properties(**chart_props)
)
return chart