Source code for laser_measles.abm.components.process_initialize_equilibrium_states

"""
Component for initializing the population in each of the model states by rough equilibrium of R0.
"""

import numpy as np
import polars as pl

from laser_measles.abm.base import PatchLaserFrame
from laser_measles.abm.base import PeopleLaserFrame
from laser_measles.abm.model import ABMModel
from laser_measles.components import BaseInitializeEquilibriumStatesParams
from laser_measles.components import BaseInitializeEquilibriumStatesProcess


class InitializeEquilibriumStatesParams(BaseInitializeEquilibriumStatesParams):
    """
    Parameters for the InitializeEquilibriumStatesProcess.
    """


[docs] class InitializeEquilibriumStatesProcess(BaseInitializeEquilibriumStatesProcess): """ Initialize S, R states of the population in each of the model states by rough equilibrium of R0. This component extends the base functionality to handle both patch-level state counts and individual agent initialization consistent with those counts. """ def _initialize(self, model: ABMModel) -> None: """ Initialize the population in each of the model states by rough equilibrium of R0. For ABM models, this involves: 1. Calculating equilibrium patch-level state counts 2. Initializing individual agents with states consistent with patch counts 3. Assigning patch_id and susceptibility values appropriately """ # First, apply the base equilibrium calculation to patch states super()._initialize(model) # Now initialize the people LaserFrame to match the patch states self._initialize_people_from_patches(model) def _initialize_people_from_patches(self, model: ABMModel) -> None: """ Initialize individual agents to match the patch-level state counts. """ # Get scenario data scenario = model.scenario scenario_df = scenario.unwrap() people: PeopleLaserFrame = model.people patches: PatchLaserFrame = model.patches num_active = len(model.people) # Assign patch_id to each agent based on patch population people.patch_id[:num_active] = np.array( scenario_df.with_row_index().select(pl.col("index").repeat_by(pl.col("pop"))).explode("index")["index"].to_numpy(), dtype=people.patch_id.dtype, ) # Initialize all agents as susceptible first people.state[:num_active] = model.params.states.index("S") people.susceptibility[:num_active] = 1.0 # Now assign R state agents according to equilibrium calculation # We need to round the equilibrium counts to integers and adjust patch states to match current_index = 0 for patch_idx in range(len(scenario_df)): patch_pop = scenario_df["pop"][patch_idx] # Calculate equilibrium R count for this patch and round to integer equilibrium_r_fraction = patch_pop * (1 - 1 / self.params.R0) patch_r_count = int(np.round(equilibrium_r_fraction)) # Ensure we don't exceed the patch population or go negative patch_r_count = max(0, min(patch_r_count, patch_pop)) patch_s_count = patch_pop - patch_r_count # Update patch states to match integer counts patches.states.S[patch_idx] = patch_s_count patches.states.R[patch_idx] = patch_r_count # Get indices of agents in this patch patch_agents = np.arange(current_index, current_index + patch_pop) # Randomly select agents to be in R state if patch_r_count > 0: r_agents = model.prng.choice(patch_agents, size=patch_r_count, replace=False) people.state[r_agents] = model.params.states.index("R") people.susceptibility[r_agents] = 0.0 current_index += patch_pop if model.params.verbose: total_s = np.sum(people.state == model.params.states.index("S")) total_r = np.sum(people.state == model.params.states.index("R")) print(f"Initialized {total_s} susceptible and {total_r} recovered agents")