Source code for ecoli.library.cell_wall.hole_detection

import os
from import MutableMapping
from functools import reduce
from operator import __or__
from time import perf_counter

import matplotlib.pyplot as plt
import numpy as np
import pytest
from skimage import measure

[docs] class HoleSizeDict(MutableMapping): def __init__(self, data=None): if not data: data = {} self.mapping = {} self.roots = set() self.largest_hole = 0 self.update(data) if len(data) > 0: self.largest_hole = max(v for v in data.values() if isinstance(v, int)) def __getitem__(self, key: frozenset): result = self.mapping[key] while not isinstance(result, int): result = self.mapping[result] return result def __delitem__(self, key): if key in self.roots: self.prune_subtree(key) return # Need to remap anything that maps to this key destination = self.get_containing_hole(key) to_remap = set() for k, v in self.items(): if v == key: to_remap.add(k) for k in to_remap: self.mapping[k] = destination del self.mapping[key] def __setitem__(self, key: frozenset, value: int): loc = key while not isinstance(self.mapping.get(loc, 0), int): loc = self.mapping[loc] # Store value, update maximum if necessary self.mapping[loc] = value self.roots.add(loc) if value > self.largest_hole: self.largest_hole = value
[docs] def merge(self, holes): containing_holes = {self.get_containing_hole(hole) for hole in holes} # Only merge if some entries do not already map to same location merged_hole = reduce(__or__, containing_holes) if len(containing_holes) > 1: new_size = sum(self.mapping[hole] for hole in containing_holes) self.mapping[merged_hole] = new_size for hole in containing_holes: self.mapping[hole] = merged_hole # No longer a root node after merge self.roots.remove(hole) if new_size > self.largest_hole: self.largest_hole = new_size self.roots.add(merged_hole) return merged_hole
[docs] def prune_subtree(self, root): if not isinstance(self.mapping[root], int): raise ValueError(f"Node {root} is not a root node of a subtree.") for leaf in root: leaf = frozenset([leaf]) current = leaf while current != root and current in self.mapping: next_node = self.mapping[current] del self.mapping[current] current = next_node del self.mapping[root] self.roots.remove(root)
[docs] def get_containing_hole(self, hole): containing_hole = hole while not isinstance(self.mapping[containing_hole], int): containing_hole = self.mapping[containing_hole] return containing_hole
[docs] def max(self): return self.largest_hole
[docs] def get_depth(self): # Returns the length of the longest branch in the tree. # Somewhat naive implementation, but then again, shouldn't need this often. depth = 1 seen = set() for key in self.mapping: if key in seen: continue seen.add(key) key_depth = 1 while not isinstance(self.mapping[key], int): key = self.mapping[key] key_depth += 1 if key_depth > depth: depth = key_depth return depth
def __iter__(self): return iter(self.mapping) def __len__(self): return len(self.mapping) def __repr__(self): return f"{type(self).__name__}({self.mapping})"
[docs] def detect_holes(lattice, on_cylinder=True, critical_size=None, prune_subtrees=True): # Create "hole view" of lattice. # Each position contains a set of integers representing the id of the hole # containing that position (or an empty set if that position is not a hole). # # Ids are sets because two holes that are initially considered separate may # later be discovered to be the same hole. # Fill hole view initially with empty sets hole_view = np.full_like(lattice, frozenset(), dtype=object) hole_sizes = HoleSizeDict() # Subtree pruning to save memory will occur when # len(hole_sizes) exceeds this value next_prune_size = 100 # root nodes in this set are immune to pruning # (used to protect top and bottom in cylindrical case) prune_immune = set() next_hole_id = 1 rows, cols = lattice.shape for r in range(rows): ids_in_row = set() for c in range(cols): # Skip non-holes if lattice[r, c] == 1: continue # Get neighbors of the current hole. # The following diagram indicates which positions # relative to the current position (X) # are considered neighbors (N): # # N N N # N X neighbor_pos = {(r - 1, c - 1), (r - 1, c), (r - 1, c + 1), (r, c - 1)} # list({...}) removes duplicates neighbor_holes = list( { hole_view[n_r, n_c] for n_r, n_c in neighbor_pos if 0 <= n_r < rows and 0 <= n_c < cols and len(hole_view[n_r, n_c]) > 0 } ) if len(neighbor_holes) == 0: # Creating new hole new_id = frozenset([next_hole_id]) hole_sizes[new_id] = 1 next_hole_id += 1 else: new_id = hole_sizes.merge(neighbor_holes) hole_sizes[new_id] += 1 hole_view[r, c] = new_id for primitive_id in new_id: ids_in_row.add(frozenset([primitive_id])) # Prune immunity if on_cylinder and r == 0: for id in ids_in_row: prune_immune.add(hole_sizes.get_containing_hole(id)) # prune the tree for memory efficiency # if a subtree was not seen in this row, # that hole/subtree is not coming back # (except due to cylindrical wraparound) if prune_subtrees and len(hole_sizes) >= next_prune_size and r != rows - 1: prune_immune = { hole_sizes.get_containing_hole(hole) for hole in prune_immune } subtrees_to_prune = set() for k in hole_sizes.roots: # prune if none of its leaves/branches were seen if not any( frozenset([primitive_id]) in ids_in_row for primitive_id in k ): subtrees_to_prune.add(k) for subtree in subtrees_to_prune: if subtree not in prune_immune: hole_sizes.prune_subtree(subtree) next_prune_size = int(np.exp(np.ceil(np.log(len(hole_sizes))))) # Early stopping if reached critical size if critical_size and hole_sizes.max() >= critical_size: break if on_cylinder: # Merge holes at top and bottom for c in range(cols): # Skip non-holes if lattice[0, c] == 1: continue neighbor_pos = {(0, c), (rows - 1, c - 1), (rows - 1, c), (rows - 1, c + 1)} neighbor_holes = { hole_view[n_r, n_c] for n_r, n_c in neighbor_pos if 0 <= n_r < rows and n_c < cols and len(hole_view[n_r, n_c]) > 0 } if len(neighbor_holes) > 1: hole_sizes.merge(neighbor_holes) return hole_sizes, hole_view
[docs] def detect_holes_skimage(lattice, on_cylinder=True): hole_view = measure.label(lattice, background=1, connectivity=2) if on_cylinder: # merge holes bordering the bottom edge # with holes bordering the top edge for c in range(hole_view.shape[1]): here = hole_view[0, c] # skip non-holes if here == 0: continue neighbors = {(-1, c - 1), (-1, c), (-1, c + 1)} for n_r, n_c in neighbors: if n_c >= 0 and n_c < hole_view.shape[1]: neighbor = hole_view[n_r, n_c] if neighbor != 0 and neighbor != here: hole_view[hole_view == neighbor] = here # Get hole sizes, excluding count of background (murein, label=0) values, counts = np.unique(hole_view.flatten(), return_counts=True) hole_sizes = counts[values != 0] return hole_sizes, hole_view
def test_hole_size_dict(): hsd = HoleSizeDict({frozenset([1]): 1, frozenset([2]): 2}) # Merging hsd.merge([frozenset([1]), frozenset([2])]) assert hsd[frozenset([1, 2])] == 3 # Mapping updates to correct destination hsd[frozenset([1])] += 5 assert hsd[frozenset([1, 2])] == 8 # Subtree pruning try: hsd.prune_subtree(frozenset([1])) assert False, "Expected ValueError (not pruning root node)" except ValueError: pass hsd[frozenset([3])] = 1 hsd.prune_subtree(frozenset([1, 2])) assert len(hsd) == 1 and frozenset([3]) in hsd def test_detect_holes(): # Create output directory os.makedirs("out/hole_detection", exist_ok=True) # Get tests test_files = os.listdir("ecoli/library/cell_wall/test_cases") # Run tests n_passed = 0 for test_case in test_files: print(f"Test case: {test_case}") # Load test case test_array = np.genfromtxt( f"ecoli/library/cell_wall/test_cases/{test_case}", dtype=int, skip_header=1 ) expected_max_size = np.loadtxt( f"ecoli/library/cell_wall/test_cases/{test_case}", dtype=int, max_rows=1 ) for method_name, detection_method in { "detect_holes": detect_holes, "detect_holes_skimage": detect_holes_skimage, }.items(): print(f"Detection method: {method_name}") # Get hole view, size of largest hole hole_sizes, hole_view = detection_method(test_array) max_hole = hole_sizes.max() # Prints and asserts print(f"> Size of largest hole: {max_hole} (Expected {expected_max_size})") passed = max_hole == expected_max_size n_passed += int(passed) print(f"> {'PASSED' if passed else 'FAILED'}") print() # Plot test case, hole view fig, ax = plt.subplots() ax.imshow(test_array, interpolation="nearest", aspect="auto") for r in range(hole_view.shape[0]): for c in range(hole_view.shape[1]): if method_name == "detect_holes": ax.text( c, r, f"{set(hole_view[r, c]) if len(hole_view[r,c]) > 0 else ''}", ha="center", va="center", color="w", ) elif method_name == "detect_holes_skimage": ax.text( c, r, str(hole_view[r, c]) if hole_view[r, c] != 0 else "", ha="center", va="center", color="w", ) ax.set_title(f"Hole View (Max hole detected = {max_hole})") # Save image fig.tight_layout() fig.savefig(f"out/hole_detection/test_{test_case}[{method_name}].png") plt.close() print("===============================================") print(f"Passed {n_passed}/{2 * len(test_files)} tests.") print() @pytest.mark.skip(reason="Used locally to compare skimage and hand-rolled algo.") def test_runtime(): # Runtime plot fig, axs = plt.subplots(nrows=4, ncols=1) side_length = [10, 100, 200, 300, 400, 500] density = np.arange(0, 1.1, 0.1) detection_methods = { "detect_holes_skimage": detect_holes_skimage, "detect_holes": detect_holes, } for method_name, detection_method in detection_methods.items(): rng = np.random.default_rng(0) for d in density: runtimes = [] dict_sizes = [] tree_depths = [] max_hole = [] for s in side_length: a = rng.binomial(1, 1 - d, size=s * s).reshape((s, s)) tick = perf_counter() hole_sizes, _ = detection_method(a) tock = perf_counter() runtimes.append(tock - tick) if method_name == "detect_holes": dict_sizes.append(len(hole_sizes.mapping)) tree_depths.append(hole_sizes.get_depth()) else: dict_sizes.append(hole_sizes.size) tree_depths.append(0) max_hole.append(hole_sizes.max() if len(hole_sizes) > 0 else 0) print( f"[{method_name}] Runtime for side length {s}, density {d:.1f} : {tock-tick} seconds" ) axs[0].plot( side_length, runtimes, label=f"Density={d:.1f}", color=(0, 1 - (2 * d - 1) ** 2, d), ) axs[1].plot( side_length, dict_sizes, label=f"Density={d:.1f}", color=(1 - (2 * d - 1) ** 2, d, 0), ) axs[2].plot( side_length, tree_depths, label=f"Density={d:.1f}", color=(d, 0, 1 - (2 * d - 1) ** 2), ) axs[3].plot( side_length, max_hole, label=f"Density={d:.1f}", color=(d, (2 * d - 1) ** 2, (1 - d) / 2), ) axs[0].set_title("Runtime vs. Side Length, Density") axs[0].set_xlabel("Side length") axs[0].set_ylabel("Runtime (s)") axs[0].legend() axs[1].set_title("Tree Size vs. Side Length, Density") axs[1].set_xlabel("Side Length") axs[1].set_ylabel("Size (nodes)") axs[1].legend() axs[2].set_title("Tree Depth vs. Side Length, Density") axs[2].set_xlabel("Side Length") axs[2].set_ylabel("Tree Depth") axs[2].legend() axs[3].plot( side_length, np.repeat(int((np.pi * 20**2) / 4), len(side_length)), "k--", label="Critical Size", ) axs[3].set_title("Max Hole Size vs. Side Length, Density") axs[3].set_xlabel("Side Length") axs[3].set_ylabel("Maximum Hole Size") axs[3].legend() fig.set_size_inches(8, 24) fig.tight_layout() fig.savefig(f"out/hole_detection/test_runtime[{method_name}].png")
[docs] def run_test_case(side_length, density, rng=np.random.default_rng(0)): a = rng.binomial(1, 1 - density, size=side_length * side_length).reshape( (side_length, side_length) ) detect_holes(a)
@pytest.mark.skip(reason="Not designed to work on the cloud (for local testing only)") def test_merge_time(): import cProfile import pstats os.makedirs("out/hole_detection/merge_profile", exist_ok=True) side_length = [10, 100, 200, 300, 400] density = np.arange(0, 1.1, 0.1) merge_times = np.zeros((len(side_length), len(density))) total_times = np.zeros((len(side_length), len(density))) for r, s in enumerate(side_length): for c, d in enumerate(density): f = f"out/hole_detection/merge_profile/prof_{s}_{int(d*10)}" f"run_test_case({s}, {d})", f, ) p = pstats.Stats(f) p.strip_dirs() try: merge_time = [v for k, v in p.stats.items() if k[2] == "merge"][0][3] except IndexError: merge_time = 0 total_time = [v for k, v in p.stats.items() if k[2] == "detect_holes"][0][3] merge_times[r, c] = merge_time total_times[r, c] = total_time fig, axs = plt.subplots(nrows=len(side_length), ncols=1) # Time vs. Density plots for r, s in enumerate(side_length): axs[r].plot(density, merge_times[r, :], label="Merging time", color="b") axs[r].plot(density, total_times[r, :], label="Total time", color="k") axs[r].set_title(f"Runtime vs. Density (Side length={s})") axs[r].set_xlabel("Density") axs[r].set_ylabel("Runtime (s)") axs[r].legend() ax2 = axs[r].twinx() ax2.plot( density, merge_times[r, :] / total_times[r, :], "r--", label="% Time Merging", ) ax2.set_ylabel("% Time Merging") ax2.set_ylim([0, 1]) ax2.legend() fig.set_size_inches(6, 3 * len(side_length)) fig.tight_layout() fig.savefig("out/hole_detection/merge_time.png")
[docs] def main(): test_hole_size_dict() test_detect_holes() test_runtime() test_merge_time()
if __name__ == "__main__": main()