"""
Component for seeding initial infections in the compartmental model.
This component allows initialization of infections in specific patches or automatically
selects the largest patch by population for seeding.
"""
import numpy as np
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from laser_measles.abm.base import PatchLaserFrame
from laser_measles.abm.base import PeopleLaserFrame
from laser_measles.base import BaseComponent
from laser_measles.base import BaseLaserModel
class InfectionSeedingParams(BaseModel):
    """Parameters for the infection seeding component."""
    num_infections: int = Field(default=1, description="Default number of infections to seed", ge=1)
    target_patches: list[str] | None = Field(default=None, description="List of specific patch IDs to seed")
    infections_per_patch: (
        int | list[int]
    ) | None = Field(default=None, description="Number of infections per patch (single int or list matching target_patches)")
    use_largest_patch: bool = Field(default=True, description="Whether to seed the largest patch by default")
    @field_validator("infections_per_patch")
    @classmethod
    def validate_infections_per_patch(cls, v, info):
        """Validate that infections_per_patch matches target_patches length if both provided."""
        if v is not None and "target_patches" in info.data and info.data["target_patches"] is not None:
            if isinstance(v, list):
                if len(v) != len(info.data["target_patches"]):
                    raise ValueError("Length of infections_per_patch must match length of target_patches")
                if any(x < 1 for x in v):
                    raise ValueError("All values in infections_per_patch must be >= 1")
            elif isinstance(v, int):
                if v < 1:
                    raise ValueError("infections_per_patch must be >= 1")
        return v
    @field_validator("target_patches")
    @classmethod
    def validate_target_patches(cls, v):
        """Validate target_patches format."""
        if v is not None:
            for patch_id in v:
                if not isinstance(patch_id, str) or not patch_id.strip():
                    raise ValueError("All target_patches must be non-empty strings")
        return v
