Source code for laser_measles.abm.components.process_initialize_equilibrium_states
"""
Component for initializing the population in each of the model states by rough equilibrium of R0.
"""
import numpy as np
import polars as pl
from laser_measles.abm.base import PatchLaserFrame
from laser_measles.abm.base import PeopleLaserFrame
from laser_measles.abm.model import ABMModel
from laser_measles.components import BaseInitializeEquilibriumStatesParams
from laser_measles.components import BaseInitializeEquilibriumStatesProcess
class InitializeEquilibriumStatesParams(BaseInitializeEquilibriumStatesParams):
"""
Parameters for the InitializeEquilibriumStatesProcess.
"""
[docs]
class InitializeEquilibriumStatesProcess(BaseInitializeEquilibriumStatesProcess):
"""
Initialize S, R states of the population in each of the model states by rough equilibrium of R0.
This component extends the base functionality to handle both patch-level state counts
and individual agent initialization consistent with those counts.
"""
def _initialize(self, model: ABMModel) -> None:
"""
Initialize the population in each of the model states by rough equilibrium of R0.
For ABM models, this involves:
1. Calculating equilibrium patch-level state counts
2. Initializing individual agents with states consistent with patch counts
3. Assigning patch_id and susceptibility values appropriately
"""
# First, apply the base equilibrium calculation to patch states
super()._initialize(model)
# Now initialize the people LaserFrame to match the patch states
self._initialize_people_from_patches(model)
def _initialize_people_from_patches(self, model: ABMModel) -> None:
"""
Initialize individual agents to match the patch-level state counts.
"""
# Get scenario data
scenario = model.scenario
scenario_df = scenario.unwrap()
people: PeopleLaserFrame = model.people
patches: PatchLaserFrame = model.patches
num_active = len(model.people)
# Assign patch_id to each agent based on patch population
people.patch_id[:num_active] = np.array(
scenario_df.with_row_index().select(pl.col("index").repeat_by(pl.col("pop"))).explode("index")["index"].to_numpy(),
dtype=people.patch_id.dtype,
)
# Initialize all agents as susceptible first
people.state[:num_active] = model.params.states.index("S")
people.susceptibility[:num_active] = 1.0
# Now assign R state agents according to equilibrium calculation
# We need to round the equilibrium counts to integers and adjust patch states to match
current_index = 0
for patch_idx in range(len(scenario_df)):
patch_pop = scenario_df["pop"][patch_idx]
# Calculate equilibrium R count for this patch and round to integer
equilibrium_r_fraction = patch_pop * (1 - 1 / self.params.R0)
patch_r_count = int(np.round(equilibrium_r_fraction))
# Ensure we don't exceed the patch population or go negative
patch_r_count = max(0, min(patch_r_count, patch_pop))
patch_s_count = patch_pop - patch_r_count
# Update patch states to match integer counts
patches.states.S[patch_idx] = patch_s_count
patches.states.R[patch_idx] = patch_r_count
# Get indices of agents in this patch
patch_agents = np.arange(current_index, current_index + patch_pop)
# Randomly select agents to be in R state
if patch_r_count > 0:
r_agents = model.prng.choice(patch_agents, size=patch_r_count, replace=False)
people.state[r_agents] = model.params.states.index("R")
people.susceptibility[r_agents] = 0.0
current_index += patch_pop
if model.params.verbose:
total_s = np.sum(people.state == model.params.states.index("S"))
total_r = np.sum(people.state == model.params.states.index("R"))
print(f"Initialized {total_s} susceptible and {total_r} recovered agents")