"""
Raster patch generator for demographic data.
You can use this to generate initial conditions (e.g, population, MCV1 coverage) for a
laser-measles scenario.
"""
from datetime import UTC
from datetime import datetime
from pathlib import Path
import alive_progress
import numpy as np
import polars as pl
from PIL import Image
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from rastertoolkit import raster_clip
from rastertoolkit import raster_clip_weighted
from laser_measles.demographics import cache
from laser_measles.demographics import shapefiles
from laser_measles.demographics.gadm import GADMShapefile
[docs]
class RasterPatchParams(BaseModel):
id: str = Field(..., description="Unique identifier for the scenario")
region: str = Field(..., description="Country identifier (ISO3 code)")
shapefile: str | Path = Field(..., description="Path to the shapefile")
population_raster: str | Path | None = Field(default=None, description="Path to the population raster")
mcv1_raster: str | Path | None = Field(default=None, description="Path to the MCV1 raster")
mcv2_raster: str | Path | None = Field(default=None, description="Path to the MCV2 raster")
[docs]
@field_validator("shapefile")
def shapefile_exists(cls, v, info):
path = Path(v) if isinstance(v, str) else v
if not path.exists():
raise ValueError(f"Shapefile does not exist: {path}")
return v
[docs]
class RasterPatchGenerator:
def __init__(self, config: RasterPatchParams, verbose: bool = True):
self.config = config
self.verbose = verbose
self.population = None
self.mcv1 = None
self.mcv2 = None
self._validate_config()
def _get_file_mtime(self, file_path: str | Path) -> float:
"""Get the modification time of a file."""
return Path(file_path).stat().st_mtime
def _check_source_files_modified(self, cache_key: str) -> bool:
"""Check if any source files have been modified since cache was created."""
with cache.load_cache() as c:
if cache_key not in c:
return True
cache_entry = c[cache_key]
if not isinstance(cache_entry, dict) or "timestamp" not in cache_entry:
return True
cache_time = cache_entry["timestamp"]
# Check shapefile modification time
if self._get_file_mtime(self.shapefile) > cache_time:
return True
# Check population raster modification time
if self._get_file_mtime(self.config.population_raster) > cache_time:
return True
# Check MCV1 raster if it exists
if cache_key == "mcv1" and self.config.mcv1_raster:
if self._get_file_mtime(self.config.mcv1_raster) > cache_time:
return True
return False
[docs]
def generate_demographics(self) -> None:
self._validate_shapefile()
self.population = self.generate_population()
if self.config.mcv1_raster is not None:
self.mcv1 = self.generate_mcv1()
def _validate_config(self) -> None:
if not shapefiles.check_field(self.config.shapefile, "DOTNAME"):
raise ValueError(f"Shapefile {self.config.shapefile_path} does not have a DOTNAME field")
# Validate mcv1_raster_path if provided
if self.config.mcv1_raster is not None:
path = Path(self.config.mcv1_raster) if isinstance(self.config.mcv1_raster, str) else self.config.mcv1_raster
if not path.exists():
raise FileNotFoundError(f"MCV1 raster path does not exist: {path}")
self._validate_shapefile()
def _validate_shapefile(self):
""" """
path = Path(self.config.shapefile) if isinstance(self.config.shapefile, str) else self.config.shapefile
if not path.exists():
raise FileNotFoundError(f"Shapefile path does not exist: {path}")
if not shapefiles.check_field(path, "DOTNAME"):
raise ValueError(f"Shapefile {path} does not have a DOTNAME field")
self.shapefile = path
[docs]
def get_cache_key(self, key) -> str:
keys = ["shapefile", "population", "mcv1"]
if key not in keys:
raise ValueError(f"Invalid key: {key}\nValid keys are: {keys}")
return f"{self.config.id}" + ":" + key
[docs]
def generate_population(self) -> pl.DataFrame:
"""Population, counts"""
cache_key = self.get_cache_key("population")
with cache.load_cache() as c:
if cache_key not in c or self._check_source_files_modified(cache_key):
# clip the raster to the shapefile
with alive_progress.alive_bar(title="Clipping population raster to shapefile"):
popdict = raster_clip(self.config.population_raster, self.shapefile, include_latlon=True)
new_dict = {"dotname": [], "lat": [], "lon": [], "pop": []}
for k, v in popdict.items():
new_dict["dotname"].append(k)
new_dict["lat"].append(v["lat"])
new_dict["lon"].append(v["lon"])
new_dict["pop"].append(v["pop"])
# Store data with timestamp
c[cache_key] = {"data": new_dict, "timestamp": datetime.now(UTC).timestamp()}
return pl.DataFrame(c[cache_key]["data"])
[docs]
def generate_mcv1(self) -> pl.DataFrame:
"""MCV1 coverage, population weighted"""
cache_key = self.get_cache_key("mcv1")
if self.config.mcv1_raster is None:
raise ValueError("MCV1 raster path is not provided")
with cache.load_cache() as c:
if cache_key not in c or self._check_source_files_modified(cache_key):
# Value array: Set negative values to zero
new_values_raster_file = self.config.mcv1_raster.with_name(f"{self.config.mcv1_raster.stem}_zeros.tif")
with Image.open(self.config.mcv1_raster) as raster:
data = np.array(raster)
data[data < 0] = 0
new_raster = Image.fromarray(data, mode=raster.mode)
new_raster.info.update(raster.info) # Preserve metadata
new_raster.save(new_values_raster_file, tiffinfo=raster.tag_v2)
# Weight array: Set negative values to zero
new_weight_raster_file = self.config.population_raster.with_name(f"{self.config.population_raster.stem}_zeros.tif")
with Image.open(self.config.population_raster) as raster:
data = np.array(raster)
data[data < 0] = np.nan
new_raster = Image.fromarray(data, mode=raster.mode)
new_raster.info.update(raster.info) # Preserve metadata
new_raster.save(new_weight_raster_file, tiffinfo=raster.tag_v2)
with alive_progress.alive_bar(title="Clipping MCV1 raster to shapefile"):
mcv_dict = raster_clip_weighted(
new_weight_raster_file,
new_values_raster_file,
shape_stem=self.config.shapefile,
include_latlon=True,
weight_summary_func=np.mean,
)
# remove the rasters
new_weight_raster_file.unlink()
new_values_raster_file.unlink()
# Store data with timestamp
c[cache_key] = {"data": mcv_dict, "timestamp": datetime.now(UTC).timestamp()}
mcv_dict = c[cache_key]["data"]
new_dict = {"dotname": [], "lat": [], "lon": [], "mcv1": []}
for k, v in mcv_dict.items():
new_dict["dotname"].append(k)
new_dict["lat"].append(v["lat"])
new_dict["lon"].append(v["lon"])
new_dict["mcv1"].append(v["val"])
return pl.DataFrame(new_dict)
[docs]
def clear_cache(self) -> None:
with cache.load_cache() as c:
for k in c.iterkeys():
if k.startswith(self.config.id):
del c[k]
[docs]
def generate_birth_rates(self) -> pl.DataFrame: ...
[docs]
def generate_mortality_rates(self) -> pl.DataFrame: ...
def _add_dotname(self) -> None: ...
if __name__ == "__main__":
gadm = GADMShapefile("NGA")
gadm.clear_cache()
gadm.download()
gadm.add_dotnames()
config = RasterPatchParams(
region="NGA",
start_year=2000,
end_year=2020,
granularity="patch",
patch_size_km=25,
shapefile=gadm.get_shapefile_path(2),
population_raster_path=gadm.shapefile_dir,
)
generator = RasterPatchGenerator(config)
generator.generate_demographics()
# print(generator.generate_population())