Source code for ecoli.analysis.antibiotics_colony.snapshot_and_hist_plot

import argparse
import ast
import concurrent.futures
import os
import warnings
import json
import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

from ecoli.analysis.antibiotics_colony.plot_utils import prettify_axis

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from vivarium.core.emitter import DatabaseEmitter
from vivarium.core.serialize import deserialize_value
from vivarium.library.dict_utils import get_value_from_path
from vivarium.library.topology import convert_path_style
from vivarium.library.units import remove_units

from ecoli.analysis.colony.snapshots import (
    format_snapshot_data,
    get_tag_ranges,
    plot_tags,
)
from ecoli.analysis.antibiotics_colony import COUNTS_PER_FL_TO_NANOMOLAR, PATHS_TO_LOAD


PERIPLASMIC_VOLUME_FRACTION = 0.2
PERIPLASMIC_VARS = ["OmpF monomer", "TolC monomer", "AmpC monomer"]


[docs] def deserialize_and_remove_units(d): return remove_units(deserialize_value(d))
[docs] def make_snapshot_and_hist_plot( timepoint_data, metadata, bounds, molecule, title=None, tag_hsv=[0.6, 1, 1] ): """Generates a figure with a snapshot plot tagging the specified molecule, and a smoothed density plot (using histogram) of the distribution of counts for that molecule at that time. Args: timepoint_data: data from one timepoint, in the form {time : {...data...}} bounds: physical bounds for the snapshot plot molecule: molecule to tag in the snapshot plot / histogram plot. Returns: fig, axes""" # time = list(timepoint_data.keys()) time = timepoint_data["Time"].unique() assert len(time) == 1, f"Expected only one timepoint, got {time}" time = time[0] condition = list(metadata.keys()) assert len(condition) == 1 condition = condition[0] seed = list(metadata[condition].keys()) assert len(seed) == 1 seed = seed[0] # Convert DataFrame data back to dictionary form for tag plot timepoint_data = { time: { "agents": { agent_id: { "boundary": boundary, # Convert from counts to uM molecule: ( ( molecule_count / ( boundary["volume"] * ( PERIPLASMIC_VOLUME_FRACTION if molecule in PERIPLASMIC_VARS else 1 - PERIPLASMIC_VOLUME_FRACTION ) ) ) * COUNTS_PER_FL_TO_NANOMOLAR / 10**3 # Convert to uM ), } for agent_id, boundary, molecule_count in zip( timepoint_data.loc[:, "Agent ID"], timepoint_data.loc[:, "Boundary"], timepoint_data.loc[:, molecule], ) }, "fields": metadata[condition][seed]["fields"][time], } } # Get snapshot figure from plot_tags fig = plot_tags( timepoint_data, bounds, snapshot_times=[time], tagged_molecules=[(molecule,)], show_timeline=False, background_color="white", default_font_size=10, scale_bar_length=None, # TODO: scale bar length looks wrong? min_color="white", tag_colors={(molecule,): tag_hsv}, convert_to_concs=False, xlim=[8, 42], ylim=[8, 42], ) tag_axes = fig.get_axes() snapshot_ax = tag_axes[0] # Prettify axis labels snapshot_ax.set(ylabel=None) snapshot_ax.set_title(molecule[-1] if title is None else title) grid = fig.add_gridspec( 2, 2, width_ratios=[2, 1], height_ratios=[3, 1], wspace=0, hspace=0 ) # Reposition axes, preparing to add hist plot below snapshot_ax.set_position(grid[0, 0].get_position(fig)) snapshot_ax.set_subplotspec(grid[0, 0]) # Remove colorbar axes (recreating is easier than re-positioning) for ax in fig.get_axes(): if ax != snapshot_ax: ax.remove() # re-create colorbar divider = make_axes_locatable(snapshot_ax) cax = divider.append_axes("bottom", size="5%", pad=0.05) agents, _ = format_snapshot_data(timepoint_data) tag_ranges, _ = get_tag_ranges( agents, [(molecule,)], [0], False, {(molecule,): tag_hsv} ) min_tag, max_tag = tag_ranges[(molecule,)] norm = matplotlib.colors.Normalize(vmin=min_tag, vmax=max_tag) cmap = matplotlib.colors.LinearSegmentedColormap.from_list( "", ["white", np.array(matplotlib.colors.hsv_to_rgb(tag_hsv))] ) mappable = matplotlib.cm.ScalarMappable(norm, cmap) cbar = fig.colorbar( mappable, cax=cax, orientation="horizontal", ticks=[min_tag, max_tag], ) def format_tick(tick): if tick == 0 or tick > 100: return f"{tick:.0f}" return f"{tick:.1f}" cbar.ax.set_xticklabels( [f"{format_tick(min_tag)} μM", f"{format_tick(max_tag)} μM"], fontsize=8 ) # Add histogram plot hist_ax = fig.add_subplot(grid[1, 0]) # Get distribution of concentration across agents hist_data = { molecule[-1]: [ get_value_from_path(agent_data, (molecule,)) for agent_data in timepoint_data[time]["agents"].values() ] } hist_data = pd.DataFrame(hist_data) # Plot histogram sns.histplot( data=hist_data, x=molecule[-1], ax=hist_ax, color=matplotlib.colors.hsv_to_rgb(tag_hsv), ) # Aesthetics hist, bins = np.histogram(hist_data, bins="auto") hist_ax.set_xlabel(None) hist_ax.set_ylabel("Cells", fontsize=9, labelpad=-5) hist_ax.set( xticks=[bins[0], bins[-1]], xticklabels=[f"{format_tick(bins[0])} μM", f"{format_tick(bins[-1])} μM"], xlim=[bins[0], bins[-1]], yticks=[0, max(hist)], ylim=[0, max(hist)], ) prettify_axis( hist_ax, label_fontsize=9, ticklabel_fontsize=8, tick_format_x="{:.1f} μM", tick_format_y="{:.0f}", ) hist_ax.set_xticks( [min_tag, max_tag], labels=[f"{format_tick(min_tag)} μM", f"{format_tick(max_tag)} μM"], ) # hist_ax.set_box_aspect(1) return fig, fig.get_axes()
[docs] def get_data(experiment_id, time, molecules, host, port, cpus, verbose): # Prepare molecule paths for access_counts() # monomers = [m[-1] for m in molecules if m[-2] == "monomer"] # mrnas = [m[-1] for m in molecules if m[-2] == "mrna"] # inner_paths = [ # path for path in molecules if path[-1] not in mrnas and path[-1] not in monomers # ] # outer_paths = [("data", "dimensions")] if verbose: print(f"Accessing data for experiment {experiment_id}...") # TODO: Retrieve data using DuckDB raise NotImplementedError("Still need to update to use DuckDB!") data = {} with concurrent.futures.ProcessPoolExecutor(cpus) as executor: # Prepare to deserialize data data_deserialized = executor.map(deserialize_and_remove_units, data.values()) # If verbose, add a progress bar if verbose: data_deserialized = tqdm( data_deserialized, desc="Deserializing data", total=len(data) ) # Do the actual deserializing (lazy computation) data_deserialized = list(data_deserialized) # prep data, physical bounds for returning data = dict(zip(data.keys(), data_deserialized)) bounds = data[time]["dimensions"]["bounds"] return data, bounds
[docs] def main(): parser = argparse.ArgumentParser( "Generate snapshot and histogram figures for specified molecules." ) parser.add_argument( "experiment_id", help="ID of the experiment for which to make the figure(s).", ) parser.add_argument( "--molecule_paths", "-m", nargs="+", required=True, help="Paths (in A>B>C form) of the molecule(s) for which to generate figure(s). " 'Can be preceded by an alias for that molecule, e.g. "OmpF=monomer>EG10671-MONOMER".', ) parser.add_argument( "--time", "-t", type=int, default=None, help="Timepoint which to plot (defaults to last timepoint).", ) parser.add_argument( "--outdir", "-d", default="out/snapshot_hist_plots", help="Directory in which to output the generated figures.", ) parser.add_argument("--svg", "-s", action="store_true", help="Save as svg.") parser.add_argument("--host", "-o", default="localhost", type=str) parser.add_argument("--port", "-p", default=27017, type=int) parser.add_argument( "--local", "-l", default=None, type=str, help="Locally saved dataframe file to run the plots on (if provided). " "Setting this option overrides database options (experiment_id, host, port).", ) parser.add_argument("--cpus", "-c", type=int, default=1) parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() # Covert molecule path styles, get names molecules = [] molecule_names = [] for path in args.molecule_paths: p = path.split("=") if len(p) == 1: # no alias given p = convert_path_style(p[0]) molecules.append(p) molecule_names.append(p[-1]) else: molecules.append(convert_path_style(p[-1])) molecule_names.append(p[0]) if args.local: # Load data data = pd.read_csv( args.local, dtype={"Agent ID": str, "Seed": str}, index_col=0 ) # Convert string to dictionary data["Boundary"] = data["Boundary"].apply(ast.literal_eval) # Get only desired columns paths_to_columns = {v: k for k, v in PATHS_TO_LOAD.items() if v in molecules} for missing in [p for p in molecules if p not in paths_to_columns]: warnings.warn(f"Path {missing} is missing from locally saved dataframe.") keep_columns = [ "Agent ID", "Dry mass", "Growth rate", "Time", "Seed", "Condition", "Boundary", *paths_to_columns.values(), ] data = data[keep_columns] # Load metadata if args.verbose: print( "Loading metadata; filename must have the form <data_filename>_metadata.<data_ext>" ) filename, ext = os.path.splitext(args.local) with open(f"{filename}_metadata.json", "r") as f: metadata = json.load(f) # Get environmental bounds condition = list(metadata.keys())[0] seed = list(metadata[condition].keys())[0] bounds = metadata[condition][seed]["bounds"] # Get max time if no time specified time = args.time if time is None: if args.verbose: print( "No timepoint given, trying to infer and use final timepoint from data.\n" "If this fails, consider specifying an explicit timepoint." ) time = data["Time"].max() # Restrict data to come from only one timepoint data = data[data["Time"] == time] else: # Get max time if no time specified time = args.time if time is None: if args.verbose: print( "No timepoint given, trying to infer and use final timepoint from data.\n" "If this fails, consider specifying an explicit timepoint." ) config = {"host": f"{args.host}:{args.port}", "database": "simulations"} emitter = DatabaseEmitter(config) db = emitter.db time = list( db.history.aggregate( [ {"$match": {"experiment_id": args.experiment_id}}, {"$project": {"data.time": 1}}, {"$group": {"_id": None, "time": {"$max": "$data.time"}}}, ] ) )[0]["time"] # Get data from database data, bounds = get_data( experiment_id=args.experiment_id, time=2 * (time // 2), # only even timesteps have the data necessary molecules=molecules, host=args.host, port=args.port, cpus=args.cpus, verbose=args.verbose, ) # TODO: Get this in Dataframe format (see plot.py?) # Generate one figure per molecule plt.rcParams["font.family"] = "Arial" plt.rcParams["svg.fonttype"] = "none" os.makedirs(args.outdir, exist_ok=True) for name, molecule in zip(molecule_names, molecules): if args.verbose: print(f"Plotting snapshot + histogram for {name}={molecule[-1]}...") fig, _ = make_snapshot_and_hist_plot( data, metadata, bounds, paths_to_columns[molecule], title="" ) fig.set_size_inches(2.25, 2.9) fig.tight_layout() fig.savefig( os.path.join( args.outdir, f"snapshot_and_hist_{name}.{'svg' if args.svg else 'png'}", ) ) plt.close(fig)
if __name__ == "__main__": main()