Source code for ecoli.analysis.multivariant.kinetic_flux

"""
Kinetic flux analysis for multivariant metabolism_redux_classic simulations.

Two rows of faceted subplots, one column per variant:
  Top row:    Scatter — per-reaction log10(avg kinetic target + ε) vs
              log10(avg estimated kinetic flux + ε) in mmol/(L·h), with a
              dashed y=x reference line and Pearson R² / R²-to-y=x annotation.
  Bottom row: Line — weighted kinetic objective term over continuous simulation
              time, broken at cell division.

Variant panels are labeled with the fraction_kinetic_target value when
available from variant_metadata.
"""

from __future__ import annotations

import os
from typing import Any, TYPE_CHECKING

import altair as alt
import numpy as np
import plotly.express as px
import polars as pl

from ecoli.analysis.multivariant.utils import create_variant_label
from ecoli.library.parquet_emitter import (
    field_metadata,
    ndlist_to_ndarray,
    read_stacked_columns,
)

if TYPE_CHECKING:
    from duckdb import DuckDBPyConnection

alt.data_transformers.enable("vegafusion")

PASTEL = px.colors.qualitative.Pastel

# Tolerance added before log10 to handle zero fluxes
LOG_EPS = 1e-8
# Seconds per hour — converts mmol/(L·s) → mmol/(L·h)
S_PER_HR = 3600.0
FLUX_UNIT_STR = "mmol/(L·h)"


