Source code for laser_measles.components.base_case_surveillance
"""
Component for tracking case surveillance
"""
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.figure import Figure
from pydantic import BaseModel
from pydantic import Field
from laser_measles.base import BaseLaserModel
from laser_measles.base import BasePhase
from laser_measles.utils import cast_type
class BaseCaseSurveillanceParams(BaseModel):
    """Parameters specific to the case surveillance component.
    Attributes:
        detection_rate: Probability of detecting an infected case.
        filter_fn: Function to filter which nodes to include in aggregation.
        aggregate_cases: Whether to aggregate cases by geographic level.
        aggregation_level: Number of levels to use for aggregation (e.g., 2 for country:state:lga).
    """
    detection_rate: float = Field(default=0.1, description="Probability of detecting an infected case", ge=0.0, le=1.0)
    filter_fn: Callable[[str], bool] = Field(default=lambda x: True, description="Function to filter which nodes to include in aggregation")
    aggregation_level: int = Field(default=-1, description="Number of levels to use for aggregation (e.g., 2 for country:state:lga)")
[docs]
class BaseCaseSurveillanceTracker(BasePhase):
    """Component for tracking detected cases in the model.
    This component:
    1. Simulates case detection based on a detection rate
    2. Optionally tracks detected cases aggregated by geographic level
    3. Uses a filter function to determine which nodes to include
    Case detection is simulated using a binomial distribution. Cases can be tracked
    at individual node level or aggregated by geographic level. Uses a filter function
    to determine which nodes to include. Note that a single infection can be detected multiple times.
    Args:
        model: The simulation model containing nodes, states, and parameters.
        verbose: Whether to print verbose output during simulation. Defaults to False.
        params: Component-specific parameters. If None, will use default parameters.
    """
    def __init__(self, model, verbose: bool = False, params: BaseCaseSurveillanceParams | None = None) -> None:
        super().__init__(model, verbose)
        self.params = params or BaseCaseSurveillanceParams()
        self._validate_params()
        # Extract node IDs and create mapping for filtered nodes
        self.node_mapping = {}
        self.node_indices = []
        for node_idx, node_id in enumerate(model.scenario["id"]):
            if self.params.filter_fn(node_id):
                if self.params.aggregation_level >= 0:
                    # Create geographic grouping key
                    group_key = ":".join(node_id.split(":")[: self.params.aggregation_level + 1])
                    if group_key not in self.node_mapping:
                        self.node_mapping[group_key] = []
                    self.node_mapping[group_key].append(node_idx)
                else:
                    group_key = node_id
                    self.node_mapping[group_key] = [node_idx]
        # Initialize reported cases tracker
        # For aggregated cases: nticks x num_groups
        self.reported_cases = np.zeros((len(self.node_mapping), model.params.num_ticks), dtype=model.patches.states.dtype)
        # Store group IDs in order
        self.group_ids = sorted(self.node_mapping.keys())
    def _validate_params(self) -> None:
        """Validate component parameters.
        Raises:
            ValueError: If aggregation_level is less than 1.
        """
        if self.params.aggregation_level < -1:
            raise ValueError("aggregation_level must be at least -1")
    def __call__(self, model, tick: int) -> None:
        """Process case surveillance for the current tick.
        Args:
            model: The simulation model.
            tick: Current time step.
        """
        # Get current infected cases
        infected = model.patches.states.I  # Infected state is index 1
        # For each group, aggregate detected cases from its nodes
        for group_idx, (_, node_indices) in enumerate(self.node_mapping.items()):
            # Get infected cases for this group's nodes
            group_infected = infected[node_indices]
            if self.params.detection_rate < 1:
                # Simulate case detection using binomial distribution
                detected_cases = cast_type(model.prng.binomial(n=group_infected, p=self.params.detection_rate), model.patches.states.dtype)
            else:
                # Otherwise report infections
                detected_cases = cast_type(group_infected, model.patches.states.dtype)
            # Store total detected cases for this group
            self.reported_cases[group_idx, tick] = detected_cases.sum()
    def get_dataframe(self) -> pl.DataFrame:
        """Get a DataFrame of reported cases over time.
        Returns:
            DataFrame with columns:
                - tick: Time step
                - group_id: Group identifier (if aggregated) or node_id (if not aggregated)
                - cases: Number of reported cases
        """
        # Create a list to store the data
        data = []
        # For each tick and group, add the reported cases
        for tick in range(self.model.params.num_ticks):
            for group_idx, group_id in enumerate(self.group_ids):
                data.append({"tick": tick, "group_id": group_id, "cases": self.reported_cases[group_idx, tick]})
        # Create DataFrame
        return pl.DataFrame(data)
    def initialize(self, model: BaseLaserModel) -> None:
        pass
    def plot(self, fig: Figure | None = None):
        """Create a heatmap visualization of log(cases+1) over time.
        Args:
            fig: Existing figure to plot on. If None, a new figure will be created.
        Yields:
            The figure containing the heatmap visualization.
        """
        # Get the case data
        df = self.get_dataframe()
        # Convert to pandas for easier plotting
        pdf = df.to_pandas()
        # Create pivot table for heatmap
        pivot_df = pdf.pivot(index="group_id", columns="tick", values="cases")
        # Create figure and axis if not provided
        if fig is None:
            fig, ax = plt.subplots(figsize=(12, 8))
        else:
            ax = fig.gca()
        # Create heatmap with log scale
        heatmap_data = np.log1p(pivot_df.values)
        im = ax.imshow(heatmap_data, aspect="auto", cmap="viridis")
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label("log(cases + 1)")
        # Set axis ticks and labels
        ax.set_xticks(np.arange(pivot_df.shape[1]))
        ax.set_xticklabels(pivot_df.columns)
        ax.set_yticks(np.arange(pivot_df.shape[0]))
        ax.set_yticklabels(pivot_df.index)
        # Customize plot
        ax.set_title("Log Cases Heatmap")
        ax.set_xlabel("Time Step")
        ax.set_ylabel("Location ID")
        # Rotate x-axis labels for better readability
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
        # Adjust layout to prevent label cutoff
        plt.tight_layout()
        yield fig