Source code for laser.measles.abm.components.process_importation_pressure

from collections.abc import Sequence

import numpy as np
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator

from laser.measles.abm.model import ABMModel
from laser.measles.base import BasePhase
from laser.measles.utils import cast_type


class ImportationPressureParams(BaseModel):
    """Parameters for the importation pressure component.

    Importation pressure simulates external case introductions from outside the
    modeled population (e.g., international travel, cross-border movement),
    seeding infections into susceptible individuals each tick.

    Note:
        Imported cases are drawn from the **total** population
        (``binomial(population, rate)``), but only **susceptible** individuals
        are actually infected. When the number of drawn imports exceeds the
        available susceptibles in a patch, the excess is discarded. This means
        the *effective* importation rate decreases as population immunity
        increases.

    Attributes:
        crude_importation_rate: Yearly importation rate per 1,000 **total**
            population (not per susceptible). Three forms are accepted:

            - **float** (scalar): uniform rate applied to every patch.
              ``0.0`` disables importation entirely.

            - **list / tuple / numpy array** (sequence): per-patch rates in the
              same order as ``model.scenario`` rows. Length must equal the number
              of patches.

            - **dict[str, float]** (sparse patch override): maps patch string ids
              (values from ``model.scenario["id"]``, e.g. ``"n_0_0"``, ``"n_2_2"``)
              to rates. Patches absent from the dict receive a rate of **0.0** —
              they do *not* inherit the scalar default of 1.0.

        importation_start: Day on which importation begins (inclusive). The model
            converts this to ticks internally, so always supply days regardless of
            model type. Default ``0`` starts importation at the first tick.

        importation_end: Day on which importation ends (inclusive). Use ``-1``
            (default) to keep importation active for the full simulation. Must be
            greater than ``importation_start`` when not ``-1``.

    Examples:
        Uniform low background pressure across all patches::

            params = ImportationPressureParams(crude_importation_rate=0.05)

        Disable importation entirely::

            params = ImportationPressureParams(crude_importation_rate=0.0)

        Per-patch sequence (one entry per patch, aligned to scenario row order)::

            # 25-patch model; patch at row index 12 is the metro hub
            rates = [0.02] * 25
            rates[12] = 0.5
            params = ImportationPressureParams(crude_importation_rate=rates)

        Sparse dict — only named patches receive importation; all others get 0.0::

            # Use string ids from model.scenario["id"], e.g. "n_0_0", "n_2_2"
            params = ImportationPressureParams(
                crude_importation_rate={"n_2_2": 0.5, "n_0_0": 0.1},
            )

        Numpy array input (accepted and converted to list internally)::

            import numpy as np
            params = ImportationPressureParams(
                crude_importation_rate=np.array([0.01, 0.05, 0.01, 0.01, 0.01])
            )

        Time-windowed importation active only during the first year (days 0-364)::

            params = ImportationPressureParams(
                crude_importation_rate=0.1,
                importation_start=0,
                importation_end=364,
            )

        Metro-only importation for the first year, then stop::

            params = ImportationPressureParams(
                crude_importation_rate={"n_2_2": 2.0},
                importation_start=0,
                importation_end=364,
            )
    """

    crude_importation_rate: float | list[float] | dict[str, float] = Field(
        default=1.0,
        description=(
            "Yearly crude importation rate per 1k population. "
            "Scalar: uniform across all patches. "
            "Sequence (list/tuple/ndarray): per-patch rates aligned to scenario row order, length must equal n_patches. "
            "Dict[str, float]: sparse override keyed by patch string id (from model.scenario['id']); "
            "omitted patches default to 0.0, not 1.0."
        ),
    )
    importation_start: int = Field(
        default=0,
        description="Day on which importation begins (inclusive). Converted to ticks internally.",
        ge=0,
    )
    importation_end: int = Field(
        default=-1,
        description="Day on which importation ends (inclusive). Use -1 (default) to run for the full simulation.",
        ge=-1,
    )

    @field_validator("importation_end")
    @classmethod
    def validate_importation_end(cls, v, info):
        if v != -1:
            start = info.data.get("importation_start", 0)
            if v <= start:
                raise ValueError("importation_end must be greater than importation_start")
        return v

    @field_validator("crude_importation_rate")
    @classmethod
    def validate_importation_rate(cls, v):
        if isinstance(v, (int, float)):
            if v < 0:
                raise ValueError("crude_importation_rate must be >= 0")
        elif isinstance(v, np.ndarray):
            if np.any(v < 0):
                raise ValueError("All crude_importation_rate values must be >= 0")
            return v.tolist()
        elif isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
            if any(x < 0 for x in v):
                raise ValueError("All crude_importation_rate values must be >= 0")
            return list(v)
        elif isinstance(v, dict):
            if any(x < 0 for x in v.values()):
                raise ValueError("All crude_importation_rate values must be >= 0")
        else:
            raise TypeError("crude_importation_rate must be a float, sequence of floats, or dict[str, float]")
        return v


