Source code for laser.measles.abm.components.process_initialize_equilibrium_states

"""
Component for initializing the population in each of the model states by rough equilibrium of R0.
"""

import numpy as np
import polars as pl

from laser.measles.abm.base import PatchLaserFrame
from laser.measles.abm.base import PeopleLaserFrame
from laser.measles.abm.model import ABMModel
from laser.measles.components import BaseInitializeEquilibriumStatesParams
from laser.measles.components import BaseInitializeEquilibriumStatesProcess


class InitializeEquilibriumStatesParams(BaseInitializeEquilibriumStatesParams):
    """
    Parameters for the InitializeEquilibriumStatesProcess.
    """


[docs] class InitializeEquilibriumStatesProcess(BaseInitializeEquilibriumStatesProcess): """ Initialize S, R states of the population in each of the model states by rough equilibrium of R0. This component extends the base functionality to handle both patch-level state counts and individual agent initialization consistent with those counts. """ def _initialize(self, model: ABMModel) -> None: """ Initialize the population in each of the model states by rough equilibrium of R0. For ABM models, this involves: 1. Calculating equilibrium patch-level state counts 2. Initializing individual agents with states consistent with patch counts 3. Assigning patch_id and susceptibility values appropriately """ # First, apply the base equilibrium calculation to patch states super()._initialize(model) # Now initialize the people LaserFrame to match the patch states self._initialize_people_from_patches(model) def _initialize_people_from_patches(self, model: ABMModel) -> None: """ Initialize individual agents to match the patch-level state counts. """ # Get scenario data scenario = model.scenario scenario_df = scenario.unwrap() people: PeopleLaserFrame = model.people patches: PatchLaserFrame = model.patches num_active = len(model.people) # Assign patch_id to each agent based on patch population people.patch_id[:num_active] = np.array( scenario_df.with_row_index().select(pl.col("index").repeat_by(pl.col("pop"))).explode("index")["index"].to_numpy(), dtype=people.patch_id.dtype, ) # Initialize all agents as susceptible first people.state[:num_active] = model.params.states.index("S") people.susceptibility[:num_active] = 1.0 # Now assign R state agents according to equilibrium calculation # We need to round the equilibrium counts to integers and adjust patch states to match current_index = 0 for patch_idx in range(len(scenario_df)): patch_pop = scenario_df["pop"][patch_idx] # Calculate equilibrium R count for this patch and round to integer equilibrium_r_fraction = patch_pop * (1 - 1 / self.params.R0) patch_r_count = int(np.round(equilibrium_r_fraction)) # Ensure we don't exceed the patch population or go negative patch_r_count = max(0, min(patch_r_count, patch_pop)) patch_s_count = patch_pop - patch_r_count # Update patch states to match integer counts patches.states.S[patch_idx] = patch_s_count patches.states.R[patch_idx] = patch_r_count # Get indices of agents in this patch patch_agents = np.arange(current_index, current_index + patch_pop) # Randomly select agents to be in R state if patch_r_count > 0: r_agents = model.prng.choice(patch_agents, size=patch_r_count, replace=False) people.state[r_agents] = model.params.states.index("R") people.susceptibility[r_agents] = 0.0 current_index += patch_pop if model.params.verbose: total_s = np.sum(people.state == model.params.states.index("S")) total_r = np.sum(people.state == model.params.states.index("R")) print(f"Initialized {total_s} susceptible and {total_r} recovered agents")