Source code for laser.measles.compartmental.components.process_importation_pressure

"""
Component for simulating the importation pressure in the compartmental model.
"""

from collections.abc import Sequence

import numpy as np
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator

from laser.measles.base import BasePhase
from laser.measles.compartmental.model import CompartmentalModel
from laser.measles.utils import cast_type


class ImportationPressureParams(BaseModel):
    """Parameters for the importation pressure component.

    Importation pressure simulates external case introductions from outside the
    modeled population (e.g., international travel, cross-border movement),
    moving susceptible individuals into the exposed (E) compartment each daily tick.

    Note:
        Unlike the ABM and biweekly models which move S→I, the compartmental model
        moves imported cases S→E, letting them progress through the latent period.

    Attributes:
        crude_importation_rate: Yearly importation rate per 1,000 population.
            Three forms are accepted:

            - **float** (scalar): uniform rate applied to every patch.
              ``0.0`` disables importation entirely.

            - **list / tuple / numpy array** (sequence): per-patch rates in the
              same order as ``model.scenario`` rows. Length must equal the number
              of patches.

            - **dict[str, float]** (sparse patch override): maps patch string ids
              (values from ``model.scenario["id"]``, e.g. ``"n_0_0"``, ``"n_2_2"``)
              to rates. Patches absent from the dict receive a rate of **0.0** —
              they do *not* inherit the scalar default of 1.0.

        importation_start: Day on which importation begins (inclusive). Default
            ``0`` starts importation at the first tick.

        importation_end: Day on which importation ends (inclusive). Use ``-1``
            (default) to keep importation active for the full simulation. Must be
            greater than ``importation_start`` when not ``-1``.

    Examples:
        Uniform low background pressure across all patches::

            params = ImportationPressureParams(crude_importation_rate=0.05)

        Disable importation entirely::

            params = ImportationPressureParams(crude_importation_rate=0.0)

        Per-patch sequence (one entry per patch, aligned to scenario row order)::

            # 25-patch model; patch at row index 12 is the metro hub
            rates = [0.02] * 25
            rates[12] = 0.5
            params = ImportationPressureParams(crude_importation_rate=rates)

        Sparse dict — only named patches receive importation; all others get 0.0::

            # Use string ids from model.scenario["id"], e.g. "n_0_0", "n_2_2"
            params = ImportationPressureParams(
                crude_importation_rate={"n_2_2": 0.5, "n_0_0": 0.1},
            )

        Numpy array input (accepted and converted to list internally)::

            import numpy as np
            params = ImportationPressureParams(
                crude_importation_rate=np.array([0.01, 0.05, 0.01, 0.01, 0.01])
            )

        Time-windowed importation active only during the first year (days 0-364)::

            params = ImportationPressureParams(
                crude_importation_rate=0.1,
                importation_start=0,
                importation_end=364,
            )

        Metro-only importation for the first year, then stop::

            params = ImportationPressureParams(
                crude_importation_rate={"n_2_2": 2.0},
                importation_start=0,
                importation_end=364,
            )
    """

    crude_importation_rate: float | list[float] | dict[str, float] = Field(
        default=1.0,
        description=(
            "Yearly crude importation rate per 1k population. "
            "Scalar: uniform across all patches. "
            "Sequence (list/tuple/ndarray): per-patch rates aligned to scenario row order, length must equal n_patches. "
            "Dict[str, float]: sparse override keyed by patch string id (from model.scenario['id']); "
            "omitted patches default to 0.0, not 1.0."
        ),
    )
    importation_start: int = Field(
        default=0,
        description="Day on which importation begins (inclusive).",
        ge=0,
    )
    importation_end: int = Field(
        default=-1,
        description="Day on which importation ends (inclusive). Use -1 (default) to run for the full simulation.",
        ge=-1,
    )

    @field_validator("importation_end")
    @classmethod
    def validate_importation_end(cls, v, info):
        """Validate that importation_end is greater than importation_start when not -1."""
        if v != -1:
            start = info.data.get("importation_start", 0)
            if v <= start:
                raise ValueError("importation_end must be greater than importation_start")
        return v

    @field_validator("crude_importation_rate")
    @classmethod
    def validate_importation_rate(cls, v):
        if isinstance(v, (int, float)):
            if v < 0:
                raise ValueError("crude_importation_rate must be >= 0")
        elif isinstance(v, np.ndarray):
            if np.any(v < 0):
                raise ValueError("All crude_importation_rate values must be >= 0")
            return v.tolist()
        elif isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
            if any(x < 0 for x in v):
                raise ValueError("All crude_importation_rate values must be >= 0")
            return list(v)
        elif isinstance(v, dict):
            if any(x < 0 for x in v.values()):
                raise ValueError("All crude_importation_rate values must be >= 0")
        else:
            raise TypeError("crude_importation_rate must be a float, sequence of floats, or dict[str, float]")
        return v


[docs] class ImportationPressureProcess(BasePhase): """ Component for simulating the importation pressure in the model. This component handles the simulation of disease importation into the population. It processes: - Importation of cases based on crude importation rate - Time-windowed importation (start/end times) - Population updates: Moves individuals from susceptible to exposed state Parameters ---------- model : object The simulation model containing nodes, states, and parameters verbose : bool, default=False Whether to print verbose output during simulation params : Optional[ImportationPressureParams], default=None Component-specific parameters. If None, will use default parameters Notes ----- - Importation rates are calculated per year - Importation is limited to the susceptible population - All state counts are ensured to be non-negative """ def __init__(self, model, verbose: bool = False, params: ImportationPressureParams | None = None) -> None: super().__init__(model, verbose) self.params = params or ImportationPressureParams() self.patch_rates_per_year_per_1k: np.ndarray | None = None
[docs] def __call__(self, model, tick: int) -> None: if tick < (self.params.importation_start // model.params.time_step_days) or ( self.params.importation_end != -1 and tick > (self.params.importation_end // model.params.time_step_days) ): return if self.patch_rates_per_year_per_1k is None: raise RuntimeError("ImportationPressureProcess not initialized") # state counts states = model.patches.states # population population = states.sum(axis=0, dtype=np.int64) # promote to int64, otherwise binomial draw will fail p = self.patch_rates_per_year_per_1k / 365.0 / 1000.0 p = np.clip(p, 0.0, 1.0) # Sample actual number of imported cases imported_cases = model.prng.binomial(population, p) imported_cases = cast_type(imported_cases, states.dtype) np.minimum(imported_cases, states.S, out=imported_cases) # update states states.S -= imported_cases states.E += imported_cases # Move to exposed state
def _initialize(self, model: CompartmentalModel) -> None: n_patches = model.patches.count patch_ids = model.scenario["id"].to_list() rates = self.params.crude_importation_rate if isinstance(rates, (int, float)): arr = np.full(n_patches, float(rates), dtype=np.float64) elif isinstance(rates, (np.ndarray, list)) or (isinstance(rates, Sequence) and not isinstance(rates, (str, bytes))): if len(rates) != n_patches: raise ValueError(f"crude_importation_rate length {len(rates)} does not match number of patches {n_patches}") arr = np.asarray(rates, dtype=np.float64) elif isinstance(rates, dict): unknown = set(rates) - set(patch_ids) if unknown: raise ValueError(f"Unknown patch ids in crude_importation_rate: {sorted(unknown)}") arr = np.array([rates.get(pid, 0.0) for pid in patch_ids], dtype=np.float64) else: raise TypeError("Unsupported crude_importation_rate type") self.patch_rates_per_year_per_1k = arr