Source code for ecoli.analysis.multivariant.centralCarbonMetabolismScatter

"""
Central Carbon Metabolism Flux comparison against Toya 2010 for multivariant simulation.

Scatter plot: X-axis = Toya 2010 experimentally measured flux; Y-axis = simulated flux.
Fluxes are aggregated across all seeds and generations within each variant.
One subplot per variant, faceted in a grid.
"""

from typing import Any, TYPE_CHECKING
import os
import pickle
import numpy as np

from ecoli.analysis.multivariant.utils import create_variant_label
from wholecell.utils import units, toya
from fsspec import open as fsspec_open
from ecoli.library.parquet_emitter import (
    field_metadata,
    ndlist_to_ndarray,
    open_arbitrary_sim_data,
    read_stacked_columns,
)

import altair as alt
import polars as pl

if TYPE_CHECKING:
    from duckdb import DuckDBPyConnection

COUNTS_UNITS = units.mmol
VOLUME_UNITS = units.L
MASS_UNITS = units.g
TIME_UNITS = units.s
CONC_UNITS = COUNTS_UNITS / VOLUME_UNITS
FLUX_UNITS = COUNTS_UNITS / VOLUME_UNITS / TIME_UNITS
TIMESTEP = 1 * TIME_UNITS


[docs] def plot( params: dict[str, Any], conn: "DuckDBPyConnection", history_sql: str, config_sql: str, success_sql: str, sim_data_dict: dict[str, dict[int, str]], validation_data_paths: list[str], outdir: str, variant_metadata: dict[str, dict[int, Any]], variant_names: dict[str, str], ): experiment_id = next(iter(variant_metadata.keys()), None) per_variant_params: dict[int, Any] = ( variant_metadata[experiment_id] if experiment_id else {} ) REDUXCLASSIC = params.get("is_redux", True) with open_arbitrary_sim_data(sim_data_dict) as f: sim_data = pickle.load(f) with fsspec_open(validation_data_paths[0], "rb") as f: validation_data = pickle.load(f) cell_density = sim_data.constants.cell_density query = [ "listeners__mass__cell_mass AS cell_mass", "listeners__mass__dry_mass AS dry_mass", "listeners__fba_results__base_reaction_fluxes AS base_reaction_fluxes", "listeners__enzyme_kinetics__counts_to_molar AS counts_to_molar", ] raw = pl.DataFrame( read_stacked_columns( history_sql, query, order_results=True, conn=conn, remove_first=REDUXCLASSIC ) ) reaction_ids = np.array( field_metadata(conn, config_sql, "listeners__fba_results__base_reaction_fluxes") ) toya_reactions = validation_data.reactionFlux.toya2010fluxes["reactionID"] common_reactions = [ r for r in toya_reactions if r in {r: i for i, r in enumerate(reaction_ids)} ] flux_matrix = ndlist_to_ndarray(raw["base_reaction_fluxes"]) if REDUXCLASSIC: counts_to_molar = raw["counts_to_molar"].to_numpy()[:, np.newaxis] sim_reaction_fluxes = CONC_UNITS / TIMESTEP * counts_to_molar * flux_matrix else: sim_reaction_fluxes = CONC_UNITS / TIMESTEP * flux_matrix unique_variants = sorted(raw["variant"].unique().to_list()) variant_col = raw["variant"].to_numpy() all_dfs = [] summary_rows = [] for variant_val in unique_variants: variant_label_list = create_variant_label(variant_val, per_variant_params) variant_label = " ".join(variant_label_list) mask = variant_col == variant_val cell_masses_ref = units.fg * raw.filter(pl.Series(mask))["cell_mass"] dry_masses_ref = units.fg * raw.filter(pl.Series(mask))["dry_mass"] toya_fluxes = toya.adjust_toya_data( validation_data.reactionFlux.toya2010fluxes["reactionFlux"], cell_masses_ref, dry_masses_ref, cell_density, ) toya_stdevs = toya.adjust_toya_data( validation_data.reactionFlux.toya2010fluxes["reactionFluxStdev"], cell_masses_ref, dry_masses_ref, cell_density, ) sim_flux_means, sim_flux_stdevs = toya.process_simulated_fluxes( toya_reactions, reaction_ids, sim_reaction_fluxes[mask, :] ) toya_flux_means = toya.process_toya_data( common_reactions, toya_reactions, toya_fluxes ) toya_flux_stdevs = toya.process_toya_data( common_reactions, toya_reactions, toya_stdevs ) sim_means_num = sim_flux_means.asNumber(FLUX_UNITS) sim_stdevs_num = sim_flux_stdevs.asNumber(FLUX_UNITS) toya_means_num = toya_flux_means.asNumber(FLUX_UNITS) toya_stdevs_num = toya_flux_stdevs.asNumber(FLUX_UNITS) ss_res = np.sum((sim_means_num - toya_means_num) ** 2) ss_tot = np.sum((toya_means_num - np.mean(toya_means_num)) ** 2) r_squared = float(1 - ss_res / ss_tot) pearson_r = float(np.corrcoef(sim_means_num, toya_means_num)[0, 1]) pearson_r2 = pearson_r**2 summary_rows.append( { "variant_val": variant_val, "variant_label": variant_label, "pearson_r2": pearson_r2, "r_squared": r_squared, } ) all_dfs.append( pl.DataFrame( { "reaction": list(toya_reactions), "toya_mean": toya_means_num, "toya_stdev": toya_stdevs_num, "sim_mean": sim_means_num, "sim_stdev": sim_stdevs_num, "toya_lo": toya_means_num - toya_stdevs_num, "toya_hi": toya_means_num + toya_stdevs_num, "sim_lo": sim_means_num - sim_stdevs_num, "sim_hi": sim_means_num + sim_stdevs_num, "variant_label": [variant_label] * len(toya_reactions), "pearson_r2": [pearson_r2] * len(toya_reactions), "r_squared": [r_squared] * len(toya_reactions), } ) ) df_all = pl.concat(all_dfs) summary_df = pl.DataFrame(summary_rows) summary_df.write_csv(os.path.join(outdir, "central_metabolism_stats.csv")) flux_unit_str = FLUX_UNITS.strUnit() base = alt.Chart().encode( x=alt.X("mean(toya_mean):Q", title=f"Toya 2010 Reaction Flux {flux_unit_str}"), y=alt.Y("mean(sim_mean):Q", title=f"Mean WCM Reaction Flux {flux_unit_str}"), detail=alt.Detail("reaction:N"), tooltip=["reaction:N", "mean(toya_mean):Q", "mean(sim_mean):Q"], ) points = base.mark_point(color="steelblue", size=50, filled=True) x_errorbars = ( alt.Chart() .mark_rule(color="black", strokeWidth=1) .encode( x=alt.X("mean(toya_lo):Q"), x2="mean(toya_hi):Q", y=alt.Y("mean(sim_mean):Q"), detail=alt.Detail("reaction:N"), ) ) y_errorbars = ( alt.Chart() .mark_rule(color="black", strokeWidth=1) .encode( y=alt.Y("mean(sim_lo):Q"), y2="mean(sim_hi):Q", x=alt.X("mean(toya_mean):Q"), detail=alt.Detail("reaction:N"), ) ) annotation = ( alt.Chart() .mark_text( lineBreak="\n", align="left", baseline="top", dx=5, dy=5, fontSize=11 ) .encode( x=alt.value(0), y=alt.value(0), ) .transform_aggregate( pearson_r2="mean(pearson_r2)", r_squared="mean(r_squared)", groupby=["variant_label"], ) .transform_calculate( multiline_label="'Pearson R\u00b2 = ' + format(datum.pearson_r2, '.2f') + '\\n' + " "'R\u00b2 to y=x is ' + format(datum.r_squared, '.2f')" ) .encode(text="multiline_label:N") ) final_chart = ( alt.layer(points, x_errorbars, y_errorbars, annotation, data=df_all) .properties(width=300, height=300) .facet(facet=alt.Facet("variant_label:N", title="Variant"), columns=4) .resolve_scale(y="independent") .configure_view(strokeWidth=0, fill=None) .properties(title="Central Carbon Metabolism Flux by Variant") ) final_chart.save(os.path.join(outdir, "centralCarbonMetabolismScatter.html"))