import altair as alt
# noinspection PyUnresolvedReferences
from duckdb import DuckDBPyConnection
import polars as pl
from typing import Any, cast
from ecoli.library.parquet_emitter import read_stacked_columns, skip_n_gens
[docs]
def plot(
params: dict[str, Any],
conn: DuckDBPyConnection,
history_sql: str,
config_sql: str,
success_sql: str,
sim_data_dict: dict[str, dict[int, str]],
validation_data_paths: list[str],
outdir: str,
variant_metadata: dict[str, dict[int, Any]],
variant_names: dict[str, str],
):
"""
Histogram of doubling times with average marked.
Configure number of initial generations to skip using ``skip_n_gens`` key
in params.
"""
doubling_time_sql = read_stacked_columns(
history_sql,
["time"],
order_results=False,
success_sql=success_sql,
)
# Skip first 8 generations to avoid initialization bias
doubling_time_sql = skip_n_gens(doubling_time_sql, params["skip_n_gens"])
doubling_times = conn.sql(f"""
SELECT (max(time) - min(time)) / 60 AS 'Doubling Time (min)', experiment_id, variant, lineage_seed, generation, agent_id
FROM ({doubling_time_sql})
GROUP BY experiment_id, variant, lineage_seed, generation, agent_id
""").pl()
# Calculate the average doubling time
avg_doubling_time = doubling_times["Doubling Time (min)"].mean()
if avg_doubling_time is None:
raise ValueError("No doubling times found in the data.")
# Round for display
avg_doubling_time_rounded = round(cast(float, avg_doubling_time), 2)
# Create the base histogram - Define the main X-axis here
hist = (
alt.Chart(doubling_times)
.mark_bar()
.encode(
x=alt.X(
"Doubling Time (min)",
bin=alt.Bin(maxbins=40),
axis=alt.Axis(title="Doubling Time (min)", labelFlush=False),
), # Explicit title
y=alt.Y("count()", axis=alt.Axis(title="Frequency")), # Nicer Y-axis title
tooltip=[
alt.Tooltip("Doubling Time (min)", bin=alt.Bin(maxbins=40)),
"count()",
],
)
)
# Create a DataFrame for the rule and text
avg_df = pl.DataFrame({"avg": [avg_doubling_time]})
# Create the vertical rule for the average - Disable its X-axis labels/title
rule = (
alt.Chart(avg_df)
.mark_rule(color="red", strokeDash=[5, 5], size=2)
.encode(
x=alt.X("avg:Q"),
tooltip=[
alt.Tooltip("avg", title=f"Average: {avg_doubling_time_rounded} min")
],
)
)
# Create the text label for the average line - Disable its X-axis labels/title
text = (
alt.Chart(avg_df)
.mark_text(align="left", baseline="middle", dx=7, dy=-20, color="red")
.encode(
x=alt.X("avg:Q"),
text=alt.value(f"{avg_doubling_time_rounded} min"),
tooltip=[
alt.Tooltip("avg", title=f"Average: {avg_doubling_time_rounded} min")
],
)
)
# Combine the histogram, the rule, and the text label
# The final chart will use the axis definitions from the 'hist' layer
chart = hist + rule + text
# Add overall chart title
chart = chart.properties(title="Distribution of Doubling Times")
# Optional: Add interactivity if desired
chart = chart.interactive()
# Save the chart
chart.save(f"{outdir}/doubling_time_histogram.html")