"""Space rendering module for Mesa visualizations.
This module provides functionality to render Mesa model spaces with different
backends, supporting various space types and visualization components.
"""
import contextlib
import warnings
from collections.abc import Callable
from typing import Literal
import altair as alt
import numpy as np
import pandas as pd
import mesa
from mesa.discrete_space import (
OrthogonalMooreGrid,
OrthogonalVonNeumannGrid,
VoronoiGrid,
)
from mesa.space import (
ContinuousSpace,
HexMultiGrid,
HexSingleGrid,
MultiGrid,
NetworkGrid,
SingleGrid,
_HexGrid,
)
from mesa.visualization.backends import AltairBackend, MatplotlibBackend
from mesa.visualization.space_drawers import (
ContinuousSpaceDrawer,
HexSpaceDrawer,
NetworkSpaceDrawer,
OrthogonalSpaceDrawer,
VoronoiSpaceDrawer,
)
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
Network = NetworkGrid | mesa.discrete_space.Network
[docs]
class SpaceRenderer:
"""Renders Mesa spaces using different visualization backends.
Supports multiple space types and backends for flexible visualization
of agent-based models.
"""
def __init__(
self,
model: mesa.Model,
backend: Literal["matplotlib", "altair"] | None = "matplotlib",
):
"""Initialize the space renderer.
Args:
model (mesa.Model): The Mesa model to render.
backend (Literal["matplotlib", "altair"] | None): The visualization backend to use.
"""
self.space = getattr(model, "grid", getattr(model, "space", None))
self.space_drawer = self._get_space_drawer()
self.space_mesh = None
self.agent_mesh = None
self.propertylayer_mesh = None
self.post_process_func = None
# Keep track of whether post-processing has been applied
# to avoid multiple applications on the same axis.
self._post_process_applied = False
self.backend = backend
if backend == "matplotlib":
self.backend_renderer = MatplotlibBackend(
self.space_drawer,
)
elif backend == "altair":
self.backend_renderer = AltairBackend(
self.space_drawer,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
self.backend_renderer.initialize_canvas()
def _get_space_drawer(self):
"""Get appropriate space drawer based on space type.
Returns:
Space drawer instance for the model's space type.
Raises:
ValueError: If the space type is not supported.
"""
if isinstance(self.space, HexGrid | _HexGrid):
return HexSpaceDrawer(self.space)
elif isinstance(self.space, OrthogonalGrid):
return OrthogonalSpaceDrawer(self.space)
elif isinstance(
self.space,
ContinuousSpace | mesa.experimental.continuous_space.ContinuousSpace,
):
return ContinuousSpaceDrawer(self.space)
elif isinstance(self.space, VoronoiGrid):
return VoronoiSpaceDrawer(self.space)
elif isinstance(self.space, Network):
return NetworkSpaceDrawer(self.space)
raise ValueError(
f"Unsupported space type: {type(self.space).__name__}. "
"Supported types are OrthogonalGrid, HexGrid, ContinuousSpace, VoronoiGrid, and Network."
)
def _map_coordinates(self, arguments):
"""Map agent coordinates to appropriate space coordinates.
Args:
arguments (dict): Dictionary containing agent data with coordinates.
Returns:
dict: Dictionary with mapped coordinates appropriate for the space type.
"""
mapped_arguments = arguments.copy()
if isinstance(self.space, OrthogonalGrid | VoronoiGrid | ContinuousSpace):
# Use the coordinates directly for Orthogonal grids, Voronoi grids and Continuous spaces
mapped_arguments["loc"] = arguments["loc"].astype(float)
elif isinstance(self.space, HexGrid):
# Map rectangular coordinates to hexagonal grid coordinates
loc = arguments["loc"].astype(float)
if loc.size > 0:
# Calculate hexagon centers
loc[:, 0] = loc[:, 0] * self.space_drawer.x_spacing + (
(loc[:, 1] - 1) % 2
) * (self.space_drawer.x_spacing / 2)
loc[:, 1] = loc[:, 1] * self.space_drawer.y_spacing
mapped_arguments["loc"] = loc
elif isinstance(self.space, Network):
# Map coordinates for Network spaces
loc = arguments["loc"].astype(float)
pos = np.asarray(list(self.space_drawer.pos.values()))
# For network only both x and y contains the correct coordinates
# use one of them
x = loc[:, 0]
if x is None:
x = loc[:, 1]
# Ensure x is an integer index for the position mapping
x = x.astype(int)
# FIXME: Find better way to handle this case
# x updates before pos can, therefore gives us index error that
# needs to be ignored.
with contextlib.suppress(IndexError):
mapped_arguments["loc"] = pos[x]
return mapped_arguments
[docs]
def draw_structure(self, **kwargs):
"""Draw the space structure.
Args:
**kwargs: Additional keyword arguments for the drawing function.
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
Returns:
The visual representation of the space structure.
"""
# Store space_kwargs for internal use
self.space_kwargs = kwargs
self.space_mesh = self.backend_renderer.draw_structure(**self.space_kwargs)
return self.space_mesh
[docs]
def draw_agents(self, agent_portrayal: Callable, **kwargs):
"""Draw agents on the space.
Args:
agent_portrayal (Callable): Function that takes an agent and returns AgentPortrayalStyle.
**kwargs: Additional keyword arguments for the drawing function.
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
Returns:
The visual representation of the agents.
"""
# Store data for internal use
self.agent_portrayal = agent_portrayal
self.agent_kwargs = kwargs
# Prepare data for agent plotting
arguments = self.backend_renderer.collect_agent_data(
self.space, agent_portrayal, default_size=self.space_drawer.s_default
)
arguments = self._map_coordinates(arguments)
self.agent_mesh = self.backend_renderer.draw_agents(
arguments, **self.agent_kwargs
)
return self.agent_mesh
[docs]
def draw_propertylayer(self, propertylayer_portrayal: Callable | dict):
"""Draw property layers on the space.
Args:
propertylayer_portrayal (Callable | dict): Function that returns PropertyLayerStyle
or dict with portrayal parameters.
Returns:
The visual representation of the property layers.
Raises:
Exception: If no property layers are found on the space.
"""
# Import here to avoid circular imports
from mesa.visualization.components import PropertyLayerStyle # noqa: PLC0415
def _dict_to_callable(portrayal_dict):
"""Convert legacy dict portrayal to callable.
Args:
portrayal_dict (dict): Dictionary with portrayal parameters.
Returns:
Callable: Function that returns PropertyLayerStyle.
"""
def style_callable(layer_object):
layer_name = layer_object.name
params = portrayal_dict.get(layer_name)
warnings.warn(
"Dict propertylayer_portrayal is deprecated. "
"Use a callable returning PropertyLayerStyle instead.",
PendingDeprecationWarning,
stacklevel=2,
)
if params is None:
return None
return PropertyLayerStyle(
color=params.get("color"),
colormap=params.get("colormap"),
alpha=params.get("alpha", PropertyLayerStyle.alpha),
vmin=params.get("vmin"),
vmax=params.get("vmax"),
colorbar=params.get("colorbar", PropertyLayerStyle.colorbar),
)
return style_callable
# Get property layers
try:
# old style spaces
property_layers = self.space.properties
except AttributeError:
# new style spaces
property_layers = self.space._mesa_property_layers
# Convert portrayal to callable if needed
if isinstance(propertylayer_portrayal, dict):
self.propertylayer_portrayal = _dict_to_callable(propertylayer_portrayal)
else:
self.propertylayer_portrayal = propertylayer_portrayal
number_of_propertylayers = sum(
[1 for layer in property_layers if layer != "empty"]
)
if number_of_propertylayers < 1:
raise Exception("No property layers were found on the space.")
self.propertylayer_mesh = self.backend_renderer.draw_propertylayer(
self.space, property_layers, self.propertylayer_portrayal
)
return self.propertylayer_mesh
[docs]
def render(
self,
agent_portrayal: Callable | None = None,
propertylayer_portrayal: Callable | dict | None = None,
post_process: Callable | None = None,
**kwargs,
):
"""Render the complete space with structure, agents, and property layers.
It is an all-in-one method that draws everything required therefore eliminates
the need of calling each method separately, but has a drawback, if want to pass
kwargs to customize the drawing, they have to be broken into
space_kwargs and agent_kwargs.
Args:
agent_portrayal (Callable | None): Function that returns AgentPortrayalStyle.
If None, agents won't be drawn.
propertylayer_portrayal (Callable | dict | None): Function that returns
PropertyLayerStyle or dict with portrayal parameters. If None,
property layers won't be drawn.
post_process (Callable | None): Function to apply post-processing to the canvas.
**kwargs: Additional keyword arguments for drawing functions.
* ``space_kwargs`` (dict): Arguments for ``draw_structure()``.
* ``agent_kwargs`` (dict): Arguments for ``draw_agents()``.
"""
space_kwargs = kwargs.pop("space_kwargs", {})
agent_kwargs = kwargs.pop("agent_kwargs", {})
if self.space_mesh is None:
self.draw_structure(**space_kwargs)
if self.agent_mesh is None and agent_portrayal is not None:
self.draw_agents(agent_portrayal, **agent_kwargs)
if self.propertylayer_mesh is None and propertylayer_portrayal is not None:
self.draw_propertylayer(propertylayer_portrayal)
self.post_process_func = post_process
return self
@property
def canvas(self):
"""Get the current canvas object.
Returns:
The backend-specific canvas object.
"""
if self.backend == "matplotlib":
ax = self.backend_renderer.ax
if ax is None:
self.backend_renderer.initialize_canvas()
return ax
elif self.backend == "altair":
structure = self.space_mesh if self.space_mesh else None
agents = self.agent_mesh if self.agent_mesh else None
prop_base, prop_cbar = self.propertylayer_mesh or (None, None)
if self.space_mesh:
structure = self.draw_structure(**self.space_kwargs)
if self.agent_mesh:
agents = self.draw_agents(self.agent_portrayal, **self.agent_kwargs)
if self.propertylayer_mesh:
prop_base, prop_cbar = self.draw_propertylayer(
self.propertylayer_portrayal
)
spatial_charts_list = [
chart for chart in [structure, prop_base, agents] if chart
]
main_spatial = None
if spatial_charts_list:
main_spatial = (
spatial_charts_list[0]
if len(spatial_charts_list) == 1
else alt.layer(*spatial_charts_list)
)
# Determine final chart by combining with color bar if present
final_chart = None
if main_spatial and prop_cbar:
final_chart = alt.vconcat(main_spatial, prop_cbar).configure_view(
stroke=None
)
elif main_spatial: # Only main_spatial, no prop_cbar
final_chart = main_spatial
elif prop_cbar: # Only prop_cbar, no main_spatial
final_chart = prop_cbar
final_chart = final_chart.configure_view(grid=False)
if final_chart is None:
# If no charts are available, return an empty chart
final_chart = (
alt.Chart(pd.DataFrame())
.mark_point()
.properties(width=450, height=350)
)
final_chart = final_chart.configure_view(stroke="black", strokeWidth=1.5)
return final_chart
@property
def post_process(self):
"""Get the current post-processing function.
Returns:
Callable | None: The post-processing function, or None if not set.
"""
return self.post_process_func
@post_process.setter
def post_process(self, func: Callable | None):
"""Set the post-processing function.
Args:
func (Callable | None): Function to apply post-processing to the canvas.
Should accept the canvas object as its first argument.
"""
self.post_process_func = func