[docs]
class InfectionSeedingProcess(BaseComponent):
    """
    Component for seeding initial infections in the compartmental model.
    This component initializes infections by moving individuals from the Susceptible (S)
    compartment to the Infected (I) compartment. It can either:
    1. Automatically seed the patch with the largest population (default)
    2. Seed specific patches provided by the user
    The seeding occurs during initialize() before the simulation begins.
    Parameters
    ----------
    model : BaseLaserModel
        The compartmental model instance
    verbose : bool, default=False
        Whether to print verbose output during initialization
    params : Optional[InfectionSeedingParams], default=None
        Component-specific parameters. If None, will use default parameters
    Examples
    --------
    # Seed 1 infection in largest patch (default)
    seeding_params = InfectionSeedingParams()
    # Seed 5 infections in largest patch
    seeding_params = InfectionSeedingParams(num_infections=5)
    # Seed specific patches with same number of infections
    seeding_params = InfectionSeedingParams(
        target_patches=["nigeria:kano:kano:A0001", "nigeria:kano:kano:A0002"],
        infections_per_patch=3
    )
    # Seed specific patches with different numbers of infections
    seeding_params = InfectionSeedingParams(
        target_patches=["nigeria:kano:kano:A0001", "nigeria:kano:kano:A0002"],
        infections_per_patch=[5, 2]
    )
    """
    def __init__(self, model: BaseLaserModel, verbose: bool = False, params: InfectionSeedingParams | None = None) -> None:
        super().__init__(model, verbose)
        self.params = params or InfectionSeedingParams()
        self._validate_params()
    def _validate_params(self) -> None:
        """Validate component parameters."""
        if self.params.target_patches is None and not self.params.use_largest_patch:
            raise ValueError("Either target_patches must be provided or use_largest_patch must be True")
    def _initialize(self, model: BaseLaserModel) -> None:
        """
        Initialize infections by seeding susceptible individuals.
        This method is called once during model initialization, after equilibrium
        states are set up but before simulation begins.
        """
        if self.verbose:
            print("Initializing infection seeding...")
        # Get patch information from the scenario
        scenario_df = model.scenario.unwrap()
        patch_ids = scenario_df["id"].to_list()
        populations = scenario_df["pop"].to_list()
        # Determine which patches to seed and how many infections per patch
        if self.params.target_patches is None:
            # Use largest patch by default
            target_patches, infections_per_patch = self._get_largest_patch_seeding(patch_ids, populations)
        else:
            # Use specified patches
            target_patches, infections_per_patch = self._get_specified_patch_seeding()
        # Validate that target patches exist in the model
        self._validate_patches_exist(target_patches, patch_ids)
        # Perform the seeding
        total_seeded = self._seed_infections(model, target_patches, infections_per_patch, patch_ids)
        if self.verbose:
            print(f"Successfully seeded {total_seeded} infections across {len(target_patches)} patches")
    def _get_largest_patch_seeding(self, patch_ids: list[str], populations: list[int]) -> tuple[list[str], list[int]]:
        """Get the largest patch for seeding."""
        max_pop_idx = np.argmax(populations)
        largest_patch = patch_ids[max_pop_idx]
        if self.verbose:
            print(f"Selected largest patch: {largest_patch} (population: {populations[max_pop_idx]:,})")
        return [largest_patch], [self.params.num_infections]
    def _get_specified_patch_seeding(self) -> tuple[list[str], list[int]]:
        """Get specified patches and infection counts."""
        target_patches = self.params.target_patches.copy() if self.params.target_patches is not None else []
        if self.params.infections_per_patch is None:
            # Use default num_infections for all patches
            infections_per_patch = [self.params.num_infections] * len(target_patches)
        elif isinstance(self.params.infections_per_patch, int):
            # Use same number for all patches
            infections_per_patch = [self.params.infections_per_patch] * len(target_patches)
        else:
            # Use specified list
            infections_per_patch = self.params.infections_per_patch.copy()
        return target_patches, infections_per_patch
    def _validate_patches_exist(self, target_patches: list[str], patch_ids: list[str]) -> None:
        """Validate that all target patches exist in the model."""
        missing_patches = [p for p in target_patches if p not in patch_ids]
        if missing_patches:
            raise ValueError(f"Target patches not found in model: {missing_patches}")
    def _seed_infections(
        self, model: BaseLaserModel, target_patches: list[str], infections_per_patch: list[int], patch_ids: list[str]
    ) -> int:
        """Seed infections in the specified patches."""
        if not hasattr(model, "people") or model.people is None:
            raise RuntimeError("Model does not have people attribute or it is None")
        if not hasattr(model, "patches") or model.patches is None:
            raise RuntimeError("Model does not have patches attribute or it is None")
        people: PeopleLaserFrame = model.people
        patches: PatchLaserFrame = model.patches
        total_seeded = 0
        num_active = len(model.people)
        for patch_id, num_infections in zip(target_patches, infections_per_patch, strict=False):
            # Find patch index
            patch_idx = patch_ids.index(patch_id)
            # Get current susceptible population
            current_susceptible = int(patches.states.S[patch_idx])
            # Determine actual number of infections to seed (limited by susceptible population)
            actual_infections = min(num_infections, current_susceptible)
            if actual_infections < num_infections:
                if self.verbose:
                    print(
                        f"Warning: Patch {patch_id} has only {current_susceptible} susceptible individuals, "
                        f"seeding {actual_infections} instead of {num_infections}"
                    )
            if actual_infections > 0:
                idx = np.where(
                    np.logical_and(people.patch_id[:num_active] == patch_idx, people.state[:num_active] == model.params.states.index("S"))
                )[0]
                # idx = model.prng.choice(idx, size=actual_infections, replace=False)
                model.prng.shuffle(idx)
                idx = idx[:actual_infections]
                flag = 0
                for instance in model.instances:
                    if hasattr(instance, "infect"):
                        assert np.all(idx < num_active), "Index out of bounds"
                        instance.infect(model, idx)
                        flag += 1
                if flag == 0:
                    raise RuntimeError("No instance found with an infect method")
                elif flag > 1:
                    raise RuntimeError("Multiple instances found with an infect method")
                # Patch states are now updated by the component's infect method
                total_seeded += actual_infections
                if self.verbose:
                    print(f"Seeded {actual_infections} infections in patch {patch_id}")
            else:
                if self.verbose:
                    print(f"Warning: No susceptible individuals in patch {patch_id}, skipping")
        return total_seeded