Source code for ecoli.analysis.colony.snapshots
import os
import math
import random
import itertools
import shutil
import concurrent.futures
from typing import Any
import cv2
from duckdb import DuckDBPyConnection
import numpy as np
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.patches as patches
from matplotlib.lines import Line2D
from matplotlib.colors import hsv_to_rgb, rgb_to_hsv
from mpl_toolkits.axes_grid1 import make_axes_locatable, anchored_artists
from vivarium.library.dict_utils import get_value_from_path
from vivarium.library.units import Quantity, units
from vivarium.plots.agents_multigen import plot_agents_multigen
DEFAULT_BOUNDS = [10, 10]
DEFAULT_HIGHLIGHT_COLOR = [0, 1, 1]
PLOT_WIDTH = 7
# constants
PI = math.pi
# colors for phylogeny initial agents
HUES = [hue / 360 for hue in np.linspace(0, 360, 30)]
DEFAULT_HUE = 45 / 360
DEFAULT_SV = [100.0 / 100.0, 70.0 / 100.0]
BASELINE_TAG_COLOR = [0, 0, 1] # HSV
FLOURESCENT_SV = [0.75, 1.0] # SV for fluorescent colors
[docs]
class LineWidthData(Line2D):
def __init__(self, *args, **kwargs):
_lw_data = kwargs.pop("linewidth", 1)
super().__init__(*args, **kwargs)
self._lw_data = _lw_data
[docs]
def _get_lw(self):
if self.axes is not None:
ppd = 72.0 / self.axes.figure.dpi
trans = self.axes.transData.transform
return ((trans((1, self._lw_data)) - trans((0, 0))) * ppd)[1]
else:
return 1
_linewidth = property(_get_lw, _set_lw)
[docs]
def init_axes(
fig,
edge_length_x,
edge_length_y,
grid,
row_idx,
col_idx,
time,
molecule,
ylabel_size=20,
title_size=12,
):
ax = fig.add_subplot(grid[row_idx, col_idx])
if row_idx == 0:
plot_title = "time: {:.4f} s".format(float(time))
plt.title(plot_title, y=1.08, fontsize=title_size)
if col_idx == 0:
ax.set_ylabel(
molecule,
fontsize=ylabel_size,
rotation="horizontal",
horizontalalignment="right",
)
ax.set(xlim=[0, edge_length_x], ylim=[0, edge_length_y], aspect=1)
ax.set_yticklabels([])
ax.set_xticklabels([])
return ax
[docs]
def add_time_axis(
fig, grid, n_rows, n_cols, n_snapshots, snapshot_times, time_unit="s"
):
# Add time axis across subplots
super_spec = matplotlib.gridspec.SubplotSpec(
grid,
(n_rows - 1) * n_cols,
(n_rows - 1) * n_cols + n_snapshots - 1,
)
grid_params = grid.get_subplot_params()
if n_snapshots > 1:
time_per_snapshot = (snapshot_times[-1] - snapshot_times[0]) / (
(n_snapshots - 1) * (grid_params.wspace + 1)
)
else:
time_per_snapshot = 1 # Arbitrary
super_ax = fig.add_subplot( # type: ignore
super_spec,
xticks=snapshot_times,
xlim=(
snapshot_times[0] - time_per_snapshot / 2,
snapshot_times[-1] + time_per_snapshot / 2,
),
yticks=[],
)
super_ax.set_xlabel(f"Time ({time_unit})", labelpad=50) # type: ignore
super_ax.xaxis.set_tick_params(width=2, length=8)
for spine_name in ("top", "right", "left"):
super_ax.spines[spine_name].set_visible(False)
super_ax.spines["bottom"].set_linewidth(2)
[docs]
def mutate_color(baseline_hsv):
mutation = 0.1
new_hsv = [(n + np.random.uniform(-mutation, mutation)) for n in baseline_hsv]
# wrap hue around
new_hsv[0] = new_hsv[0] % 1
# reflect saturation and value
if new_hsv[1] > 1:
new_hsv[1] = 2 - new_hsv[1]
if new_hsv[2] > 1:
new_hsv[2] = 2 - new_hsv[2]
return new_hsv
[docs]
def plot_agent(
ax,
data,
color,
agent_shape,
membrane_width=0.1,
membrane_color=[0, 0, 0],
alpha=1,
):
"""Plot an agent
Args:
ax: The axes to draw on.
data (dict): The agent data dictionary.
color (list): HSV color of agent body.
agent_shape (str): One of ``rectangle``, ``segment``, and ``circle``.
membrane_width (float): Width of drawn agent boundary.
membrane_color (list): RGB color of drawn agent boundary.
"""
if not data or not data.get("boundary"):
return
x_center = data["boundary"]["location"][0]
y_center = data["boundary"]["location"][1]
# get color, convert to rgb. Strings are already RGB
if isinstance(color, str):
rgb = color
else:
rgb = hsv_to_rgb(color)
if agent_shape == "rectangle":
theta = (
data["boundary"]["angle"] / PI * 180 + 90
) # rotate 90 degrees to match field
length = data["boundary"]["length"]
width = data["boundary"]["width"]
# get bottom left position
x_offset = width / 2
y_offset = length / 2
theta_rad = math.radians(theta)
dx = x_offset * math.cos(theta_rad) - y_offset * math.sin(theta_rad)
dy = x_offset * math.sin(theta_rad) + y_offset * math.cos(theta_rad)
x = x_center - dx
y = y_center - dy
# Create a rectangle
shape = patches.Rectangle(
(x, y),
width,
length,
angle=theta,
linewidth=membrane_width,
edgecolor=membrane_color,
alpha=alpha,
facecolor=rgb,
)
ax.add_patch(shape)
elif agent_shape == "segment":
theta = (
data["boundary"]["angle"] / PI * 180 + 90
) # rotate 90 degrees to match field
length = data["boundary"]["length"]
width = data["boundary"]["width"]
radius = width / 2
# get the two ends
length_offset = (length / 2) - radius
theta_rad = math.radians(theta)
dx = -length_offset * math.sin(theta_rad)
dy = length_offset * math.cos(theta_rad)
x1 = x_center - dx
y1 = y_center - dy
x2 = x_center + dx
y2 = y_center + dy
# segment plot
membrane = LineWidthData(
[x1, x2],
[y1, y2],
color=membrane_color,
linewidth=width,
alpha=alpha,
solid_capstyle="round",
)
line = LineWidthData(
[x1, x2],
[y1, y2],
color=rgb,
alpha=alpha,
linewidth=width - membrane_width,
solid_capstyle="round",
)
ax.add_line(membrane)
ax.add_line(line)
elif agent_shape == "circle":
diameter = data["boundary"]["diameter"]
# get bottom left position
radius = diameter / 2
x = x_center - radius
y = y_center - radius
# Create a circle
circle = patches.Circle(
(x, y),
radius,
linewidth=membrane_width,
edgecolor=membrane_color,
alpha=alpha,
)
ax.add_patch(circle)
[docs]
def plot_agents(
ax,
agents,
agent_colors=None,
agent_shape="segment",
dead_color=None,
membrane_width=0.1,
membrane_color=[1, 1, 1],
alpha=1,
highlight_agent=None,
):
"""Plot agents.
Args:
ax: the axis for plot
agents (dict): a mapping from agent ID to that agent's data,
which should have keys ``location``, ``angle``, ``length``,
and ``width``.
agent_colors (dict): Mapping from agent ID to HSV color.
dead_color (list): List of 3 floats that define HSV color to use
for dead cells. Dead cells only get treated differently if
this is set.
membrane_width (float): Width of agent outline to draw.
membrane_color (list): List of 3 floats that define the RGB
color to use for agent outlines.
alpha: Alpha value for agents.
highlight_agent: Mapping of agent ID to `membrane_color` and
`membrane_width`. Useful for highlighting specific agents,
with rest using default `membrane_width` and `membrane_color`
"""
if not agent_colors:
agent_colors = dict()
if not highlight_agent:
highlight_agent = []
for agent_id, agent_data in agents.items():
color = agent_colors.get(agent_id, [DEFAULT_HUE] + DEFAULT_SV)
if dead_color and agent_data.get("boundary"):
if agent_data["boundary"].get("dead"):
color = dead_color
if agent_data:
if agent_id in highlight_agent:
agent_membrane_width = highlight_agent[agent_id]["membrane_width"]
agent_membrane_color = highlight_agent[agent_id]["membrane_color"]
else:
agent_membrane_width = membrane_width
agent_membrane_color = membrane_color
plot_agent(
ax,
agent_data,
color,
agent_shape,
agent_membrane_width,
agent_membrane_color,
alpha,
)
if len(agents) == 1:
ax.set_title("1 agent", y=1.1)
else:
ax.set_title(f"{len(agents)} agents", y=1.1)
[docs]
def color_phylogeny(ancestor_id, phylogeny, baseline_hsv, phylogeny_colors=None):
"""
get colors for all descendants of the ancestor
through recursive calls to each generation
"""
if not phylogeny_colors:
phylogeny_colors = {}
phylogeny_colors.update({ancestor_id: baseline_hsv})
daughter_ids = phylogeny.get(ancestor_id)
if daughter_ids:
for daughter_id in daughter_ids:
daughter_color = mutate_color(baseline_hsv)
color_phylogeny(daughter_id, phylogeny, daughter_color, phylogeny_colors)
return phylogeny_colors
[docs]
def get_phylogeny_colors_from_names(agent_ids):
"""Get agent colors using phlogeny saved in agent_ids
This assumes the names use daughter_phylogeny_id() from meta_division
"""
# make phylogeny with {mother_id: [daughter_1_id, daughter_2_id]}
phylogeny = {agent_id: [] for agent_id in agent_ids}
for agent1, agent2 in itertools.combinations(agent_ids, 2):
if agent1 == agent2[0:-1]:
phylogeny[agent1].append(agent2)
elif agent2 == agent1[0:-1]:
phylogeny[agent2].append(agent1)
# get initial ancestors
daughters = list(phylogeny.values())
daughters = set([item for sublist in daughters for item in sublist])
mothers = set(list(phylogeny.keys()))
ancestors = list(mothers - daughters)
# agent colors based on phylogeny
agent_colors = {agent_id: [] for agent_id in agent_ids}
for agent_id in ancestors:
hue = random.choice(HUES) # select random initial hue
initial_color = [hue] + DEFAULT_SV
agent_colors.update(color_phylogeny(agent_id, phylogeny, initial_color))
return agent_colors
[docs]
def format_snapshot_data(data):
agents = {}
fields = {}
for time, time_data in data.items():
agents[time] = time_data.get("agents", {})
fields[time] = time_data.get("fields", {})
return agents, fields
[docs]
def get_field_range(
fields,
time_vec,
include_fields=None,
skip_fields=None,
min_pct=0.95,
max_pct=1,
):
if not skip_fields:
skip_fields = []
field_range = {}
if fields:
if include_fields is None:
field_ids = set(fields[time_vec[0]].keys())
else:
field_ids = set(include_fields)
field_ids -= set(skip_fields)
for field_id in field_ids:
field_min = min(
[
min(min(field_data[field_id])) * min_pct
for t, field_data in fields.items()
]
)
field_max = max(
[
max(max(field_data[field_id])) * max_pct
for t, field_data in fields.items()
]
)
field_range[field_id] = [field_min, field_max]
return field_range
[docs]
def get_agent_ids(agents):
agent_ids = set()
for time, time_data in agents.items():
current_agents = list(time_data.keys())
agent_ids.update(current_agents)
return list(agent_ids)
[docs]
def get_agent_colors(
agents,
phylogeny_names=True,
agent_fill_color=None,
):
agent_ids = get_agent_ids(agents)
agent_colors = {}
if agents:
# set agent colors
if agent_fill_color:
agent_colors = {agent_id: agent_fill_color for agent_id in agent_ids}
elif phylogeny_names:
agent_colors = get_phylogeny_colors_from_names(agent_ids)
else:
agent_colors = {}
for agent_id in agent_ids:
hue = random.choice(HUES)
color = [hue] + DEFAULT_SV
agent_colors[agent_id] = color
return agent_colors
[docs]
def plot_snapshots(
bounds,
agents=None,
fields=None,
n_snapshots=5,
snapshot_times=None,
agent_fill_color=None,
agent_colors=None,
phylogeny_names=True,
skip_fields=None,
include_fields=None,
out_dir=None,
filename="snapshots",
min_pct=0.95,
max_pct=1,
**kwargs,
):
"""Plot snapshots of the simulation over time
The snapshots depict the agents and environmental molecule
concentrations.
Arguments:
data (dict): A dictionary with the following keys:
* **bounds** (:py:class:`tuple`): The dimensions of the
environment.
* **agents** (:py:class:`dict`): A mapping from times to
dictionaries of agent data at that timepoint. Agent data
dictionaries should have the same form as the hierarchy
tree rooted at ``agents``.
* **fields** (:py:class:`dict`): A mapping from times to
dictionaries of environmental field data at that
timepoint. Field data dictionaries should have the same
form as the hierarchy tree rooted at ``fields``.
* **n_snapshots** (:py:class:`int`): Number of snapshots to
show per row (i.e. for each molecule). Defaults to 6.
* **phylogeny_names** (:py:class:`bool`): This selects agent
colors based on phylogenies seved in their names using
meta_division.py daughter_phylogeny_id()
* **skip_fields** (:py:class:`Iterable`): Keys of fields to
exclude from the plot. This takes priority over
``include_fields``.
* **include_fields** (:py:class:`Iterable`): Keys of fields
to plot.
* **snapshot_times** (:py:class:`Iterable`): Times to plot
snapshots for. Defaults to None, in which case n_snapshots
is used.
* **out_dir** (:py:class:`str`): Output directory, which is
``out`` by default.
* **filename** (:py:class:`str`): Base name of output file.
``snapshots`` by default.
* **min_pct** (:py:class:`float`): Percent of minimum value
to use as minimum in colorbar (1 = 100%)
* **max_pct** (:py:class:`float`): Percent of maximum value
to use as maximum in colorbar (1 = 100%)
"""
if not agents:
agents = {}
if not fields:
fields = {}
if not skip_fields:
skip_fields = []
# Strip units from bounds if present.
if isinstance(bounds[0], Quantity):
bounds = tuple(bound.to(units.um).magnitude for bound in bounds)
# time steps that will be used
if agents and fields:
assert set(list(agents.keys())) == set(
list(fields.keys())
), "agent and field times are different"
time_vec = list(agents.keys())
elif agents:
time_vec = list(agents.keys())
elif fields:
time_vec = list(fields.keys())
agents = {t: {} for t in time_vec}
else:
raise Exception("No agents or field data")
# get fields id and range
field_range = get_field_range(
fields, time_vec, include_fields, skip_fields, min_pct, max_pct
)
# get agent ids
if not agent_colors:
agent_colors = get_agent_colors(agents, phylogeny_names, agent_fill_color)
# get time data
if snapshot_times:
n_snapshots = len(snapshot_times)
time_indices = [time_vec.index(time) for time in snapshot_times]
else:
time_indices = np.round(np.linspace(0, len(time_vec) - 1, n_snapshots)).astype(
int
)
snapshot_times = [time_vec[i] for i in time_indices]
return make_snapshots_figure(
agents=agents,
agent_colors=agent_colors,
fields=fields,
field_range=field_range,
n_snapshots=n_snapshots,
time_indices=time_indices,
snapshot_times=snapshot_times,
bounds=bounds,
out_dir=out_dir,
filename=filename,
**kwargs,
)
[docs]
def make_snapshots_figure(
agents,
fields,
bounds,
n_snapshots,
time_indices,
snapshot_times,
time_unit="s",
plot_width=12,
field_range=None,
agent_colors=None,
dead_color=[0, 0, 0],
membrane_width=0.1,
membrane_color=[1, 1, 1],
default_font_size=36,
field_label_size=32,
agent_shape="segment",
agent_alpha=1,
colorbar_decimals=3,
show_timeline=True,
scale_bar_length=1,
scale_bar_color="black",
xlim=None,
ylim=None,
min_color="white",
max_color="gray",
out_dir=None,
filename="snapshots",
figsize=None,
):
"""
Args:
bounds (:py:class:`tuple`): The dimensions of the
environment.
field_label_size (:py:class:`float`): Font size of the
field label.
dead_color (:py:class:`list` of 3 :py:class:`float`):
Color for dead cells in HSV. Defaults to [0, 0, 0], which
is black.
default_font_size (:py:class:`float`): Font size for
titles and axis labels.
agent_shape (:py:class:`str`): the shape of the agents.
select from ``rectangle``, ``segment``
agent_alpha (:py:class:`float`): Alpha for agent
plots.
colorbar_decimals (:py:class:`int`): number of decimals in
colorbar.
scale_bar_length (:py:class:`float`): Length of scale
bar. Defaults to 1 (in units of micrometers). If 0, no
bar plotted.
scale_bar_color (:py:class:`str`): Color of scale bar
xlim (:py:class:`tuple` of :py:class:`float`): Tuple
of lower and upper x-axis limits.
ylim (:py:class:`tuple` of :py:class:`float`): Tuple
of lower and upper y-axis limits.
min_color (any valid matplotlib color): Color for
minimum field values.
max_color (any valid matplotlib color): Color for
maximum field values.
out_dir (:py:class:`str`): Output directory, which is
``out`` by default.
filename (:py:class:`str`): Base name of output file.
``snapshots`` by default.
figsize (:py:class:`tuple`): Dimensions of figure in
inches (takes precedence over `plot_width`)
"""
edge_length_x = bounds[0]
edge_length_y = bounds[1]
# make the figure
field_ids = list(field_range.keys())
n_rows = max(len(field_ids), 1)
n_cols = n_snapshots + 1 # one column for the colorbar
if not figsize:
figsize = (plot_width * n_cols, plot_width * n_rows)
max_dpi = min([2**16 // dim for dim in figsize]) - 1
fig = plt.figure(figsize=figsize, dpi=min(max_dpi, 100))
grid = plt.GridSpec(n_rows, n_cols, wspace=0.2, hspace=0.2)
original_fontsize = plt.rcParams["font.size"]
plt.rcParams.update({"font.size": default_font_size})
# Add time axis across subplots
if show_timeline:
add_time_axis(fig, grid, n_rows, n_cols, n_snapshots, snapshot_times, time_unit)
# Make the colormap
min_rgb = matplotlib.colors.to_rgb(min_color)
max_rgb = matplotlib.colors.to_rgb(max_color)
colors_dict = {
"red": [
[0, min_rgb[0], min_rgb[0]],
[1, max_rgb[0], max_rgb[0]],
],
"green": [
[0, min_rgb[1], min_rgb[1]],
[1, max_rgb[1], max_rgb[1]],
],
"blue": [
[0, min_rgb[2], min_rgb[2]],
[1, max_rgb[2], max_rgb[2]],
],
}
cmap = matplotlib.colors.LinearSegmentedColormap(
"field", segmentdata=colors_dict, N=512
)
stats = {
"agents": {},
}
if field_ids:
stats["fields"] = {field_id: {} for field_id in field_ids}
# plot snapshot data in each subsequent column
for col_idx, (time_idx, time) in enumerate(zip(time_indices, snapshot_times)):
stats["agents"][time] = len(agents[time])
if field_ids:
for row_idx, field_id in enumerate(field_ids):
ax = init_axes(
fig,
edge_length_x,
edge_length_y,
grid,
row_idx,
col_idx,
time,
field_id,
field_label_size,
)
ax.tick_params(
axis="both",
which="both",
bottom=False,
top=False,
left=False,
right=False,
)
# transpose field to align with agents
field = np.transpose(np.array(fields[time][field_id]))
vmin, vmax = field_range[field_id]
q1, q2, q3 = np.percentile(field, [25, 50, 75])
stats["fields"][field_id][time] = (field.min(), q1, q2, q3, field.max())
im = plt.imshow(
field.tolist(),
origin="lower",
extent=[0, edge_length_x, 0, edge_length_y],
vmin=vmin,
vmax=vmax,
cmap=cmap,
)
if agents:
agents_now = agents[time]
plot_agents(
ax,
agents_now,
agent_colors,
agent_shape=agent_shape,
dead_color=dead_color,
membrane_width=membrane_width,
membrane_color=membrane_color,
alpha=agent_alpha,
)
if xlim:
ax.set_xlim(*xlim)
if ylim:
ax.set_ylim(*ylim)
# colorbar in new column after final snapshot
if col_idx == n_snapshots - 1:
cbar_col = col_idx + 1
ax = fig.add_subplot(grid[row_idx, cbar_col])
if row_idx == 0:
ax.set_title("Concentration\n(mM)", y=1.08)
ax.axis("off")
if vmin == vmax:
continue
divider = make_axes_locatable(ax)
cax = divider.append_axes("left", size="5%", pad=0.0)
fig.colorbar(im, cax=cax, format=f"%.{colorbar_decimals}f")
ax.axis("off")
# Scale bar in first snapshot of each row
if col_idx == 0 and scale_bar_length:
scale_bar = anchored_artists.AnchoredSizeBar(
ax.transData,
scale_bar_length,
f"{scale_bar_length} μm",
"lower left",
color=scale_bar_color,
frameon=False,
sep=scale_bar_length,
size_vertical=scale_bar_length / 20,
)
ax.add_artist(scale_bar)
else:
row_idx = 0
ax = init_axes(fig, bounds[0], bounds[1], grid, row_idx, col_idx, time, "")
if agents:
agents_now = agents[time]
plot_agents(
ax,
agents_now,
agent_colors,
agent_shape=agent_shape,
dead_color=dead_color,
membrane_width=membrane_width,
membrane_color=membrane_color,
alpha=agent_alpha,
)
if xlim:
ax.set_xlim(*xlim)
if ylim:
ax.set_ylim(*ylim)
# Scale bar in first snapshot of each row
if col_idx == 0 and scale_bar_length:
scale_bar = anchored_artists.AnchoredSizeBar(
ax.transData,
scale_bar_length,
f"{scale_bar_length} μm",
"lower left",
color=scale_bar_color,
frameon=False,
sep=scale_bar_length,
size_vertical=scale_bar_length / 20,
)
ax.add_artist(scale_bar)
plt.rcParams.update({"font.size": original_fontsize})
if out_dir:
fig_path = os.path.join(out_dir, filename)
fig.subplots_adjust(wspace=0.7, hspace=0.1)
fig.savefig(fig_path, bbox_inches="tight")
return fig
[docs]
def plot_tags(
data,
bounds,
snapshot_times=None,
n_snapshots=5,
**kwargs,
):
agents, fields = format_snapshot_data(data)
time_vec = list(agents.keys())
# get time data
if snapshot_times:
n_snapshots = len(snapshot_times)
time_indices = [time_vec.index(time) for time in snapshot_times]
else:
time_indices = np.round(np.linspace(0, len(time_vec) - 1, n_snapshots)).astype(
int
)
snapshot_times = [time_vec[i] for i in time_indices]
return make_tags_figure(
agents=agents,
bounds=bounds,
n_snapshots=n_snapshots,
time_indices=time_indices,
snapshot_times=snapshot_times,
**kwargs,
)
[docs]
def get_tag_ranges(
agents, tagged_molecules, time_indices, convert_to_concs, tag_colors
):
# get tag ids and range
tag_ranges = {}
for time_idx, (time, time_data) in enumerate(agents.items()):
if time_idx in time_indices:
for agent_id, agent_data in time_data.items():
volume = agent_data.get("boundary", {}).get("volume", 0)
for tag_id in tagged_molecules:
level = get_value_from_path(agent_data, tag_id)
if level is None:
continue
if convert_to_concs:
level = level / volume if volume else 0
if tag_id in tag_ranges:
tag_ranges[tag_id] = [
min(tag_ranges[tag_id][0], level),
max(tag_ranges[tag_id][1], level),
]
else:
# add new tag
tag_ranges[tag_id] = [level, level]
# select random initial hue
if tag_id not in tag_colors:
hue = random.choice(HUES)
tag_color = [hue] + FLOURESCENT_SV
tag_colors[tag_id] = tag_color
return tag_ranges, tag_colors
[docs]
def make_tags_figure(
agents,
bounds,
time_indices,
snapshot_times,
tag_ranges=None,
tag_colors=None,
min_color="black",
agent_colors=None,
n_snapshots=6,
scale_bar_length=1,
scale_bar_color="black",
show_timeline=True,
show_colorbar=True,
time_unit="s",
tagged_molecules=None,
out_dir=False,
filename="tags",
agent_shape="segment",
background_color="black",
colorbar_decimals=1,
tag_path_name_map=None,
tag_label_size=20,
plot_width=12,
default_font_size=36,
convert_to_concs=True,
membrane_width=0.1,
membrane_color=None,
xlim=None,
ylim=None,
figsize=None,
highlight_agent=None,
):
"""Plot snapshots of the simulation over time
The snapshots depict the agents and the levels of tagged molecules
in each agent by agent color intensity.
Arguments:
data (dict): A dictionary with the following keys:
* **agents** (:py:class:`dict`): A mapping from times to
dictionaries of agent data at that timepoint. Agent data
dictionaries should have the same form as the hierarchy
tree rooted at ``agents``.
* **n_snapshots** (:py:class:`int`): Number of snapshots to
show per row (i.e. for each molecule). Defaults to 6.
* **out_dir** (:py:class:`str`): Output directory, which is
``out`` by default.
* **filename** (:py:class:`str`): Base name of output file.
``tags`` by default.
* **tagged_molecules** (:py:class:`typing.Iterable`): The
tagged molecules whose concentrations will be indicated by
agent color. Each molecule should be specified as a
:py:class:`tuple` of the path in the agent compartment
to where the molecule's count can be found, with the last
value being the molecule's count variable.
* **convert_to_concs** (:py:class:`bool`): if True, convert counts
to concentrations.
* **background_color** (:py:class:`str`): use matplotlib colors,
``black`` by default
* **colorbar_decimals** (:py:class:`int`): number of decimals in
colorbar.
* **tag_label_size** (:py:class:`float`): The font size for
the tag name label
* **default_font_size** (:py:class:`float`): Font size for
titles and axis labels.
* **membrane_width** (:py:class:`float`): Width to use for
drawing agent edges.
* **membrane_color** (:py:class:`list`): RGB color to use
for drawing agent edges.
* **tag_colors** (:py:class:`dict`): Mapping from tag ID to
the HSV color to use for that tag as a list. Alternatively,
each tag ID is mapped to a dictionary containing the `cmp`
and `norm` keys with the :py:class:`matplotlib.colors.Colormap`
and the :py:class:`matplotlib.colors.Normalize` instance to use
for that tag. Using dictionaries will override ``min_color``
* **figsize** (:py:class:`tuple`): Dimensions of figure in
inches (takes precedence over `plot_width`)
* **highlight_agent** (:py:class:`dict`): Mapping of agent IDs to
`membrane_color` and `membrane_width`. Useful for highlighting
specific agents, with rest using default color / width
"""
membrane_color = membrane_color or [1, 1, 1]
agent_colors = agent_colors or {}
tag_colors = tag_colors or {}
tag_path_name_map = tag_path_name_map or {}
tagged_molecules = tagged_molecules or []
if tagged_molecules == []:
raise ValueError("At least one molecule must be tagged.")
if not tag_ranges:
tag_ranges, tag_colors = get_tag_ranges(
agents, tagged_molecules, time_indices, convert_to_concs, tag_colors
)
# get data
edge_length_x, edge_length_y = bounds
# make the figure
n_rows = len(tagged_molecules)
n_cols = n_snapshots + 1 # one column for the colorbar
if not figsize:
figsize = (plot_width * n_cols, plot_width * n_rows)
max_dpi = min([2**16 // dim for dim in figsize]) - 1
fig = plt.figure(figsize=figsize, dpi=min(max_dpi, 100))
grid = plt.GridSpec(n_rows, n_cols, wspace=0.2, hspace=0.2)
original_fontsize = plt.rcParams["font.size"]
plt.rcParams.update({"font.size": default_font_size})
# Add time axis across subplots
if show_timeline:
add_time_axis(fig, grid, n_rows, n_cols, n_snapshots, snapshot_times, time_unit)
# plot tags
for row_idx, tag_id in enumerate(tag_ranges.keys()):
for col_idx, (time_idx, time) in enumerate(zip(time_indices, snapshot_times)):
tag_name = tag_path_name_map.get(tag_id, tag_id)
ax = init_axes(
fig,
edge_length_x,
edge_length_y,
grid,
row_idx,
col_idx,
time,
tag_name,
tag_label_size,
title_size=default_font_size,
)
ax.tick_params(
axis="both",
which="both",
bottom=False,
top=False,
left=False,
right=False,
)
ax.set_facecolor(background_color)
# update agent colors based on tag_level
min_tag, max_tag = tag_ranges[tag_id]
agent_tag_colors = {}
if isinstance(tag_colors[tag_id], dict):
cmap = tag_colors[tag_id]["cmp"]
norm = tag_colors[tag_id]["norm"]
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
else:
tag_h, tag_s, tag_v = tag_colors[tag_id]
tag_color_rgb = np.array(
matplotlib.colors.hsv_to_rgb([tag_h, tag_s, tag_v])
)
min_color = np.array(matplotlib.colors.to_rgb(min_color))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
name=f"row_{row_idx}", colors=[min_color, tag_color_rgb]
)
norm = matplotlib.colors.Normalize(vmin=min_tag, vmax=max_tag)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
for agent_id, agent_data in agents[time].items():
# get current tag concentration, and determine color
level = get_value_from_path(agent_data, tag_id)
if convert_to_concs:
volume = agent_data.get("boundary", {}).get("volume", 0)
level = level / volume if volume else 0
agent_tag_colors[agent_id] = rgb_to_hsv(mappable.to_rgba(level)[:3])
agent_tag_colors.update(agent_colors)
plot_agents(
ax,
agents[time],
agent_tag_colors,
agent_shape,
None,
membrane_width,
membrane_color,
highlight_agent=highlight_agent,
)
if xlim:
ax.set_xlim(*xlim)
if ylim:
ax.set_ylim(*ylim)
# Scale bar in first snapshot of each row
if col_idx == 0 and scale_bar_length:
scale_bar = anchored_artists.AnchoredSizeBar(
ax.transData,
scale_bar_length,
f"{scale_bar_length} μm",
"lower left",
color=scale_bar_color,
frameon=False,
sep=scale_bar_length,
size_vertical=scale_bar_length / 20,
)
ax.add_artist(scale_bar)
# colorbar in new column after final snapshot
if col_idx == n_snapshots - 1 and show_colorbar:
cbar_col = col_idx + 1
ax = fig.add_subplot(grid[row_idx, cbar_col])
if row_idx == 0:
if convert_to_concs:
ax.set_title("Concentration\n(counts/fL)", y=1.08)
ax.axis("off")
if min_tag == max_tag:
continue
divider = make_axes_locatable(ax)
cax = divider.append_axes("left", size="5%", pad=0.0)
fig.colorbar(mappable, cax=cax, format=f"%.{colorbar_decimals}f")
plt.rcParams.update({"font.size": original_fontsize})
if out_dir:
fig_path = os.path.join(out_dir, filename)
fig.subplots_adjust(wspace=0.7, hspace=0.1)
fig.savefig(fig_path, bbox_inches="tight")
return fig
[docs]
def save_snapshot_figure(data_at_time, kwargs):
time = data_at_time[0]
fig_path = os.path.join(kwargs.pop("images_dir", ""), f"img{time}.jpg")
multibody_agents, multibody_fields = format_snapshot_data({time: data_at_time[1]})
fig = make_snapshots_figure(
agents=multibody_agents,
fields=multibody_fields,
n_snapshots=1,
time_indices=[time],
snapshot_times=[time],
plot_width=PLOT_WIDTH,
scale_bar_length=0,
**kwargs,
)
fig.savefig(fig_path, bbox_inches="tight")
plt.close()
return fig_path
[docs]
def save_tags_figure(data_at_time, kwargs):
time = data_at_time[0]
fig_path = os.path.join(kwargs.pop("images_dir", ""), f"img{time}.jpg")
agents = {time: data_at_time[1].get("agents", {})}
fig = make_tags_figure(
time_indices=[time],
snapshot_times=[time],
agents=agents,
n_snapshots=1,
plot_width=PLOT_WIDTH,
**kwargs,
)
fig.savefig(fig_path, bbox_inches="tight", dpi=300)
plt.close()
return fig_path
[docs]
def save_timeseries_figure(t_index, kwargs):
time_vec = list(kwargs["data"].keys())
plot_settings = {
"column_width": 6,
"row_height": 2,
"stack_column": True,
"tick_label_size": 10,
"linewidth": 2,
"title_size": 10,
}
if kwargs["show_timeseries"]:
plot_settings.update({"include_paths": kwargs["show_timeseries"]})
# remove agents not included in highlight_agents
if kwargs["highlight_agents"]:
for time, state in kwargs["data"].items():
agents = state["agents"]
for agent_id, agent_state in agents.items():
if agent_id not in kwargs["highlight_agents"]:
del kwargs["data"][time]["agents"][agent_id]
agent_colors = {
agent_id: kwargs["highlight_color"]
for agent_id in kwargs["highlight_agents"]
}
plot_settings.update({"agent_colors": agent_colors})
fig_path = os.path.join(kwargs["images_dir"], f"timeseries{t_index}.jpg")
time_indices = np.array(range(0, t_index + 1))
current_data = {
time_vec[index]: kwargs["data"][time_vec[index]] for index in time_indices
}
fig = plot_agents_multigen(current_data, dict(plot_settings))
fig.savefig(fig_path, bbox_inches="tight")
plt.close()
return fig_path
[docs]
def video_from_images(img_paths, out_file):
# make the video
img_array = []
size = None
for img_file in img_paths:
img = cv2.imread(img_file)
height, width, layers = img.shape
if size:
if width < size[0]:
size[0] = width
if height < size[0]:
size[1] = height
else:
size = [width, height]
img_array.append(img)
out = cv2.VideoWriter(out_file, cv2.VideoWriter_fourcc(*"mp4v"), 15, size)
for i in range(len(img_array)):
# Crop all images to smallest size to avoid frame skips
img_array[i] = img_array[i][0 : size[1], 0 : size[0]]
out.write(img_array[i])
out.release()
[docs]
def make_video(
data,
bounds,
plot_type="fields",
step=1,
highlight_agents=None,
show_timeseries=None,
highlight_color=DEFAULT_HIGHLIGHT_COLOR,
out_dir="out",
filename="snapshot_vid",
cpus=1,
**kwargs,
):
"""Make a video with snapshots across time
Args:
plot_type: (str) select either 'fields' or 'tags'. 'fields' is the default
"""
# Remove last timestep since data may be empty
data = dict(list(data.items())[:-1])
highlight_agents = highlight_agents or []
show_timeseries = show_timeseries or []
# Strip units from bounds if present.
if isinstance(bounds[0], Quantity):
bounds = tuple(bound.to(units.um).magnitude for bound in bounds)
# make images directory, remove if existing
out_file = os.path.join(out_dir, f"{filename}.mp4")
out_file2 = os.path.join(out_dir, f"{filename}_timeseries.mp4")
images_dir = os.path.join(out_dir, f"_images_{plot_type}")
if os.path.isdir(images_dir):
shutil.rmtree(images_dir)
os.makedirs(images_dir)
agent_colors = {}
if highlight_agents:
agent_colors = {agent_id: highlight_color for agent_id in highlight_agents}
# get the single snapshots function
multibody_agents, multibody_fields = format_snapshot_data(data)
time_vec = list(multibody_agents.keys())
if plot_type == "fields":
multibody_field_range = get_field_range(multibody_fields, time_vec)
multibody_agent_colors = get_agent_colors(multibody_agents)
multibody_agent_colors.update(agent_colors)
do_plot = save_snapshot_figure
plot_kwargs = {
"multibody_agent_colors": multibody_agent_colors,
"multibody_field_range": multibody_field_range,
"images_dir": images_dir,
"bounds": bounds,
**kwargs,
}
elif plot_type == "tags":
time_indices = np.array(range(0, len(time_vec)))
tag_ranges, tag_colors = get_tag_ranges(
agents=multibody_agents,
tagged_molecules=kwargs.get("tagged_molecules", None),
time_indices=time_indices,
convert_to_concs=kwargs.get("convert_to_concs", False),
tag_colors=kwargs.pop("tag_colors", {}),
)
do_plot = save_tags_figure
plot_kwargs = {
"tag_ranges": tag_ranges,
"tag_colors": tag_colors,
"images_dir": images_dir,
"agent_colors": agent_colors,
"bounds": bounds,
**kwargs,
}
# Only plot data for every `step` timepoints
if step != 1:
filtered_data = {}
time_counter = 0
for timepoint in time_vec:
if time_counter % step == 0:
filtered_data[timepoint] = data[timepoint]
time_counter += 1
data = filtered_data
with concurrent.futures.ProcessPoolExecutor(cpus) as executor:
img_paths = list(
tqdm(
executor.map(do_plot, data.items(), itertools.repeat(plot_kwargs)),
total=len(data),
)
)
img_paths_2 = []
if show_timeseries:
plot_kwargs = {
"show_timeseries": show_timeseries,
"highlight_agents": highlight_agents,
"highlight_color": highlight_color,
"images_dir": images_dir,
"data": data,
**kwargs,
}
time_indices = list(range(0, len(time_vec)))
with concurrent.futures.ProcessPoolExecutor() as executor:
img_paths_2 = list(
tqdm(
executor.map(
save_timeseries_figure,
time_indices,
itertools.repeat(plot_kwargs),
),
total=len(time_indices),
)
)
# make the video
video_from_images(img_paths, out_file)
video_from_images(img_paths_2, out_file2)
# delete image folder
shutil.rmtree(images_dir)
[docs]
def plot(
params: dict[str, Any],
conn: DuckDBPyConnection,
history_sql: str,
config_sql: str,
sim_data_paths: dict[str, dict[int, str]],
validation_data_paths: list[str],
outdir: str,
variant_metadata: dict[str, dict[int, Any]],
variant_name: str,
):
# TODO: Write analysis script using DuckDB
raise NotImplementedError("Still need to write analysis script using DuckDB!")