Source code for laser.measles.abm.model

"""
A class to represent the agent-based model.
"""

import numpy as np
import polars as pl
from matplotlib import pyplot as plt
from matplotlib.figure import Figure

from laser.measles.abm.base import BaseABMScenario
from laser.measles.abm.base import PatchLaserFrame
from laser.measles.abm.base import PeopleLaserFrame
from laser.measles.base import BaseLaserModel
from laser.measles.base import BaseScenario
from laser.measles.utils import StateArray

from . import components
from .params import ABMParams


[docs] class ABMModel(BaseLaserModel): """ Agent-based model for measles transmission with daily timesteps (SEIR). **Both** ``scenario`` and ``params`` are required positional arguments. There is no default constructor — omitting ``params`` raises ``TypeError``. Args: scenario (pl.DataFrame): A DataFrame containing the metapopulation patch data. Required columns: ``id`` (str), ``pop`` (int), ``lat`` (Float64), ``lon`` (Float64), ``mcv1`` (Float64). params (ABMParams): Simulation parameters including ``num_ticks``, ``seed``, and ``start_time``. This argument is **mandatory**. name (str, optional): The name of the model. Defaults to ``"abm"``. Notes: Typical usage:: from laser.measles.abm import ABMModel, ABMParams from laser.measles.abm import components params = ABMParams(num_ticks=365, seed=42) model = ABMModel(scenario=df, params=params) model.add_component(components.InfectionSeedingProcess) model.add_component(components.InfectionProcess) model.run() """ people: PeopleLaserFrame # Specify the scenario wrapper class for auto-wrapping DataFrames scenario_wrapper_class = BaseABMScenario def __init__(self, scenario: BaseABMScenario | pl.DataFrame, params: ABMParams, name: str = "abm") -> None: """ Initialize the disease model with the given scenario and parameters. Args: scenario (pl.DataFrame): A DataFrame containing the metapopulation patch data, including population, latitude, and longitude. parameters (ABMParams): A set of parameters for the model and simulations. name (str, optional): The name of the model. Defaults to "abm". Returns: None """ super().__init__(scenario, params, name) if self.params.verbose: print(f"Initializing the {name} model with {len(scenario)} patches…") # Setup patches self.setup_patches() # Setup people - initialization is done via components self.setup_people() return
[docs] def __call__(self, model, tick: int) -> None: pass
[docs] def setup_patches(self) -> None: """Setup the patches for the model.""" scenario: BaseScenario = self.scenario self.patches = PatchLaserFrame(capacity=len(scenario)) # Create the state vector for each of the patches (4, num_patches) for SEIR self.patches.add_array_property("states", shape=(len(self.params.states), len(scenario))) # S, E, I, R # Wrap the states array with StateArray for attribute access self.patches.states = StateArray(self.patches.states, state_names=self.params.states) # Start with totally susceptible population self.patches.states.S[:] = scenario["pop"] # All susceptible initially self.patches.states.E[:] = 0 # No exposed initially self.patches.states.I[:] = 0 # No infected initially self.patches.states.R[:] = 0 # No recovered initially return
[docs] def setup_people(self) -> None: """Placeholder for people - sets the data types for patch_id and susceptibility.""" self.people = PeopleLaserFrame(capacity=1) self.people.add_scalar_property("patch_id", dtype=np.uint16) # patch id self.people.add_scalar_property("state", dtype=np.uint8, default=0) # state self.people.add_scalar_property("susceptibility", dtype=np.float32, default=0) # susceptibility factor return
[docs] def initialize_people_capacity(self, capacity: int, initial_count: int = -1) -> None: """ Initialize the people LaserFrame with a new capacity while preserving all properties. This method uses the factory method from BasePeopleLaserFrame to create a new instance of the same type with the specified capacity, copying all properties from the existing instance. Args: capacity: The new capacity for the people LaserFrame """ if self.people is None: raise RuntimeError("Cannot initialize capacity: people LaserFrame is None") # Use the factory method to create a new instance with the same type and properties new_people = type(self.people).create_with_capacity(capacity, self.people, initial_count=initial_count) # Update the people laserframe self.people = new_people
[docs] def infect(self, indices: int | np.ndarray, num_infected: int | np.ndarray) -> None: """ Infect agents by moving them from Susceptible to Exposed state. This method finds the transmission component and delegates to its infect method, which handles both individual agent state updates and patch counter updates. Args: indices (int | np.ndarray): The indices of the agents to infect. num_infected (int | np.ndarray): The number of agents to infect (for API consistency). Note: In ABM, this should match the length of indices. """ if isinstance(indices, int): indices = np.array([indices]) if isinstance(num_infected, int): # For single values, create array if len(indices) != num_infected: raise ValueError(f"Number of indices ({len(indices)}) must match num_infected ({num_infected})") elif isinstance(num_infected, np.ndarray): # For arrays, sum should equal length of indices if len(indices) != num_infected.sum(): raise ValueError(f"Length of indices ({len(indices)}) must match sum of num_infected ({num_infected.sum()})") # Find the component with infect method transmission_component = None for instance in self.instances: if hasattr(instance, "infect"): if transmission_component is not None: raise RuntimeError("Multiple components found with an infect method") transmission_component = instance if transmission_component is None: raise RuntimeError("No component found with an infect method") # Delegate to the transmission component transmission_component.infect(self, indices)
[docs] def plot(self, fig: Figure | None = None): """ Plots various visualizations related to the scenario and population data. Parameters: fig (Figure, optional): A matplotlib Figure object to use for plotting. If None, a new figure will be created. Yields: None: This function uses a generator to yield control back to the caller after each plot is created. The function generates three plots: 1. A scatter plot of the scenario patches and populations. 2. A histogram of the distribution of the day of birth for the initial population. 3. A pie chart showing the distribution of update phase times. """ _fig = plt.figure(figsize=(12, 9), dpi=128) if fig is None else fig column_names = ["tick"] + [type(phase).__name__ for phase in self.phases] metrics = pl.DataFrame(self.metrics, schema=column_names) sum_columns = metrics.select([pl.sum(col).alias(col) for col in metrics.columns[1:]]).to_dict(as_series=False) # Build labels (strip "do_" if present) labels = [name[3:] if name.startswith("do_") else name for name in sum_columns.keys()] values = list(sum_columns.values()) # Plot pie chart plt.pie( values, labels=labels, autopct="%1.1f%%", startangle=140, ) plt.title("Update Phase Times") yield return
def _setup_components(self) -> None: pass def _initialize(self) -> None: """ Setup birth component registration for generic model. """ # This will re-run all instantiaion if len(self.people) != self.patches.states.sum(): if self.params.verbose: print("No vital dynamics provided. Creating a new people laserframe with the same properties as the patches.") self.prepend_component(components.NoBirthsProcess) super()._initialize() return
# Alias for backwards compatibility Model = ABMModel