[docs] def plot( params: dict[str, Any], conn: "DuckDBPyConnection", history_sql: str, config_sql: str, success_sql: str, sim_data_dict: dict[str, dict[int, str]], validation_data_paths: list[str], outdir: str, variant_metadata: dict[str, dict[int, Any]], variant_names: dict[str, str], ) -> None: """ Faceted scatter (log avg kinetic target vs log avg kinetic flux) and kinetic term over time, one column per variant. """ # ── Resolve per-variant parameter dicts ─────────────────────────────────── experiment_id = next(iter(variant_metadata.keys()), None) per_variant_params: dict[int, Any] = ( variant_metadata[experiment_id] if experiment_id else {} ) # ── Metadata: reaction name lists ───────────────────────────────────────── kinetic_rxn_names: list[str] = field_metadata( conn, config_sql, "listeners__fba_results__target_kinetic_fluxes" ) all_rxn_names: list[str] = field_metadata( conn, config_sql, "listeners__fba_results__solution_fluxes" ) kinetic_indices = np.array( [all_rxn_names.index(name) for name in kinetic_rxn_names], dtype=int ) # ── Load raw listener data ───────────────────────────────────────────────── raw = pl.DataFrame( read_stacked_columns( history_sql, [ "listeners__fba_results__target_kinetic_fluxes AS target_kinetic_fluxes", "listeners__fba_results__estimated_fluxes AS estimated_fluxes", "listeners__fba_results__kinetics_term AS kinetics_term", "listeners__enzyme_kinetics__counts_to_molar AS counts_to_molar", ], order_results=True, conn=conn, remove_first=True, ) ) if raw.is_empty(): print("kinetic_flux_analysis: no rows returned; skipping.") return # Continuous relative time per lineage_seed min_t = raw.group_by(["lineage_seed"]).agg(pl.col("time").min().alias("t_min")) raw = raw.join(min_t, on=["lineage_seed"]) raw = raw.with_columns( ((pl.col("time") - pl.col("t_min")) / 60).alias("Time (min)") ) # ── Variant label mapping ────────────────────────────────────────────────── unique_variants: list[int] = sorted(raw["variant"].unique().to_list()) def _make_label(v: int) -> str: raw_label = create_variant_label(v, per_variant_params) return " ".join(raw_label) if isinstance(raw_label, list) else raw_label variant_label_map = {v: _make_label(v) for v in unique_variants} variant_labels = [variant_label_map[v] for v in unique_variants] color_range = PASTEL[: len(unique_variants)] color_scale = alt.Scale(domain=variant_labels, range=color_range) # ── Numpy arrays ────────────────────────────────────────────────────────── target_arr = ndlist_to_ndarray(raw["target_kinetic_fluxes"]) # (T, n_kinetic) estimated_arr = ndlist_to_ndarray(raw["estimated_fluxes"]) # (T, n_all_rxns) kinetic_flux_arr = estimated_arr[:, kinetic_indices] # (T, n_kinetic) # counts_to_molar [mmol/L per count]; multiply by S_PER_HR to get mmol/(L·h) counts_to_molar = raw["counts_to_molar"].to_numpy()[:, np.newaxis] * S_PER_HR variants_col = np.array(raw["variant"].to_list()) # ── Build scatter DataFrame ──────────────────────────────────────────────── # Each data row is one (reaction, variant) pair averaged over all timesteps. # Two extra rows per variant encode the y=x reference line endpoints on the # log-transformed axes. scatter_rows: list[dict] = [] for v in unique_variants: label = variant_label_map[v] mask = variants_col == v ctm = counts_to_molar[mask] # (T_v, 1) mean_target = (target_arr[mask] * ctm).mean(axis=0) # mmol/(L·h) mean_flux = (kinetic_flux_arr[mask] * ctm).mean(axis=0) log_target = np.log10(mean_target + LOG_EPS) log_flux = np.log10(mean_flux + LOG_EPS) # R² metrics computed on log-space values ss_res = np.sum((log_flux - log_target) ** 2) ss_tot = np.sum((log_target - log_target.mean()) ** 2) r2_to_yx = float(1 - ss_res / ss_tot) if ss_tot > 0 else float("nan") pearson_r2 = float(np.corrcoef(log_flux, log_target)[0, 1]) ** 2 for i, rxn in enumerate(kinetic_rxn_names): scatter_rows.append( { "Reaction": rxn, "log_target": float(log_target[i]), "log_flux": float(log_flux[i]), "Variant": label, "r2_to_yx": r2_to_yx, "pearson_r2": pearson_r2, "is_ref": False, } ) # y=x reference line endpoints in log space for this variant's scale ref_lo = float(min(log_target.min(), log_flux.min())) ref_hi = float(max(log_target.max(), log_flux.max())) for ref_val in (ref_lo, ref_hi): scatter_rows.append( { "Reaction": f"_ref_{ref_val}", "log_target": ref_val, "log_flux": ref_val, "Variant": label, "r2_to_yx": r2_to_yx, "pearson_r2": pearson_r2, "is_ref": True, } ) scatter_df = pl.DataFrame(scatter_rows).to_pandas() log_axis_title_x = f"log₁₀(Mean Kinetic Target + ε) [{FLUX_UNIT_STR}]" log_axis_title_y = f"log₁₀(Mean Kinetic Flux + ε) [{FLUX_UNIT_STR}]" # ── Scatter layer definitions ────────────────────────────────────────────── ref_line = ( alt.Chart() .mark_line(color="lightgray", strokeDash=[5, 4], strokeWidth=1.2) .transform_filter("datum.is_ref") .encode( x=alt.X("log_target:Q"), y=alt.Y("log_flux:Q"), ) ) scatter_pts = ( alt.Chart() .mark_circle(size=55, opacity=0.75) .transform_filter("!datum.is_ref") .encode( x=alt.X("log_target:Q", title=log_axis_title_x), y=alt.Y("log_flux:Q", title=log_axis_title_y), color=alt.Color("Variant:N", scale=color_scale, legend=None), tooltip=[ alt.Tooltip("Reaction:N"), alt.Tooltip("Variant:N"), alt.Tooltip("log_target:Q", title="log₁₀(target)", format=".3f"), alt.Tooltip("log_flux:Q", title="log₁₀(flux)", format=".3f"), ], ) ) annotation = ( alt.Chart() .mark_text( lineBreak="\n", align="left", baseline="top", dx=5, dy=5, fontSize=10, ) .transform_filter("!datum.is_ref") .transform_aggregate( pearson_r2="mean(pearson_r2)", r2_to_yx="mean(r2_to_yx)", groupby=["Variant"], ) .transform_calculate( annotation_label=( "'Pearson R\u00b2 = ' + format(datum.pearson_r2, '.2f') + '\\n'" "+ 'R\u00b2 to y=x = ' + format(datum.r2_to_yx, '.2f')" ) ) .encode( x=alt.value(5), y=alt.value(5), text="annotation_label:N", ) ) num_cols = min(len(unique_variants), 4) scatter_faceted = ( alt.layer(ref_line, scatter_pts, annotation, data=scatter_df) .properties(width=280, height=280) .facet( facet=alt.Facet("Variant:N", title="Variant"), columns=num_cols, ) .resolve_scale(x="independent", y="independent") .properties( title=f"log₁₀(Avg Kinetic Target + ε) vs log₁₀(Avg Estimated Flux + ε) [ε={LOG_EPS}]" ) ) # ── Line-plot DataFrame ──────────────────────────────────────────────────── variant_label_col = [variant_label_map[v] for v in raw["variant"].to_list()] line_df = ( raw.select(["Time (min)", "generation", "lineage_seed", "kinetics_term"]) .with_columns(pl.Series("Variant", variant_label_col)) .to_pandas() ) line_faceted = ( alt.Chart(line_df) .mark_line(strokeWidth=1.3, opacity=0.85) .encode( x=alt.X("Time (min):Q", title="Time (min)"), y=alt.Y("mean(kinetics_term):Q", title="Unweighted Kinetic Term"), color=alt.Color("Variant:N", scale=color_scale, legend=None), detail=alt.Detail("generation:N"), ) .properties(width=280, height=200) .facet( facet=alt.Facet("Variant:N", title="Variant"), columns=num_cols, ) .resolve_scale(x="independent", y="independent") .properties(title="Kinetic Objective Term Over Time") ) # ── Combine into faceted 2-row layout ───────────────────────────────────── final = ( alt.vconcat(scatter_faceted, line_faceted) .resolve_scale(color="shared") .properties(title="Kinetic Flux Analysis by Variant") ) out_path = os.path.join(outdir, "kinetic_flux_analysis.html") final.save(out_path) print(f"Saved kinetic flux analysis to: {out_path}")