Source code for laser_measles.abm.components.process_infection_seeding

"""
Component for seeding initial infections in the compartmental model.

This component allows initialization of infections in specific patches or automatically
selects the largest patch by population for seeding.
"""

import numpy as np
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator

from laser_measles.abm.base import PatchLaserFrame
from laser_measles.abm.base import PeopleLaserFrame
from laser_measles.base import BaseComponent
from laser_measles.base import BaseLaserModel


class InfectionSeedingParams(BaseModel):
    """Parameters for the infection seeding component."""

    num_infections: int = Field(default=1, description="Default number of infections to seed", ge=1)
    target_patches: list[str] | None = Field(default=None, description="List of specific patch IDs to seed")
    infections_per_patch: (
        int | list[int]
    ) | None = Field(default=None, description="Number of infections per patch (single int or list matching target_patches)")
    use_largest_patch: bool = Field(default=True, description="Whether to seed the largest patch by default")

    @field_validator("infections_per_patch")
    @classmethod
    def validate_infections_per_patch(cls, v, info):
        """Validate that infections_per_patch matches target_patches length if both provided."""
        if v is not None and "target_patches" in info.data and info.data["target_patches"] is not None:
            if isinstance(v, list):
                if len(v) != len(info.data["target_patches"]):
                    raise ValueError("Length of infections_per_patch must match length of target_patches")
                if any(x < 1 for x in v):
                    raise ValueError("All values in infections_per_patch must be >= 1")
            elif isinstance(v, int):
                if v < 1:
                    raise ValueError("infections_per_patch must be >= 1")
        return v

    @field_validator("target_patches")
    @classmethod
    def validate_target_patches(cls, v):
        """Validate target_patches format."""
        if v is not None:
            for patch_id in v:
                if not isinstance(patch_id, str) or not patch_id.strip():
                    raise ValueError("All target_patches must be non-empty strings")
        return v


[docs] class InfectionSeedingProcess(BaseComponent): """ Component for seeding initial infections in the compartmental model. This component initializes infections by moving individuals from the Susceptible (S) compartment to the Infected (I) compartment. It can either: 1. Automatically seed the patch with the largest population (default) 2. Seed specific patches provided by the user The seeding occurs during initialize() before the simulation begins. Parameters ---------- model : BaseLaserModel The compartmental model instance verbose : bool, default=False Whether to print verbose output during initialization params : Optional[InfectionSeedingParams], default=None Component-specific parameters. If None, will use default parameters Examples -------- # Seed 1 infection in largest patch (default) seeding_params = InfectionSeedingParams() # Seed 5 infections in largest patch seeding_params = InfectionSeedingParams(num_infections=5) # Seed specific patches with same number of infections seeding_params = InfectionSeedingParams( target_patches=["nigeria:kano:kano:A0001", "nigeria:kano:kano:A0002"], infections_per_patch=3 ) # Seed specific patches with different numbers of infections seeding_params = InfectionSeedingParams( target_patches=["nigeria:kano:kano:A0001", "nigeria:kano:kano:A0002"], infections_per_patch=[5, 2] ) """ def __init__(self, model: BaseLaserModel, verbose: bool = False, params: InfectionSeedingParams | None = None) -> None: super().__init__(model, verbose) self.params = params or InfectionSeedingParams() self._validate_params() def _validate_params(self) -> None: """Validate component parameters.""" if self.params.target_patches is None and not self.params.use_largest_patch: raise ValueError("Either target_patches must be provided or use_largest_patch must be True") def _initialize(self, model: BaseLaserModel) -> None: """ Initialize infections by seeding susceptible individuals. This method is called once during model initialization, after equilibrium states are set up but before simulation begins. """ if self.verbose: print("Initializing infection seeding...") # Get patch information from the scenario scenario_df = model.scenario.unwrap() patch_ids = scenario_df["id"].to_list() populations = scenario_df["pop"].to_list() # Determine which patches to seed and how many infections per patch if self.params.target_patches is None: # Use largest patch by default target_patches, infections_per_patch = self._get_largest_patch_seeding(patch_ids, populations) else: # Use specified patches target_patches, infections_per_patch = self._get_specified_patch_seeding() # Validate that target patches exist in the model self._validate_patches_exist(target_patches, patch_ids) # Perform the seeding total_seeded = self._seed_infections(model, target_patches, infections_per_patch, patch_ids) if self.verbose: print(f"Successfully seeded {total_seeded} infections across {len(target_patches)} patches") def _get_largest_patch_seeding(self, patch_ids: list[str], populations: list[int]) -> tuple[list[str], list[int]]: """Get the largest patch for seeding.""" max_pop_idx = np.argmax(populations) largest_patch = patch_ids[max_pop_idx] if self.verbose: print(f"Selected largest patch: {largest_patch} (population: {populations[max_pop_idx]:,})") return [largest_patch], [self.params.num_infections] def _get_specified_patch_seeding(self) -> tuple[list[str], list[int]]: """Get specified patches and infection counts.""" target_patches = self.params.target_patches.copy() if self.params.target_patches is not None else [] if self.params.infections_per_patch is None: # Use default num_infections for all patches infections_per_patch = [self.params.num_infections] * len(target_patches) elif isinstance(self.params.infections_per_patch, int): # Use same number for all patches infections_per_patch = [self.params.infections_per_patch] * len(target_patches) else: # Use specified list infections_per_patch = self.params.infections_per_patch.copy() return target_patches, infections_per_patch def _validate_patches_exist(self, target_patches: list[str], patch_ids: list[str]) -> None: """Validate that all target patches exist in the model.""" missing_patches = [p for p in target_patches if p not in patch_ids] if missing_patches: raise ValueError(f"Target patches not found in model: {missing_patches}") def _seed_infections( self, model: BaseLaserModel, target_patches: list[str], infections_per_patch: list[int], patch_ids: list[str] ) -> int: """Seed infections in the specified patches.""" if not hasattr(model, "people") or model.people is None: raise RuntimeError("Model does not have people attribute or it is None") if not hasattr(model, "patches") or model.patches is None: raise RuntimeError("Model does not have patches attribute or it is None") people: PeopleLaserFrame = model.people patches: PatchLaserFrame = model.patches total_seeded = 0 num_active = len(model.people) for patch_id, num_infections in zip(target_patches, infections_per_patch, strict=False): # Find patch index patch_idx = patch_ids.index(patch_id) # Get current susceptible population current_susceptible = int(patches.states.S[patch_idx]) # Determine actual number of infections to seed (limited by susceptible population) actual_infections = min(num_infections, current_susceptible) if actual_infections < num_infections: if self.verbose: print( f"Warning: Patch {patch_id} has only {current_susceptible} susceptible individuals, " f"seeding {actual_infections} instead of {num_infections}" ) if actual_infections > 0: idx = np.where( np.logical_and(people.patch_id[:num_active] == patch_idx, people.state[:num_active] == model.params.states.index("S")) )[0] # idx = model.prng.choice(idx, size=actual_infections, replace=False) model.prng.shuffle(idx) idx = idx[:actual_infections] flag = 0 for instance in model.instances: if hasattr(instance, "infect"): assert np.all(idx < num_active), "Index out of bounds" instance.infect(model, idx) flag += 1 if flag == 0: raise RuntimeError("No instance found with an infect method") elif flag > 1: raise RuntimeError("Multiple instances found with an infect method") # Patch states are now updated by the component's infect method total_seeded += actual_infections if self.verbose: print(f"Seeded {actual_infections} infections in patch {patch_id}") else: if self.verbose: print(f"Warning: No susceptible individuals in patch {patch_id}, skipping") return total_seeded