Source code for laser_measles.components.base_case_surveillance

"""
Component for tracking case surveillance
"""

from collections.abc import Callable

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.figure import Figure
from pydantic import BaseModel
from pydantic import Field

from laser_measles.base import BaseLaserModel
from laser_measles.base import BasePhase
from laser_measles.utils import cast_type


class BaseCaseSurveillanceParams(BaseModel):
    """Parameters specific to the case surveillance component.

    Attributes:
        detection_rate: Probability of detecting an infected case.
        filter_fn: Function to filter which nodes to include in aggregation.
        aggregate_cases: Whether to aggregate cases by geographic level.
        aggregation_level: Number of levels to use for aggregation (e.g., 2 for country:state:lga).
    """

    detection_rate: float = Field(default=0.1, description="Probability of detecting an infected case", ge=0.0, le=1.0)
    filter_fn: Callable[[str], bool] = Field(default=lambda x: True, description="Function to filter which nodes to include in aggregation")
    aggregation_level: int = Field(default=-1, description="Number of levels to use for aggregation (e.g., 2 for country:state:lga)")


[docs] class BaseCaseSurveillanceTracker(BasePhase): """Component for tracking detected cases in the model. This component: 1. Simulates case detection based on a detection rate 2. Optionally tracks detected cases aggregated by geographic level 3. Uses a filter function to determine which nodes to include Case detection is simulated using a binomial distribution. Cases can be tracked at individual node level or aggregated by geographic level. Uses a filter function to determine which nodes to include. Note that a single infection can be detected multiple times. Args: model: The simulation model containing nodes, states, and parameters. verbose: Whether to print verbose output during simulation. Defaults to False. params: Component-specific parameters. If None, will use default parameters. """ def __init__(self, model, verbose: bool = False, params: BaseCaseSurveillanceParams | None = None) -> None: super().__init__(model, verbose) self.params = params or BaseCaseSurveillanceParams() self._validate_params() # Extract node IDs and create mapping for filtered nodes self.node_mapping = {} self.node_indices = [] for node_idx, node_id in enumerate(model.scenario["id"]): if self.params.filter_fn(node_id): if self.params.aggregation_level >= 0: # Create geographic grouping key group_key = ":".join(node_id.split(":")[: self.params.aggregation_level + 1]) if group_key not in self.node_mapping: self.node_mapping[group_key] = [] self.node_mapping[group_key].append(node_idx) else: group_key = node_id self.node_mapping[group_key] = [node_idx] # Initialize reported cases tracker # For aggregated cases: nticks x num_groups self.reported_cases = np.zeros((len(self.node_mapping), model.params.num_ticks), dtype=model.patches.states.dtype) # Store group IDs in order self.group_ids = sorted(self.node_mapping.keys()) def _validate_params(self) -> None: """Validate component parameters. Raises: ValueError: If aggregation_level is less than 1. """ if self.params.aggregation_level < -1: raise ValueError("aggregation_level must be at least -1") def __call__(self, model, tick: int) -> None: """Process case surveillance for the current tick. Args: model: The simulation model. tick: Current time step. """ # Get current infected cases infected = model.patches.states.I # Infected state is index 1 # For each group, aggregate detected cases from its nodes for group_idx, (_, node_indices) in enumerate(self.node_mapping.items()): # Get infected cases for this group's nodes group_infected = infected[node_indices] if self.params.detection_rate < 1: # Simulate case detection using binomial distribution detected_cases = cast_type(model.prng.binomial(n=group_infected, p=self.params.detection_rate), model.patches.states.dtype) else: # Otherwise report infections detected_cases = cast_type(group_infected, model.patches.states.dtype) # Store total detected cases for this group self.reported_cases[group_idx, tick] = detected_cases.sum() def get_dataframe(self) -> pl.DataFrame: """Get a DataFrame of reported cases over time. Returns: DataFrame with columns: - tick: Time step - group_id: Group identifier (if aggregated) or node_id (if not aggregated) - cases: Number of reported cases """ # Create a list to store the data data = [] # For each tick and group, add the reported cases for tick in range(self.model.params.num_ticks): for group_idx, group_id in enumerate(self.group_ids): data.append({"tick": tick, "group_id": group_id, "cases": self.reported_cases[group_idx, tick]}) # Create DataFrame return pl.DataFrame(data) def initialize(self, model: BaseLaserModel) -> None: pass def plot(self, fig: Figure | None = None): """Create a heatmap visualization of log(cases+1) over time. Args: fig: Existing figure to plot on. If None, a new figure will be created. Yields: The figure containing the heatmap visualization. """ # Get the case data df = self.get_dataframe() # Convert to pandas for easier plotting pdf = df.to_pandas() # Create pivot table for heatmap pivot_df = pdf.pivot(index="group_id", columns="tick", values="cases") # Create figure and axis if not provided if fig is None: fig, ax = plt.subplots(figsize=(12, 8)) else: ax = fig.gca() # Create heatmap with log scale heatmap_data = np.log1p(pivot_df.values) im = ax.imshow(heatmap_data, aspect="auto", cmap="viridis") cbar = fig.colorbar(im, ax=ax) cbar.set_label("log(cases + 1)") # Set axis ticks and labels ax.set_xticks(np.arange(pivot_df.shape[1])) ax.set_xticklabels(pivot_df.columns) ax.set_yticks(np.arange(pivot_df.shape[0])) ax.set_yticklabels(pivot_df.index) # Customize plot ax.set_title("Log Cases Heatmap") ax.set_xlabel("Time Step") ax.set_ylabel("Location ID") # Rotate x-axis labels for better readability plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Adjust layout to prevent label cutoff plt.tight_layout() yield fig