Source code for ecoli.processes.listeners.aggregator

from vivarium.core.process import Step
from vivarium.library.topology import assoc_path, get_in


[docs] class Aggregator(Step): """ Given a list of paths and a list of functions, this will apply the ith function to the ith path and write the results through `aggregated`. If a list of functions is not supplied, len is used for all paths. """ name = "aggregator" defaults: dict[str, tuple] = {"paths": tuple(), "funcs": tuple()} def __init__(self, parameters): super().__init__(parameters) self.paths = self.parameters["paths"] if not self.parameters["funcs"]: self.funcs = (len,) * len(self.paths) else: self.funcs = self.parameters["funcs"]
[docs] def ports_schema(self): schema = {} variables = [] for path, func in zip(self.paths, self.funcs): assoc_path(schema, path, {"_default": {}}) variable = f"{path[-1]}_{func.__name__}" assert variable not in variables variables.append(variable) assert "aggregated" not in schema schema["aggregated"] = { variable: { "_default": 0, "_divider": "zero", "_updater": "set", "_emit": True, } for variable in variables } return schema
[docs] def next_update(self, timestep, states): counts = {} for path, func in zip(self.paths, self.funcs): variable = f"{path[-1]}_{func.__name__}" assert variable not in counts counts[variable] = func(get_in(states, path)) assert counts[variable] is not None return {"aggregated": counts}
[docs] def len_squared(x): return len(x) ** 2
[docs] def len_plus_one(x): return len(x) + 1
def test_aggregator(): state = { "a": { "b": { 1: 0, 2: 0, 3: 0, }, "c": {}, }, } proc = Aggregator( { "paths": (("a", "b"), ("a", "c"), ("a", "b")), "funcs": (len_squared, len_plus_one, len_plus_one), } ) schema = proc.get_schema() expected_schema = { "a": { "b": { "_default": {}, }, "c": { "_default": {}, }, }, "aggregated": { "b_len_squared": { "_default": 0, "_divider": "zero", "_updater": "set", "_emit": True, }, "c_len_plus_one": { "_default": 0, "_divider": "zero", "_updater": "set", "_emit": True, }, "b_len_plus_one": { "_default": 0, "_divider": "zero", "_updater": "set", "_emit": True, }, }, } assert schema == expected_schema update = proc.next_update(0, state) expected_update = { "aggregated": { "b_len_squared": 9, "c_len_plus_one": 1, "b_len_plus_one": 4, } } assert update == expected_update if __name__ == "__main__": test_aggregator()