[docs] class ImportationPressureProcess(BasePhase): """ Component for simulating the importation pressure in the model. This component handles the simulation of disease importation into the population. It processes: - Importation of cases based on crude importation rate - Time-windowed importation (start/end times) - Population updates: Moves individuals from susceptible to infected state Parameters ---------- model : object The simulation model containing nodes, states, and parameters verbose : bool, default=False Whether to print verbose output during simulation params : Optional[ImportationPressureParams], default=None Component-specific parameters. If None, will use default parameters Notes ----- - Importation rates are calculated per year - Importation is limited to the susceptible population - All state counts are ensured to be non-negative """ def __init__(self, model, verbose: bool = False, params: ImportationPressureParams | None = None) -> None: super().__init__(model, verbose) self.params = params or ImportationPressureParams() self.patch_rates_per_year_per_1k: np.ndarray | None = None
[docs] def __call__(self, model, tick: int) -> None: if tick < (self.params.importation_start // model.params.time_step_days) or ( self.params.importation_end != -1 and tick > (self.params.importation_end // model.params.time_step_days) ): return if self.patch_rates_per_year_per_1k is None: raise RuntimeError("ImportationPressureProcess not initialized") states = model.patches.states population = states.sum(axis=0, dtype=np.int64) p = self.patch_rates_per_year_per_1k / 365.0 / 1000.0 p = np.clip(p, 0.0, 1.0) imported_cases = model.prng.binomial(population, p) imported_cases = cast_type(imported_cases, states.dtype) infection_component = None count = 0 for instance in model.instances: if hasattr(instance, "infect"): infection_component = instance count += 1 if count != 1: raise RuntimeError("Issue with the infection component") susceptible_state = model.params.states.index("S") for patch_idx, num_imported_cases in enumerate(imported_cases): if num_imported_cases == 0: continue if hasattr(model.people, "active"): idx = np.where(model.people.active & (model.people.patch_id == patch_idx))[0] else: idx = np.where(model.people.patch_id == patch_idx)[0] if len(idx) == 0: imported_cases[patch_idx] = 0 continue susceptible_idx = idx[model.people.state[idx] == susceptible_state] if len(susceptible_idx) == 0: continue n_draw = min(num_imported_cases, len(susceptible_idx)) chosen = model.prng.choice(susceptible_idx, size=n_draw, replace=False) infection_component.infect(model, chosen) imported_cases[patch_idx] = n_draw
def _initialize(self, model: ABMModel) -> None: n_patches = model.patches.count patch_ids = model.scenario["id"].to_list() rates = self.params.crude_importation_rate if isinstance(rates, (int, float)): arr = np.full(n_patches, float(rates), dtype=np.float64) elif isinstance(rates, (np.ndarray, list)) or (isinstance(rates, Sequence) and not isinstance(rates, (str, bytes))): if len(rates) != n_patches: raise ValueError(f"crude_importation_rate length {len(rates)} does not match number of patches {n_patches}") arr = np.asarray(rates, dtype=np.float64) elif isinstance(rates, dict): unknown = set(rates) - set(patch_ids) if unknown: raise ValueError(f"Unknown patch ids in crude_importation_rate: {sorted(unknown)}") arr = np.array([rates.get(pid, 0.0) for pid in patch_ids], dtype=np.float64) else: raise TypeError("Unsupported crude_importation_rate type") self.patch_rates_per_year_per_1k = arr