import altair as alt
# noinspection PyUnresolvedReferences
from duckdb import DuckDBPyConnection
import polars as pl
from typing import Any
from ecoli.library.parquet_emitter import read_stacked_columns
[docs]
def plot(
params: dict[str, Any],
conn: DuckDBPyConnection,
history_sql: str,
config_sql: str,
success_sql: str,
sim_data_dict: dict[str, dict[int, str]],
validation_data_paths: list[str],
outdir: str,
variant_metadata: dict[str, dict[int, Any]],
variant_names: dict[str, str],
):
"""
Line plot of doubling time vs generation for each lineage seed. Only works for lineage
simulations with ``single_daughters`` set to True.
"""
doubling_time_sql = read_stacked_columns(
history_sql,
["time"],
order_results=False,
)
doubling_times = conn.sql(f"""
SELECT (max(time) - min(time)) / 3600 AS 'Doubling Time (hrs)', experiment_id, variant, lineage_seed, generation, agent_id
FROM ({doubling_time_sql})
GROUP BY experiment_id, variant, Seed, Generation, agent_id
""").pl()
successful_sims = conn.sql(success_sql).pl()
doubling_times = doubling_times.join(
successful_sims,
how="semi",
on=["experiment_id", "variant", "lineage_seed", "generation", "agent_id"],
)
death_times = doubling_times.join(
successful_sims,
how="anti",
on=["experiment_id", "variant", "lineage_seed", "generation", "agent_id"],
)
selection = alt.selection_point(fields=["Seed"], bind="legend")
chart = (
alt.Chart(doubling_times)
.mark_line()
.encode(
x="Generation",
y="Doubling Time (hr)",
color=alt.Color("Seed", type="nominal"),
tooltip=["Doubling Time (hr)", "Seed"],
opacity=alt.when(selection).then(alt.value(1)).otherwise(alt.value(0.2)),
)
.add_params(selection)
.interactive()
)
death_points = (
alt.Chart(death_times)
.mark_point(shape="cross")
.encode(
x="Generation",
y="Doubling Time (hr)",
color=alt.Color("Seed", type="nominal"),
opacity=alt.when(selection).then(alt.value(1)).otherwise(alt.value(0.2)),
tooltip=["Doubling Time (hr)", "Seed"],
)
)
exp_avg = alt.Chart().mark_rule(strokeDash=[2, 2]).encode(y=alt.datum(1 / 0.47))
sim_avg_df = doubling_times.group_by("experiment_id", "variant", "Generation").agg(
pl.mean("Doubling Time (hr)")
)
sim_avg = (
alt.Chart(sim_avg_df)
.mark_line(strokeDash=[2, 2])
.encode(x="Generation", y="Doubling Time (hr)", tooltip=["Doubling Time (hr)"])
)
chart = chart + exp_avg + sim_avg + death_points
chart.save(f"{outdir}/doubling_time.html")