import inspect
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.figure import Figure
from pydantic import BaseModel
from pydantic import Field
from laser_measles.base import BaseLaserModel
from laser_measles.base import BasePhase
from laser_measles.base import StateArray
class BaseStateTrackerParams(BaseModel):
"""Parameters specific to the state tracker component.
Attributes:
filter_fn: Function to filter which nodes to include in aggregation.
aggregation_level: Number of levels to use for aggregation (e.g., 2 for country:state:lga).
Use -1 to sum over all patches (default behavior).
"""
filter_fn: Callable[[str], bool] = Field(default=lambda x: True, description="Function to filter which nodes to include in aggregation")
aggregation_level: int = Field(default=-1, description="Number of levels to use for aggregation. Use -1 to sum over all patches")
[docs]
class BaseStateTracker(BasePhase):
"""
Component for tracking the number in each SEIR state for each time tick.
This class maintains a time series of state counts across nodes in the model.
The states are dynamically generated as properties based on model.params.states
(e.g., "S", "E", "I", "R"). Each state can be accessed as a property that returns
a numpy array containing the time series for that state.
The tracking can be done at different aggregation levels:
- aggregation_level = -1: Sum over all patches (default, backward compatible)
- aggregation_level >= 0: Group by geographic level and track separately
Args:
model: The simulation model containing nodes, states, and parameters.
verbose: Whether to print verbose output during simulation. Defaults to False.
params: Component-specific parameters. If None, will use default parameters.
"""
def __init__(self, model, verbose: bool = False, params: BaseStateTrackerParams | None = None) -> None:
super().__init__(model, verbose)
self.name = "StateTracker"
self.params = params or BaseStateTrackerParams()
self._validate_params()
# Extract node IDs and create mapping for filtered nodes
self.node_mapping = {}
self.node_indices = []
for node_idx, node_id in enumerate(model.scenario["id"]):
if self.params.filter_fn(node_id):
if self.params.aggregation_level >= 0:
# Create geographic grouping key
group_key = ":".join(node_id.split(":")[: self.params.aggregation_level + 1])
if group_key not in self.node_mapping:
self.node_mapping[group_key] = []
self.node_mapping[group_key].append(node_idx)
else:
self.node_indices.append(node_idx)
# Initialize state tracker with appropriate shape
if self.params.aggregation_level >= 0:
# Shape: (num_states, num_ticks, num_groups)
num_groups = len(self.node_mapping)
self.group_ids = sorted(self.node_mapping.keys())
else:
# Shape: (num_states, num_ticks, 1) - sum over all patches
num_groups = 1
self.group_ids = ["all_patches"]
self.state_tracker = StateArray(
np.zeros((len(model.params.states), model.params.num_ticks, num_groups), dtype=model.patches.states.dtype), model.params.states
)
# Dynamically create properties for each state
for i, state in enumerate(model.params.states):
setattr(self.__class__, state, property(lambda self, idx=i: self._get_state_data(idx)))
def _validate_params(self) -> None:
"""Validate component parameters.
Raises:
ValueError: If aggregation_level is less than -1.
"""
if self.params.aggregation_level < -1:
raise ValueError("aggregation_level must be at least -1")
def _get_state_data(self, state_idx: int) -> np.ndarray:
"""Get state data for a specific state index.
Args:
state_idx: Index of the state to retrieve.
Returns:
Array of shape (num_ticks,) for aggregation_level = -1,
or (num_ticks, num_groups) for aggregation_level >= 0.
"""
if self.params.aggregation_level == -1:
# Return (num_ticks,) for backward compatibility
return self.state_tracker[state_idx, :, 0]
else:
# Return (num_ticks, num_groups)
return self.state_tracker[state_idx, :, :]
def __call__(self, model, tick: int) -> None:
if self.params.aggregation_level >= 0:
# For each group, aggregate states from its nodes
for group_idx, (_, node_indices) in enumerate(self.node_mapping.items()):
# Get states for this group's nodes and sum them
group_states = model.patches.states[:, node_indices].sum(axis=1)
self.state_tracker[:, tick, group_idx] = group_states
else:
# Sum over all filtered patches (default behavior)
if self.node_indices:
# Use filtered nodes
filtered_states = model.patches.states[:, self.node_indices].sum(axis=1)
else:
# Use all patches (backward compatibility)
filtered_states = model.patches.states.sum(axis=1)
self.state_tracker[:, tick, 0] = filtered_states
def plot(self, fig: Figure | None = None):
"""
Plots the time series of SEIR state counts across all nodes using subplots.
This function creates a separate subplot for each state, showing how the number of individuals
in each state changes over time. Each state gets its own subplot for better visibility.
Parameters:
fig (Figure, optional): A matplotlib Figure object. If None, a new figure will be created.
Yields:
None: This function uses a generator to yield control back to the caller.
If used directly (not as a generator), it will show the plot immediately.
Example:
# Use as a generator (for model.visualize()):
for _ in tracker.plot():
plt.show()
"""
n_states = len(self.model.params.states)
fig = plt.figure(figsize=(12, 3 * n_states), dpi=128) if fig is None else fig
fig.suptitle("SEIR State Counts Over Time")
time = np.arange(self.model.params.num_ticks)
colors = ["blue", "orange", "red", "green"] # S, E, I, R
for i, state in enumerate(self.model.params.states):
ax = plt.subplot(n_states, 1, i + 1)
color = colors[i] if i < len(colors) else "black"
ax.plot(time, self._get_state_data(i), label=f"{state} (Total)", color=color, linewidth=2)
ax.set_ylabel(f"Number in {state}")
ax.grid(True, alpha=0.3)
ax.legend()
# Format y-axis with scientific notation for large numbers
ax.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0))
# Only add xlabel to the bottom subplot
if i == n_states - 1:
ax.set_xlabel("Time (days)")
plt.tight_layout()
# Check if the function is being used as a generator
frame = inspect.currentframe()
try:
yield
finally:
if frame:
del frame
def plot_combined(self, fig: Figure | None = None):
"""
Plots all SEIR states on a single plot for easy comparison.
Parameters:
fig (Figure, optional): A matplotlib Figure object. If None, a new figure will be created.
Yields:
None: This function uses a generator to yield control back to the caller.
"""
fig = plt.figure(figsize=(12, 6), dpi=128) if fig is None else fig
time = np.arange(self.model.params.num_ticks)
colors = ["blue", "orange", "red", "green"] # S, E, I, R
linestyles = ["-", "--", "-.", ":"]
for i, state in enumerate(self.model.params.states):
color = colors[i] if i < len(colors) else "black"
linestyle = linestyles[i] if i < len(linestyles) else "-"
plt.plot(time, self._get_state_data(i), label=f"{state}", color=color, linestyle=linestyle, linewidth=2)
plt.xlabel("Time (days)")
plt.ylabel("Number of Individuals")
plt.title("SEIR Model Dynamics")
plt.legend()
plt.grid(True, alpha=0.3)
plt.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0))
plt.tight_layout()
# Check if the function is being used as a generator
frame = inspect.currentframe()
try:
yield
finally:
if frame:
del frame
def get_dataframe(self) -> pl.DataFrame:
"""Get a DataFrame of state counts over time.
Returns:
DataFrame with columns:
- tick: Time step
- state: State name (S, E, I, R, etc.)
- group_id: Group identifier (if aggregated) or "all_patches" (if summed)
- count: Number of individuals in this state
"""
data = []
for tick in range(self.model.params.num_ticks):
for state_idx, state_name in enumerate(self.model.params.states):
if self.params.aggregation_level >= 0:
# For each group
for group_idx, group_id in enumerate(self.group_ids):
data.append(
{
"tick": tick,
"state": state_name,
"group_id": group_id,
"count": self.state_tracker[state_idx, tick, group_idx],
}
)
else:
# Single aggregated value
data.append(
{"tick": tick, "state": state_name, "group_id": "all_patches", "count": self.state_tracker[state_idx, tick, 0]}
)
return pl.DataFrame(data)
def _initialize(self, model: BaseLaserModel) -> None:
pass