Source code for laser_measles.components.base_tracker_state

import inspect
from collections.abc import Callable

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.figure import Figure
from pydantic import BaseModel
from pydantic import Field

from laser_measles.base import BaseLaserModel
from laser_measles.base import BasePhase
from laser_measles.base import StateArray


class BaseStateTrackerParams(BaseModel):
    """Parameters specific to the state tracker component.

    Attributes:
        filter_fn: Function to filter which nodes to include in aggregation.
        aggregation_level: Number of levels to use for aggregation (e.g., 2 for country:state:lga).
                          Use -1 to sum over all patches (default behavior).
    """

    filter_fn: Callable[[str], bool] = Field(default=lambda x: True, description="Function to filter which nodes to include in aggregation")
    aggregation_level: int = Field(default=-1, description="Number of levels to use for aggregation. Use -1 to sum over all patches")


[docs] class BaseStateTracker(BasePhase): """ Component for tracking the number in each SEIR state for each time tick. This class maintains a time series of state counts across nodes in the model. The states are dynamically generated as properties based on model.params.states (e.g., "S", "E", "I", "R"). Each state can be accessed as a property that returns a numpy array containing the time series for that state. The tracking can be done at different aggregation levels: - aggregation_level = -1: Sum over all patches (default, backward compatible) - aggregation_level >= 0: Group by geographic level and track separately Args: model: The simulation model containing nodes, states, and parameters. verbose: Whether to print verbose output during simulation. Defaults to False. params: Component-specific parameters. If None, will use default parameters. """ def __init__(self, model, verbose: bool = False, params: BaseStateTrackerParams | None = None) -> None: super().__init__(model, verbose) self.name = "StateTracker" self.params = params or BaseStateTrackerParams() self._validate_params() # Extract node IDs and create mapping for filtered nodes self.node_mapping = {} self.node_indices = [] for node_idx, node_id in enumerate(model.scenario["id"]): if self.params.filter_fn(node_id): if self.params.aggregation_level >= 0: # Create geographic grouping key group_key = ":".join(node_id.split(":")[: self.params.aggregation_level + 1]) if group_key not in self.node_mapping: self.node_mapping[group_key] = [] self.node_mapping[group_key].append(node_idx) else: self.node_indices.append(node_idx) # Initialize state tracker with appropriate shape if self.params.aggregation_level >= 0: # Shape: (num_states, num_ticks, num_groups) num_groups = len(self.node_mapping) self.group_ids = sorted(self.node_mapping.keys()) else: # Shape: (num_states, num_ticks, 1) - sum over all patches num_groups = 1 self.group_ids = ["all_patches"] self.state_tracker = StateArray( np.zeros((len(model.params.states), model.params.num_ticks, num_groups), dtype=model.patches.states.dtype), model.params.states ) # Dynamically create properties for each state for i, state in enumerate(model.params.states): setattr(self.__class__, state, property(lambda self, idx=i: self._get_state_data(idx))) def _validate_params(self) -> None: """Validate component parameters. Raises: ValueError: If aggregation_level is less than -1. """ if self.params.aggregation_level < -1: raise ValueError("aggregation_level must be at least -1") def _get_state_data(self, state_idx: int) -> np.ndarray: """Get state data for a specific state index. Args: state_idx: Index of the state to retrieve. Returns: Array of shape (num_ticks,) for aggregation_level = -1, or (num_ticks, num_groups) for aggregation_level >= 0. """ if self.params.aggregation_level == -1: # Return (num_ticks,) for backward compatibility return self.state_tracker[state_idx, :, 0] else: # Return (num_ticks, num_groups) return self.state_tracker[state_idx, :, :] def __call__(self, model, tick: int) -> None: if self.params.aggregation_level >= 0: # For each group, aggregate states from its nodes for group_idx, (_, node_indices) in enumerate(self.node_mapping.items()): # Get states for this group's nodes and sum them group_states = model.patches.states[:, node_indices].sum(axis=1) self.state_tracker[:, tick, group_idx] = group_states else: # Sum over all filtered patches (default behavior) if self.node_indices: # Use filtered nodes filtered_states = model.patches.states[:, self.node_indices].sum(axis=1) else: # Use all patches (backward compatibility) filtered_states = model.patches.states.sum(axis=1) self.state_tracker[:, tick, 0] = filtered_states def plot(self, fig: Figure | None = None): """ Plots the time series of SEIR state counts across all nodes using subplots. This function creates a separate subplot for each state, showing how the number of individuals in each state changes over time. Each state gets its own subplot for better visibility. Parameters: fig (Figure, optional): A matplotlib Figure object. If None, a new figure will be created. Yields: None: This function uses a generator to yield control back to the caller. If used directly (not as a generator), it will show the plot immediately. Example: # Use as a generator (for model.visualize()): for _ in tracker.plot(): plt.show() """ n_states = len(self.model.params.states) fig = plt.figure(figsize=(12, 3 * n_states), dpi=128) if fig is None else fig fig.suptitle("SEIR State Counts Over Time") time = np.arange(self.model.params.num_ticks) colors = ["blue", "orange", "red", "green"] # S, E, I, R for i, state in enumerate(self.model.params.states): ax = plt.subplot(n_states, 1, i + 1) color = colors[i] if i < len(colors) else "black" ax.plot(time, self._get_state_data(i), label=f"{state} (Total)", color=color, linewidth=2) ax.set_ylabel(f"Number in {state}") ax.grid(True, alpha=0.3) ax.legend() # Format y-axis with scientific notation for large numbers ax.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0)) # Only add xlabel to the bottom subplot if i == n_states - 1: ax.set_xlabel("Time (days)") plt.tight_layout() # Check if the function is being used as a generator frame = inspect.currentframe() try: yield finally: if frame: del frame def plot_combined(self, fig: Figure | None = None): """ Plots all SEIR states on a single plot for easy comparison. Parameters: fig (Figure, optional): A matplotlib Figure object. If None, a new figure will be created. Yields: None: This function uses a generator to yield control back to the caller. """ fig = plt.figure(figsize=(12, 6), dpi=128) if fig is None else fig time = np.arange(self.model.params.num_ticks) colors = ["blue", "orange", "red", "green"] # S, E, I, R linestyles = ["-", "--", "-.", ":"] for i, state in enumerate(self.model.params.states): color = colors[i] if i < len(colors) else "black" linestyle = linestyles[i] if i < len(linestyles) else "-" plt.plot(time, self._get_state_data(i), label=f"{state}", color=color, linestyle=linestyle, linewidth=2) plt.xlabel("Time (days)") plt.ylabel("Number of Individuals") plt.title("SEIR Model Dynamics") plt.legend() plt.grid(True, alpha=0.3) plt.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0)) plt.tight_layout() # Check if the function is being used as a generator frame = inspect.currentframe() try: yield finally: if frame: del frame def get_dataframe(self) -> pl.DataFrame: """Get a DataFrame of state counts over time. Returns: DataFrame with columns: - tick: Time step - state: State name (S, E, I, R, etc.) - group_id: Group identifier (if aggregated) or "all_patches" (if summed) - count: Number of individuals in this state """ data = [] for tick in range(self.model.params.num_ticks): for state_idx, state_name in enumerate(self.model.params.states): if self.params.aggregation_level >= 0: # For each group for group_idx, group_id in enumerate(self.group_ids): data.append( { "tick": tick, "state": state_name, "group_id": group_id, "count": self.state_tracker[state_idx, tick, group_idx], } ) else: # Single aggregated value data.append( {"tick": tick, "state": state_name, "group_id": "all_patches", "count": self.state_tracker[state_idx, tick, 0]} ) return pl.DataFrame(data) def _initialize(self, model: BaseLaserModel) -> None: pass