import os
from typing import Any
import altair as alt
import pickle
import polars as pl
import numpy as np
from duckdb import DuckDBPyConnection
import pandas as pd
from ecoli.library.parquet_emitter import open_arbitrary_sim_data, named_idx
from ecoli.library.schema import bulk_name_to_idx
# ----------------------------------------- #
[docs]
def calc_rna_doubling_time(
    produced_col: str, count_col: str, borderline: float
) -> pl.Expr:
    """
    Calculate rRNA doubling time with sanitation.
    """
    production_rate = pl.col(produced_col) / pl.col("time_step_sec")
    growth_rate = production_rate / pl.col(count_col)
    dt_min = float(np.log(2)) / growth_rate / 60
    valid = (
        (pl.col(produced_col) >= 0)
        & (pl.col(count_col) > 0)
        & (growth_rate > 0)
        & dt_min.is_finite()
        & (dt_min > 0)
        & (dt_min < 2 * borderline)
    )
    return pl.when(valid).then(dt_min).otherwise(None) 
[docs]
def plot(
    params: dict[str, Any],
    conn: DuckDBPyConnection,
    history_sql: str,
    config_sql: str,
    success_sql: str,
    sim_data_dict: dict[str, dict[int, str]],
    validation_data_paths: list[str],
    outdir: str,
    variant_metadata: dict[str, dict[int, Any]],
    variant_names: dict[str, str],
):
    """Visualize ribosome production metrics for E. coli simulation."""
    with open_arbitrary_sim_data(sim_data_dict) as f:
        sim_data = pickle.load(f)
    sim_doubling_time = sim_data.doubling_time.asNumber()
    # define rRNA groups and bulk IDs
    s30_16s = list(sim_data.molecule_groups.s30_16s_rRNA) + [
        sim_data.molecule_ids.s30_full_complex
    ]
    s50_23s = list(sim_data.molecule_groups.s50_23s_rRNA) + [
        sim_data.molecule_ids.s50_full_complex
    ]
    s50_5s = list(sim_data.molecule_groups.s50_5s_rRNA) + [
        sim_data.molecule_ids.s50_full_complex
    ]
    bulk_ids = sim_data.internal_state.bulk_molecules.bulk_data["id"].tolist()
    # precompute indices as Python ints
    idx_16s = [int(i) for i in np.atleast_1d(bulk_name_to_idx(s30_16s, bulk_ids))]
    idx_23s = [int(i) for i in np.atleast_1d(bulk_name_to_idx(s50_23s, bulk_ids))]
    idx_5s = [int(i) for i in np.atleast_1d(bulk_name_to_idx(s50_5s, bulk_ids))]
    required_columns = [
        "time",
        "variant",
        "generation",
        "agent_id",
        "listeners__mass__instantaneous_growth_rate",
        "listeners__mass__dry_mass",
        "listeners__ribosome_data__rRNA16S_initiated",
        "listeners__ribosome_data__rRNA23S_initiated",
        "listeners__ribosome_data__rRNA5S_initiated",
        "listeners__ribosome_data__rRNA16S_init_prob",
        "listeners__ribosome_data__rRNA23S_init_prob",
        "listeners__ribosome_data__rRNA5S_init_prob",
        "listeners__ribosome_data__effective_elongation_rate",
        "listeners__unique_molecule_counts__active_ribosome",
    ]
    # load data
    # Create the bulk index expressions
    bulk_16s_expr = named_idx("bulk", [f"bulk_{i}" for i in idx_16s], [idx_16s])
    bulk_23s_expr = named_idx("bulk", [f"bulk_{i}" for i in idx_23s], [idx_23s])
    bulk_5s_expr = named_idx("bulk", [f"bulk_{i}" for i in idx_5s], [idx_5s])
    # Combine all columns and expressions
    all_columns = ", ".join(required_columns)
    bulk_expressions = ", ".join([bulk_16s_expr, bulk_23s_expr, bulk_5s_expr])
    # Build the SQL query
    sql = f"""
    SELECT {all_columns}, {bulk_expressions}
    FROM ({history_sql})
    WHERE agent_id = 0
    ORDER BY generation, time
    """
    df = conn.sql(sql).pl()
    # time
    df = df.with_columns((pl.col("time") / 60).alias("time_min"))
    df = df.with_columns(
        pl.col("time")
        .diff()
        .over(["variant", "generation", "agent_id"])
        .alias("time_step_sec")
    )
    df = df.with_columns(
        time_step_sec=pl.when(pl.col("time_step_sec").is_null())
        .then(pl.col("time"))
        .otherwise(pl.col("time_step_sec"))
    )
    # cell doubling time
    if "listeners__mass__instantaneous_growth_rate" in df.columns:
        val = (
            float(np.log(2)) / pl.col("listeners__mass__instantaneous_growth_rate") / 60
        )
        df = df.with_columns(
            pl.when(val.is_between(0, 2 * sim_doubling_time, closed="both"))
            .then(val)
            .otherwise(None)
            .alias("cell_doubling_time_min")
        )
    df = df.with_columns(
        [
            pl.sum_horizontal([pl.col(f"bulk_{i}") for i in idx_16s]).alias(
                "bulk_16s_count"
            ),
            pl.sum_horizontal([pl.col(f"bulk_{i}") for i in idx_23s]).alias(
                "bulk_23s_count"
            ),
            pl.sum_horizontal([pl.col(f"bulk_{i}") for i in idx_5s]).alias(
                "bulk_5s_count"
            ),
            pl.col("listeners__unique_molecule_counts__active_ribosome")
            .fill_null(0)
            .alias("ribosome_count"),
        ]
    )
    # total rRNA
    df = df.with_columns(
        [
            (pl.col("bulk_16s_count") + pl.col("ribosome_count")).alias("rrn16s_count"),
            (pl.col("bulk_23s_count") + pl.col("ribosome_count")).alias("rrn23s_count"),
            (pl.col("bulk_5s_count") + pl.col("ribosome_count")).alias("rrn5s_count"),
        ]
    )
    # rRNA doubling times
    if "listeners__ribosome_data__rRNA16S_initiated" in df.columns:
        df = df.with_columns(
            rrn16S_doubling_time_min=calc_rna_doubling_time(
                "listeners__ribosome_data__rRNA16S_initiated",
                "rrn16s_count",
                sim_doubling_time,
            )
        )
    if "listeners__ribosome_data__rRNA23S_initiated" in df.columns:
        df = df.with_columns(
            rrn23S_doubling_time_min=calc_rna_doubling_time(
                "listeners__ribosome_data__rRNA23S_initiated",
                "rrn23s_count",
                sim_doubling_time,
            )
        )
    if "listeners__ribosome_data__rRNA5S_initiated" in df.columns:
        df = df.with_columns(
            rrn5S_doubling_time_min=calc_rna_doubling_time(
                "listeners__ribosome_data__rRNA5S_initiated",
                "rrn5s_count",
                sim_doubling_time,
            )
        )
    # reference probabilities
    cond = sim_data.condition
    trans = sim_data.process.transcription
    synth_probs = trans.cistron_tu_mapping_matrix.dot(trans.rna_synth_prob[cond])
    def fit_prob(group_ids):
        cistrons = [rid[:-3] for rid in group_ids]
        idxs = np.where(np.isin(trans.cistron_data["id"], cistrons))[0]
        return synth_probs[idxs].sum() if idxs.size else 0.0
    ref_probs = {
        "16S": fit_prob(sim_data.molecule_groups.s30_16s_rRNA),
        "23S": fit_prob(sim_data.molecule_groups.s50_23s_rRNA),
        "5S": fit_prob(sim_data.molecule_groups.s50_5s_rRNA),
    }
    # ----------------------------------------- #
    # prepare for plotting
    plot_cols = ["time_min", "variant", "generation"]
    for c in [
        "listeners__mass__dry_mass",
        "cell_doubling_time_min",
        "rrn16S_doubling_time_min",
        "rrn23S_doubling_time_min",
        "rrn5S_doubling_time_min",
        "rrn16S_init_prob",
        "rrn23S_init_prob",
        "rrn5S_init_prob",
        "listeners__ribosome_data__effective_elongation_rate",
    ]:
        if c in df.columns:
            plot_cols.append(c)
    plot_df = df.select(plot_cols)
    init_dm = (
        plot_df.filter(pl.col("time_min") == 0)
        .select(["variant", "listeners__mass__dry_mass"])
        .rename({"listeners__mass__dry_mass": "initial_dry_mass"})
    )
    plot_df = plot_df.join(init_dm, on=["variant"], how="left")
    plot_df = plot_df.with_columns(
        (pl.col("listeners__mass__dry_mass") / pl.col("initial_dry_mass")).alias(
            "dry_mass_normalized"
        )
    )
    # generate Altair charts
    def create_line_chart(y, title, y_title, ref=None):
        base = alt.Chart(plot_df)
        line = (
            base.mark_line()
            .encode(
                x=alt.X("time_min:Q", title="Time (min)"),
                y=alt.Y(f"{y}:Q", title=y_title),
                color=alt.Color(
                    "generation:N",
                    legend=alt.Legend(title="Simulated Multigeneration Data"),
                ),
            )
            .properties(title=title, width=600, height=120)
        )
        if ref is not None:
            rule = (
                alt.Chart(pd.DataFrame({"y": [ref]}))
                .mark_rule(color="red", strokeDash=[5, 5])
                .encode(y="y:Q")
            )
            return line + rule
        return line
    def create_histogram(
        col: str, title: str, bins: int = 30, probability: bool = False
    ) -> alt.Chart:
        if probability:
            density = (
                alt.Chart(plot_df)
                .transform_density(col, as_=[col, "density"], counts=False, steps=bins)
                .mark_area(opacity=0.6)
                .encode(
                    x=alt.X(f"{col}:Q", title=f"bin={bins}"),
                    y=alt.Y("density:Q", title="Density"),
                )
                .properties(width=200, height=120, title=title)
            )
            return density
        else:
            hist = (
                alt.Chart(plot_df)
                .mark_bar(opacity=0.6)
                .encode(
                    x=alt.X(f"{col}:Q", bin=alt.Bin(maxbins=bins), title=f"bin={bins}"),
                    y=alt.Y("count():Q", title="Count"),
                    color=alt.value("steelblue"),
                )
                .properties(width=200, height=120, title=title)
            )
            return hist
    plots = []
    # Dry mass
    if "dry_mass_normalized" in plot_df.columns:
        line = create_line_chart(
            "dry_mass_normalized",
            "Normalized Dry Mass Over Time",
            "Dry mass (relative to t=0)",
        )
        hist = create_histogram(
            "dry_mass_normalized", "Normalized Dry Mass Distribution", probability=True
        )
        plots.append(alt.hconcat(line, hist))
    # Cell Doubling Time
    if "cell_doubling_time_min" in plot_df.columns:
        line = create_line_chart(
            "cell_doubling_time_min",
            "Cell Doubling Time",
            "Doubling Time (min)",
            sim_doubling_time,
        )
        hist = create_histogram(
            "cell_doubling_time_min",
            "Cell Doubling Time (min) Distribution",
            probability=True,
        )
        plots.append(alt.hconcat(line, hist))
    # rRNA Doubl;ing Time
    for suffix in ["16S", "23S", "5S"]:
        col = f"rrn{suffix}_doubling_time_min"
        if col in plot_df.columns:
            line = create_line_chart(
                col,
                f"{suffix} rRNA Doubling Time",
                "Doubling Time (min)",
                sim_doubling_time,
            )
            hist = create_histogram(
                col, f"{suffix} rRNA Doubling Time Distribution", probability=True
            )
            plots.append(alt.hconcat(line, hist))
    # rRNA Initiation Probability
    for suffix, ref in ref_probs.items():
        col = f"rrn{suffix}_init_prob"
        if col in plot_df.columns:
            line = create_line_chart(
                col, f"{suffix} rRNA Initiation Probability", "Probability", ref
            )
            hist = create_histogram(
                col,
                f"{suffix} rRNA Initiation Probability Distribution",
                probability=True,
            )
            plots.append(alt.hconcat(line, hist))
    # Ribosome Elongation Rate
    if "listeners__ribosome_data__effective_elongation_rate" in plot_df.columns:
        line = create_line_chart(
            "listeners__ribosome_data__effective_elongation_rate",
            "Ribosome Elongation Rate",
            "Amino acids/s",
        )
        hist = create_histogram(
            "listeners__ribosome_data__effective_elongation_rate",
            "Ribosome Elongation Rate Distribution",
            probability=True,
        )
        plots.append(alt.hconcat(line, hist))
    if not plots:
        fallback = pl.DataFrame({"message": ["No data available"], "x": [0], "y": [0]})
        plots.append(
            alt.Chart(fallback)
            .mark_text(size=20, color="red")
            .encode(x="x:Q", y="y:Q", text="message:N")
            .properties(width=600, height=400, title="No Data")
        )
    combined = (
        alt.vconcat(*plots)
        .resolve_scale(x="shared", y="independent")
        .properties(title="Ribosome Production Metrics")
    )
    out_path = os.path.join(outdir, "ribosome_production_report.html")
    combined.save(out_path)
    print(f"Saved visualization to: {out_path}")
    return combined