Source code for mesa.visualization.mpl_space_drawing

"""Helper functions for drawing mesa spaces with matplotlib.

These functions are used by the provided matplotlib components, but can also be used to quickly visualize
a space with matplotlib for example when creating a mp4 of a movie run or when needing a figure
for a paper.

"""

import itertools
import os
import warnings
from collections.abc import Callable
from dataclasses import fields
from functools import lru_cache
from itertools import pairwise
from typing import Any

import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.collections import LineCollection, PatchCollection, PolyCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from matplotlib.patches import Polygon
from PIL import Image

import mesa
from mesa.discrete_space import (
    OrthogonalMooreGrid,
    OrthogonalVonNeumannGrid,
    VoronoiGrid,
)
from mesa.space import (
    ContinuousSpace,
    HexMultiGrid,
    HexSingleGrid,
    MultiGrid,
    NetworkGrid,
    SingleGrid,
)

CORRECTION_FACTOR_MARKER_ZOOM = 0.6
DEFAULT_MARKER_SIZE = 50

OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
Network = NetworkGrid | mesa.discrete_space.Network


[docs] def collect_agent_data( space: OrthogonalGrid | HexGrid | Network | ContinuousSpace | VoronoiGrid, agent_portrayal: Callable, default_size: float | None = None, ) -> dict: """Collect the plotting data for all agents in the space. Args: space: The space containing the Agents. agent_portrayal: A callable that is called with the agent and returns a AgentPortrayalStyle default_size: default size agent_portrayal should return a AgentPortrayalStyle, limited to size (size of marker), color (color of marker), zorder (z-order), marker (marker style), alpha, linewidths, and edgecolors. """ def get_agent_pos(agent, space): """Helper function to get the agent position depending on the grid type.""" if isinstance(space, NetworkGrid): agent_x, agent_y = agent.pos, agent.pos elif isinstance(space, Network): agent_x, agent_y = agent.cell.coordinate, agent.cell.coordinate else: agent_x = ( agent.pos[0] if agent.pos is not None else agent.cell.coordinate[0] ) agent_y = ( agent.pos[1] if agent.pos is not None else agent.cell.coordinate[1] ) return agent_x, agent_y arguments = { "loc": [], "s": [], "c": [], "marker": [], "zorder": [], "alpha": [], "edgecolors": [], "linewidths": [], } # Importing AgentPortrayalStyle inside the function to prevent circular imports from mesa.visualization.components import AgentPortrayalStyle # noqa: PLC0415 # Get AgentPortrayalStyle defaults style_fields = {f.name: f.default for f in fields(AgentPortrayalStyle)} class_default_size = style_fields.get("size") for agent in space.agents: portray_input = agent_portrayal(agent) aps: AgentPortrayalStyle if isinstance(portray_input, dict): warnings.warn( "Returning a dict from agent_portrayal is deprecated and will be removed " "in a future version. Please return an AgentPortrayalStyle instance instead.", PendingDeprecationWarning, stacklevel=2, ) dict_data = portray_input.copy() agent_x, agent_y = get_agent_pos(agent, space) # Extract values from the dict, using defaults if not provided size_val = dict_data.pop("size", style_fields.get("size")) color_val = dict_data.pop("color", style_fields.get("color")) marker_val = dict_data.pop("marker", style_fields.get("marker")) zorder_val = dict_data.pop("zorder", style_fields.get("zorder")) alpha_val = dict_data.pop("alpha", style_fields.get("alpha")) edgecolors_val = dict_data.pop("edgecolors", None) linewidths_val = dict_data.pop("linewidths", style_fields.get("linewidths")) aps = AgentPortrayalStyle( x=agent_x, y=agent_y, size=size_val, color=color_val, marker=marker_val, zorder=zorder_val, alpha=alpha_val, edgecolors=edgecolors_val, linewidths=linewidths_val, ) # Report list of unused data if dict_data: ignored_keys = list(dict_data.keys()) warnings.warn( f"The following keys from the returned dict were ignored: {', '.join(ignored_keys)}", UserWarning, stacklevel=2, ) else: aps = portray_input # default to agent's color if not provided if aps.edgecolors is None: aps.edgecolors = aps.color # get position if not specified if aps.x is None and aps.y is None: aps.x, aps.y = get_agent_pos(agent, space) # Collect common data from the AgentPortrayalStyle instance arguments["loc"].append((aps.x, aps.y)) # Determine final size for collection size_to_collect = aps.size if size_to_collect is None: size_to_collect = default_size if size_to_collect is None: size_to_collect = class_default_size arguments["s"].append(size_to_collect) arguments["c"].append(aps.color) arguments["marker"].append(aps.marker) arguments["zorder"].append(aps.zorder) arguments["alpha"].append(aps.alpha) if aps.edgecolors is not None: arguments["edgecolors"].append(aps.edgecolors) arguments["linewidths"].append(aps.linewidths) data = { k: (np.asarray(v, dtype=object) if k == "marker" else np.asarray(v)) for k, v in arguments.items() } # ensures that the tuples in marker dont get converted by numpy to an array resulting in a 2D array arr = np.empty(len(arguments["marker"]), dtype=object) arr[:] = arguments["marker"] data["marker"] = arr return data
[docs] def draw_space( space, agent_portrayal: Callable, propertylayer_portrayal: Callable | None = None, ax: Axes | None = None, **space_drawing_kwargs, ): """Draw a Matplotlib-based visualization of the space. Args: space: the space of the mesa model agent_portrayal: A callable that returns a AgnetPortrayalStyle specifying how to show the agent propertylayer_portrayal: A callable that returns a PropertyLayerStyle specifying how to show the property layer ax: the axes upon which to draw the plot space_drawing_kwargs: any additional keyword arguments to be passed on to the underlying function for drawing the space. Returns: Returns the Axes object with the plot drawn onto it. ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: _, ax = plt.subplots() # https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching match space: # order matters here given the class structure of old-style grid spaces case HexSingleGrid() | HexMultiGrid() | mesa.discrete_space.HexGrid(): draw_hex_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs) case ( mesa.space.SingleGrid() | OrthogonalMooreGrid() | OrthogonalVonNeumannGrid() | mesa.space.MultiGrid() ): draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs) case mesa.space.NetworkGrid() | mesa.discrete_space.Network(): draw_network(space, agent_portrayal, ax=ax, **space_drawing_kwargs) case ( mesa.space.ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace() ): draw_continuous_space(space, agent_portrayal, ax=ax) case VoronoiGrid(): draw_voronoi_grid(space, agent_portrayal, ax=ax) case _: raise ValueError(f"Unknown space type: {type(space)}") if propertylayer_portrayal: draw_property_layers(space, propertylayer_portrayal, ax=ax) return ax
@lru_cache(maxsize=1024, typed=True) def _get_hexmesh( width: int, height: int, size: float = 1.0 ) -> list[tuple[float, float]]: """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon.""" # Helper function for getting the vertices of a hexagon given the center and size def _get_hex_vertices( center_x: float, center_y: float, size: float = 1.0 ) -> list[tuple[float, float]]: """Get vertices for a hexagon centered at (center_x, center_y).""" vertices = [ (center_x, center_y + size), # top (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right (center_x, center_y - size), # bottom (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left ] return vertices x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size hexagons = [] for row, col in itertools.product(range(height), range(width)): # Calculate center position with offset for even rows x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) y = row * y_spacing hexagons.append(_get_hex_vertices(x, y, size)) return hexagons
[docs] def draw_property_layers( space, propertylayer_portrayal: dict[str, dict[str, Any]] | Callable, ax: Axes ): """Draw PropertyLayers on the given axes. Args: space (mesa.space._Grid): The space containing the PropertyLayers. propertylayer_portrayal (Callable): A function that accepts a property layer object and returns either a `PropertyLayerStyle` object defining its visualization, or `None` to skip drawing this particular layer. ax (matplotlib.axes.Axes): The axes to draw on. """ # Importing here to avoid circular import issues from mesa.visualization.components import PropertyLayerStyle # noqa: PLC0415 def _propertylayer_portryal_dict_to_callable( propertylayer_portrayal: dict[str, dict[str, Any]], ): """Helper function to convert a propertylayer_portrayal dict to a callable that return a PropertyLayerStyle.""" def style_callable(layer_object: Any): layer_name = layer_object.name params = propertylayer_portrayal.get(layer_name) warnings.warn( "The propertylayer_portrayal dict is deprecated. Use a callable that returns PropertyLayerStyle instead.", PendingDeprecationWarning, stacklevel=2, ) if params is None: return None # Layer not specified in the dict, so skip. return PropertyLayerStyle( color=params.get("color"), colormap=params.get("colormap"), alpha=params.get( "alpha", PropertyLayerStyle.alpha ), # Use defaults defined in the dataclass itself vmin=params.get("vmin"), vmax=params.get("vmax"), colorbar=params.get("colorbar", PropertyLayerStyle.colorbar), ) return style_callable try: # old style spaces property_layers = space.properties except AttributeError: # new style spaces property_layers = space._mesa_property_layers callable_portrayal: Callable[[Any], PropertyLayerStyle | None] if isinstance(propertylayer_portrayal, dict): callable_portrayal = _propertylayer_portryal_dict_to_callable( propertylayer_portrayal ) else: callable_portrayal = propertylayer_portrayal for layer_name in property_layers: if layer_name == "empty": # Skipping empty layer, automatically generated continue layer = property_layers.get(layer_name, None) portrayal = callable_portrayal(layer) if portrayal is None: # Not visualizing layers that do not have a defined visual encoding. continue data = layer.data.astype(float) if layer.data.dtype == bool else layer.data if (space.width, space.height) != data.shape: warnings.warn( f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).", UserWarning, stacklevel=2, ) color = portrayal.color colormap = portrayal.colormap alpha = portrayal.alpha vmin = portrayal.vmin if portrayal.vmin else np.min(data) vmax = portrayal.vmax if portrayal.vmax else np.max(data) if color: rgba_color = to_rgba(color) cmap = LinearSegmentedColormap.from_list( layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)] ) elif colormap: cmap = colormap if isinstance(cmap, list): cmap = LinearSegmentedColormap.from_list(layer_name, cmap) elif isinstance(cmap, str): cmap = plt.get_cmap(cmap) else: raise ValueError( f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." ) if isinstance(space, OrthogonalGrid): if color: data = data.T normalized_data = (data - vmin) / (vmax - vmin) rgba_data = np.full((*data.shape, 4), rgba_color) rgba_data[..., 3] *= normalized_data * alpha rgba_data = np.clip(rgba_data, 0, 1) ax.imshow(rgba_data, origin="lower") else: ax.imshow( data.T, cmap=cmap, alpha=alpha, vmin=vmin, vmax=vmax, origin="lower", ) elif isinstance(space, HexGrid): width, height = data.shape hexagons = _get_hexmesh(width, height) norm = Normalize(vmin=vmin, vmax=vmax) colors = data.ravel() if color: normalized_colors = np.clip(norm(colors), 0, 1) rgba_colors = np.full((len(colors), 4), rgba_color) rgba_colors[:, 3] = normalized_colors * alpha else: rgba_colors = cmap(norm(colors)) rgba_colors[..., 3] *= alpha collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1) ax.add_collection(collection) else: raise NotImplementedError( f"PropertyLayer visualization not implemented for {type(space)}." ) if portrayal.colorbar: norm = Normalize(vmin=vmin, vmax=vmax) sm = ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) plt.colorbar(sm, ax=ax, label=layer_name)
[docs] def draw_orthogonal_grid( space: OrthogonalGrid, agent_portrayal: Callable, ax: Axes | None = None, draw_grid: bool = True, **kwargs, ): """Visualize a orthogonal grid. Args: space: the space to visualize agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: _, ax = plt.subplots() # gather agent data s_default = (180 / max(space.width, space.height)) ** 2 arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # further styling ax.set_xlim(-0.5, space.width - 0.5) ax.set_ylim(-0.5, space.height - 0.5) # plot the agents _scatter(ax, arguments, **kwargs) if draw_grid: # Draw grid lines for x in np.arange(-0.5, space.width - 0.5, 1): ax.axvline(x, color="gray", linestyle=":") for y in np.arange(-0.5, space.height - 0.5, 1): ax.axhline(y, color="gray", linestyle=":") return ax
[docs] def draw_hex_grid( space: HexGrid, agent_portrayal: Callable, ax: Axes | None = None, draw_grid: bool = True, **kwargs, ): """Visualize a hex grid. Args: space: the space to visualize agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: _, ax = plt.subplots() # gather data s_default = (180 / max(space.width, space.height)) ** 2 arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # Parameters for hexagon grid size = 1.0 x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size # Calculate proper bounds that account for the full hexagon width and height x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) y_max = space.height * y_spacing # Add padding that accounts for the hexagon points x_padding = ( size * np.sqrt(3) / 2 ) # Distance from center to rightmost point of hexagon y_padding = size # Distance from center to topmost point of hexagon # Plot limits to perfectly contain the hexagonal grid # Determined through physical testing. ax.set_xlim(-2 * x_padding, x_max + x_padding) ax.set_ylim(-2 * y_padding, y_max + y_padding) loc = arguments["loc"].astype(float) # Calculate hexagon centers for agents if agents are present and plot them. if loc.size > 0: loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1] - 1) % 2) * (x_spacing / 2) loc[:, 1] = loc[:, 1] * y_spacing arguments["loc"] = loc # plot the agents _scatter(ax, arguments, **kwargs) def setup_hexmesh(width, height): """Helper function for creating the hexmesh with unique edges.""" edges = set() # Generate edges for each hexagon hexagons = _get_hexmesh(width, height) for vertices in hexagons: # Edge logic, connecting each vertex to the next for v1, v2 in pairwise([*vertices, vertices[0]]): # Sort vertices to ensure consistent edge representation and avoid duplicates. edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) edges.add(edge) return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1) if draw_grid: ax.add_collection(setup_hexmesh(space.width, space.height)) return ax
[docs] def draw_network( space: Network, agent_portrayal: Callable, ax: Axes | None = None, draw_grid: bool = True, layout_alg=nx.spring_layout, layout_kwargs=None, **kwargs, ): """Visualize a network space. Args: space: the space to visualize agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid layout_alg: a networkx layout algorithm or other callable with the same behavior layout_kwargs: a dictionary of keyword arguments for the layout algorithm kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: _, ax = plt.subplots() if layout_kwargs is None: layout_kwargs = {"seed": 0} # gather locations for nodes in network graph = space.G pos = layout_alg(graph, **layout_kwargs) x, y = list(zip(*pos.values())) xmin, xmax = min(x), max(x) ymin, ymax = min(y), max(y) width = xmax - xmin height = ymax - ymin x_padding = width / 20 y_padding = height / 20 # gather agent data s_default = (180 / max(width, height)) ** 2 arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # this assumes that nodes are identified by an integer # which is true for default nx graphs but might user changeable pos = np.asarray(list(pos.values())) loc = arguments["loc"] # For network only one of x and y contains the correct coordinates x = loc[:, 0] if x is None: x = loc[:, 1] arguments["loc"] = pos[x] # further styling ax.set_axis_off() ax.set_xlim(xmin=xmin - x_padding, xmax=xmax + x_padding) ax.set_ylim(ymin=ymin - y_padding, ymax=ymax + y_padding) # plot the agents _scatter(ax, arguments, **kwargs) if draw_grid: # fixme we need to draw the empty nodes as well edge_collection = nx.draw_networkx_edges( graph, pos, ax=ax, alpha=0.5, style="--" ) edge_collection.set_zorder(0) return ax
[docs] def draw_continuous_space( space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None, **kwargs ): """Visualize a continuous space. Args: space: the space to visualize agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: _, ax = plt.subplots() # space related setup width = space.x_max - space.x_min x_padding = width / 20 height = space.y_max - space.y_min y_padding = height / 20 # gather agent data s_default = (180 / max(width, height)) ** 2 arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # further visual styling border_style = "solid" if not space.torus else (0, (5, 10)) for spine in ax.spines.values(): spine.set_linewidth(1.5) spine.set_color("black") spine.set_linestyle(border_style) ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) # plot the agents _scatter(ax, arguments, **kwargs) return ax
[docs] def draw_voronoi_grid( space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None, draw_grid: bool = True, **kwargs, ): """Visualize a voronoi grid. Args: space: the space to visualize agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid or not kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: _, ax = plt.subplots() x_list = [i[0] for i in space.centroids_coordinates] y_list = [i[1] for i in space.centroids_coordinates] x_max = max(x_list) x_min = min(x_list) y_max = max(y_list) y_min = min(y_list) width = x_max - x_min x_padding = width / 20 height = y_max - y_min y_padding = height / 20 s_default = (180 / max(width, height)) ** 2 arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) ax.set_xlim(x_min - x_padding, x_max + x_padding) ax.set_ylim(y_min - y_padding, y_max + y_padding) _scatter(ax, arguments, **kwargs) def setup_voroinoimesh(cells): patches = [] for cell in cells: patch = Polygon(cell.properties["polygon"]) patches.append(patch) mesh = PatchCollection( patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1 ) return mesh if draw_grid: ax.add_collection(setup_voroinoimesh(space.all_cells.cells)) return ax
def _get_zoom_factor(ax, img): ax.get_figure().canvas.draw() bbox = ax.get_window_extent().transformed( ax.get_figure().dpi_scale_trans.inverted() ) # in inches width, height = ( bbox.width * ax.get_figure().dpi, bbox.height * ax.get_figure().dpi, ) # in pixel xr = ax.get_xlim() yr = ax.get_ylim() x_pixel_per_data = width / (xr[1] - xr[0]) y_pixel_per_data = height / (yr[1] - yr[0]) zoom_x = (x_pixel_per_data / img.width) * CORRECTION_FACTOR_MARKER_ZOOM zoom_y = (y_pixel_per_data / img.height) * CORRECTION_FACTOR_MARKER_ZOOM return min(zoom_x, zoom_y) def _scatter(ax: Axes, arguments, **kwargs): """Helper function for plotting the agents. Args: ax: a Matplotlib Axes instance arguments: the agents specific arguments for plotting kwargs: additional keyword arguments for ax.scatter """ loc = arguments.pop("loc") loc_x = loc[:, 0] loc_y = loc[:, 1] marker = arguments.pop("marker") zorder = arguments.pop("zorder") malpha = arguments.pop("alpha") msize = arguments.pop("s") # we check if edgecolor, linewidth, and alpha are specified # at the agent level, if not, we remove them from the arguments dict # and fallback to the default value in ax.scatter / use what is passed via **kwargs for entry in ["edgecolors", "linewidths"]: if len(arguments[entry]) == 0: arguments.pop(entry) else: if entry in kwargs: raise ValueError( f"{entry} is specified in agent portrayal and via plotting kwargs, you can only use one or the other" ) ax.get_figure().canvas.draw() for mark in set(marker): if isinstance(mark, (str | os.PathLike)) and os.path.isfile(mark): # images for m_size in np.unique(msize): image = Image.open(mark) im = OffsetImage( image, zoom=_get_zoom_factor(ax, image) * m_size / DEFAULT_MARKER_SIZE, ) im.image.axes = ax mask_marker = [m == mark for m in list(marker)] & (m_size == msize) for z_order in np.unique(zorder[mask_marker]): for m_alpha in np.unique(malpha[mask_marker]): mask = (z_order == zorder) & (m_alpha == malpha) & mask_marker for x, y in zip(loc_x[mask], loc_y[mask]): ab = AnnotationBbox( im, (x, y), frameon=False, pad=0.0, zorder=z_order, **kwargs, ) ax.add_artist(ab) else: # ordinary markers mask_marker = [m == mark for m in list(marker)] for z_order in np.unique(zorder[mask_marker]): zorder_mask = z_order == zorder & mask_marker ax.scatter( loc_x[zorder_mask], loc_y[zorder_mask], marker=mark, zorder=z_order, **{k: v[zorder_mask] for k, v in arguments.items()}, **kwargs, )