Source code for laser_measles.demographics.raster_patch

"""
Raster patch generator for demographic data.
You can use this to generate initial conditions (e.g, population, MCV1 coverage) for a
laser-measles scenario.
"""

from datetime import UTC
from datetime import datetime
from pathlib import Path

import alive_progress
import numpy as np
import polars as pl
from PIL import Image
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from rastertoolkit import raster_clip
from rastertoolkit import raster_clip_weighted

from laser_measles.demographics import cache
from laser_measles.demographics import shapefiles
from laser_measles.demographics.gadm import GADMShapefile


[docs] class RasterPatchParams(BaseModel): id: str = Field(..., description="Unique identifier for the scenario") region: str = Field(..., description="Country identifier (ISO3 code)") shapefile: str | Path = Field(..., description="Path to the shapefile") population_raster: str | Path | None = Field(default=None, description="Path to the population raster") mcv1_raster: str | Path | None = Field(default=None, description="Path to the MCV1 raster") mcv2_raster: str | Path | None = Field(default=None, description="Path to the MCV2 raster")
[docs] @field_validator("shapefile") def shapefile_exists(cls, v, info): path = Path(v) if isinstance(v, str) else v if not path.exists(): raise ValueError(f"Shapefile does not exist: {path}") return v
[docs] class RasterPatchGenerator: def __init__(self, config: RasterPatchParams, verbose: bool = True): self.config = config self.verbose = verbose self.population = None self.mcv1 = None self.mcv2 = None self._validate_config() def _get_file_mtime(self, file_path: str | Path) -> float: """Get the modification time of a file.""" return Path(file_path).stat().st_mtime def _check_source_files_modified(self, cache_key: str) -> bool: """Check if any source files have been modified since cache was created.""" with cache.load_cache() as c: if cache_key not in c: return True cache_entry = c[cache_key] if not isinstance(cache_entry, dict) or "timestamp" not in cache_entry: return True cache_time = cache_entry["timestamp"] # Check shapefile modification time if self._get_file_mtime(self.shapefile) > cache_time: return True # Check population raster modification time if self._get_file_mtime(self.config.population_raster) > cache_time: return True # Check MCV1 raster if it exists if cache_key == "mcv1" and self.config.mcv1_raster: if self._get_file_mtime(self.config.mcv1_raster) > cache_time: return True return False
[docs] def generate_demographics(self) -> None: self._validate_shapefile() self.population = self.generate_population() if self.config.mcv1_raster is not None: self.mcv1 = self.generate_mcv1()
def _validate_config(self) -> None: if not shapefiles.check_field(self.config.shapefile, "DOTNAME"): raise ValueError(f"Shapefile {self.config.shapefile_path} does not have a DOTNAME field") # Validate mcv1_raster_path if provided if self.config.mcv1_raster is not None: path = Path(self.config.mcv1_raster) if isinstance(self.config.mcv1_raster, str) else self.config.mcv1_raster if not path.exists(): raise FileNotFoundError(f"MCV1 raster path does not exist: {path}") self._validate_shapefile() def _validate_shapefile(self): """ """ path = Path(self.config.shapefile) if isinstance(self.config.shapefile, str) else self.config.shapefile if not path.exists(): raise FileNotFoundError(f"Shapefile path does not exist: {path}") if not shapefiles.check_field(path, "DOTNAME"): raise ValueError(f"Shapefile {path} does not have a DOTNAME field") self.shapefile = path
[docs] def get_cache_key(self, key) -> str: keys = ["shapefile", "population", "mcv1"] if key not in keys: raise ValueError(f"Invalid key: {key}\nValid keys are: {keys}") return f"{self.config.id}" + ":" + key
[docs] def generate_population(self) -> pl.DataFrame: """Population, counts""" cache_key = self.get_cache_key("population") with cache.load_cache() as c: if cache_key not in c or self._check_source_files_modified(cache_key): # clip the raster to the shapefile with alive_progress.alive_bar(title="Clipping population raster to shapefile"): popdict = raster_clip(self.config.population_raster, self.shapefile, include_latlon=True) new_dict = {"dotname": [], "lat": [], "lon": [], "pop": []} for k, v in popdict.items(): new_dict["dotname"].append(k) new_dict["lat"].append(v["lat"]) new_dict["lon"].append(v["lon"]) new_dict["pop"].append(v["pop"]) # Store data with timestamp c[cache_key] = {"data": new_dict, "timestamp": datetime.now(UTC).timestamp()} return pl.DataFrame(c[cache_key]["data"])
[docs] def generate_mcv1(self) -> pl.DataFrame: """MCV1 coverage, population weighted""" cache_key = self.get_cache_key("mcv1") if self.config.mcv1_raster is None: raise ValueError("MCV1 raster path is not provided") with cache.load_cache() as c: if cache_key not in c or self._check_source_files_modified(cache_key): # Value array: Set negative values to zero new_values_raster_file = self.config.mcv1_raster.with_name(f"{self.config.mcv1_raster.stem}_zeros.tif") with Image.open(self.config.mcv1_raster) as raster: data = np.array(raster) data[data < 0] = 0 new_raster = Image.fromarray(data, mode=raster.mode) new_raster.info.update(raster.info) # Preserve metadata new_raster.save(new_values_raster_file, tiffinfo=raster.tag_v2) # Weight array: Set negative values to zero new_weight_raster_file = self.config.population_raster.with_name(f"{self.config.population_raster.stem}_zeros.tif") with Image.open(self.config.population_raster) as raster: data = np.array(raster) data[data < 0] = np.nan new_raster = Image.fromarray(data, mode=raster.mode) new_raster.info.update(raster.info) # Preserve metadata new_raster.save(new_weight_raster_file, tiffinfo=raster.tag_v2) with alive_progress.alive_bar(title="Clipping MCV1 raster to shapefile"): mcv_dict = raster_clip_weighted( new_weight_raster_file, new_values_raster_file, shape_stem=self.config.shapefile, include_latlon=True, weight_summary_func=np.mean, ) # remove the rasters new_weight_raster_file.unlink() new_values_raster_file.unlink() # Store data with timestamp c[cache_key] = {"data": mcv_dict, "timestamp": datetime.now(UTC).timestamp()} mcv_dict = c[cache_key]["data"] new_dict = {"dotname": [], "lat": [], "lon": [], "mcv1": []} for k, v in mcv_dict.items(): new_dict["dotname"].append(k) new_dict["lat"].append(v["lat"]) new_dict["lon"].append(v["lon"]) new_dict["mcv1"].append(v["val"]) return pl.DataFrame(new_dict)
[docs] def clear_cache(self) -> None: with cache.load_cache() as c: for k in c.iterkeys(): if k.startswith(self.config.id): del c[k]
[docs] def generate_birth_rates(self) -> pl.DataFrame: ...
[docs] def generate_mortality_rates(self) -> pl.DataFrame: ...
def _add_dotname(self) -> None: ...
if __name__ == "__main__": gadm = GADMShapefile("NGA") gadm.clear_cache() gadm.download() gadm.add_dotnames() config = RasterPatchParams( region="NGA", start_year=2000, end_year=2020, granularity="patch", patch_size_km=25, shapefile=gadm.get_shapefile_path(2), population_raster_path=gadm.shapefile_dir, ) generator = RasterPatchGenerator(config) generator.generate_demographics() # print(generator.generate_population())