from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.patches import Polygon
from shapefile import Reader
from shapefile import Writer
[docs]
def check_field(path: str | Path, field_name: str) -> bool:
path = Path(path) if isinstance(path, str) else path
with Reader(path) as sf:
fields = [field[0] for field in sf.fields[1:]]
if field_name in fields:
return True
return False
[docs]
def add_dotname(
path: str | Path,
dot_name_fields: list[str],
dotname_symbol: str = ":",
append_suffix: str = "dotname",
inplace: bool = False,
field_name: str = "DOTNAME",
) -> None:
"""
Add a DOTNAME to the shapefile.
"""
def make_temp_path(path: Path, append_suffix: str, suffix: str) -> Path:
return path.with_name(path.stem + "_" + append_suffix + suffix)
# Resolve shapefile
path = Path(path) if isinstance(path, str) else path
# original shapefile
with Reader(path) as sf:
fields = [field[0] for field in sf.fields[1:]]
if not all(field in fields for field in dot_name_fields):
raise ValueError(f"Dot name fields {dot_name_fields} not found in shapefile {path}. Choices are {fields}")
if field_name in fields:
return
dotnames = [
dotname_symbol.join([shaperec.record[field].lower() for field in dot_name_fields]) for shaperec in sf.iterShapeRecords()
]
# check that all dotnames are unique
if len(dotnames) != len(set(dotnames)):
raise ValueError(f"Dotnames are not unique in shapefile {path}")
# create a new shapefile
with Writer(make_temp_path(path, append_suffix=append_suffix, suffix=path.suffix)) as w:
# add the original fields
for i, field in enumerate(sf.fields):
if i > 0:
w.field(*field)
# add the new field
w.field(field_name, "C", 50)
record_cnt = 0
for _, shaperec in enumerate(sf.iterShapeRecords()):
dotname = dotname_symbol.join([shaperec.record[field].lower() for field in dot_name_fields])
# add the new field
w.record(*shaperec.record, dotname)
# add the shape
w.shape(shaperec.shape)
record_cnt += 1
# copy the new shapefile to the old
if inplace:
for suffix in [".shp", ".shx", ".prj", ".cpg", ".prj", ".dbf"]:
temp_path = make_temp_path(path, append_suffix=append_suffix, suffix=suffix)
target_path = path.with_suffix(suffix)
if temp_path.exists():
if target_path.exists():
target_path.unlink() # Remove existing file first
temp_path.rename(target_path)
[docs]
def get_shapefile_dataframe(shapefile_path: str | Path) -> pl.DataFrame:
"""
Get a DataFrame containing the shapefile data with DOTNAME and shape columns.
Args:
shapefile_path: The path to the shapefile.
Returns:
A DataFrame with DOTNAME and shape columns.
"""
shapefile_path = Path(shapefile_path) if isinstance(shapefile_path, str) else shapefile_path
if not shapefile_path.exists():
raise FileNotFoundError(f"Shapefile not found at {shapefile_path}")
with Reader(shapefile_path) as sf:
# Get all records and shapes
records = []
shapes = []
for shaperec in sf.iterShapeRecords():
records.append(shaperec.record)
shapes.append(shaperec.shape)
record_dict = defaultdict(list)
for record in records:
for key, value in record.as_dict().items():
record_dict[key].append(value)
# Convert to DataFrame
df = pl.DataFrame(record_dict)
# Add shape column
df = df.with_columns(pl.Series(name="shape", values=shapes))
return df
[docs]
def plot_shapefile_dataframe(df: pl.DataFrame, ax: plt.Axes | None = None, plot_kwargs: dict | None = None) -> plt.Figure:
if ax is None:
fig, ax = plt.subplots()
if plot_kwargs is None:
plot_kwargs = {}
default_plot_kwargs = {
"closed": True,
"fill": True,
"edgecolor": "black",
"linewidth": 1.0,
"facecolor": "white",
}
default_plot_kwargs.update(plot_kwargs)
# if "facecolor" in default_plot_kwargs:
# default_plot_kwargs["fill"] = True
xlim = [float("inf"), float("-inf")]
ylim = [float("inf"), float("-inf")]
def get_data(data: list[tuple[float, float]], index: int) -> list[float]:
return [x[index] for x in data]
for shape in df["shape"]:
parts = [*shape.parts, len(shape.points)]
for part in range(len(parts) - 1): # Only plot first shape?
polygon = Polygon(shape.points[parts[part] : parts[part + 1]], **default_plot_kwargs)
ax.add_patch(polygon)
xlim[0] = min(xlim[0], min(get_data(shape.points, 0)))
xlim[1] = max(xlim[1], max(get_data(shape.points, 0)))
ylim[0] = min(ylim[0], min(get_data(shape.points, 1)))
ylim[1] = max(ylim[1], max(get_data(shape.points, 1)))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_aspect(1 / np.cos(np.mean(ylim) * np.pi / 180))
ax.set_axis_off()
return ax.figure
if __name__ == "__main__":
df = get_shapefile_dataframe("/home/krosenfeld/code/laser-measles/examples/expanding_kano/gadm41_NGA_shp/gadm41_NGA_2_shp")
plot_shapefile_dataframe(df)