Source code for emod_api.channelreports.utils

"""
Helper functions, primarily for property reports, which are channel reports.
"""

import json
from pathlib import Path
from typing import Dict, List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from emod_api.channelreports.channels import ChannelReport

__all__ = [
    "property_report_to_csv",
    "read_json_file",
    "get_report_channels",
    "_validate_property_report_channels",
    "_validate_property_report_ips",
    "accumulate_channel_data",
    "__get_trace_name",
    "save_to_csv",
    "plot_traces",
    "__index_for",
    "__title_for"]


[docs]def property_report_to_csv(source_file: Union[str, Path], csv_file: Union[str, Path], channels: Optional[List[str]]=None, groupby: Optional[List[str]]=None, transpose: bool=False) -> None: """ Write a property report to a CSV formatted file. Optionally selected a subset of available channels. Optionally "rolling-up" IP:value sub-channels into a "parent" IP. Args: source_file: filename of property report channels: list of channels to output, None results in writing _all_ channels to output groupby: list of IPs into which to aggregate remaining IPs, None indicates no grouping, [] indicates _all_ aggregated csv_file: filename of CSV formatted result transpose: write channels as columns rather than rows """ json_data = read_json_file(Path(source_file)) channel_data = get_report_channels(json_data) if channels is None: channels = sorted({key.split(":")[0] for key in channel_data}) elif isinstance(channels, str): channels = [channels] if isinstance(groupby, str): groupby = [groupby] _validate_property_report_channels(channels, channel_data) _validate_property_report_ips(groupby, channel_data) trace_values = accumulate_channel_data(channels, False, groupby, channel_data) save_to_csv(trace_values, csv_file, transpose) return
[docs]def read_json_file(filename: Union[str, Path]) -> Dict: with Path(filename).open("r", encoding="utf-8") as file: json_data = json.load(file) return json_data
[docs]def get_report_channels(json_data: Dict) -> Dict: try: channel_data = json_data['Channels'] except KeyError as exc: raise KeyError("Didn't find 'Channels' in JSON data.") from exc return channel_data
def _validate_property_report_channels(channels, channel_data) -> None: if channels: keys = set(map(lambda name: name.split(":", 1)[0], channel_data)) not_found = [name for name in channels if not name in keys] if not_found: print("Valid channel names:") print("\n".join(keys)) raise ValueError(f"Specified channel(s) - {not_found} - is/are not valid channel names.") return def _validate_property_report_ips(groupby, channel_data) -> None: if groupby: first = next(iter(channel_data)) ip_string = first.split(":", 1)[1] ips = [kvp.split(":")[0] for kvp in ip_string.split(",")] not_found = [ip for ip in groupby if not ip in ips] if not_found: print("Valid IPs:") print("\n".join(ips)) raise ValueError(f"Specified groupby IP(s) - {not_found} - is/are not valid IP names.") return
[docs]def accumulate_channel_data(channels: List[str], verbose: bool, groupby: List[str], channel_data: Dict) -> Dict[str, np.ndarray]: """ Extract selected channel(s) from property report data. Aggregate on groupby IP(s), if provided, otherwise on channel per unique IP:value pair (e.g., "QualityOfCare:High"), per main channel (e.g., "Infected"). Args: channels: names of channels to plot verbose: output some "debugging"/progress information if true groupby: IP(s) under which to aggregate other IP:value pairs channel_data: data for channels keyed on channel name Returns: tuple of dictionary of aggregated data, keyed on channel name, and of Numpy array of normalization values """ trace_values = {} pool_keys = sorted(channel_data) name_ip_pairs = map(lambda key: tuple(key.split(":", 1)), pool_keys) name_ip_pairs_to_process = filter(lambda p: p[0] in channels, name_ip_pairs) for (channel_title, key_value_pairs) in name_ip_pairs_to_process: if verbose: print(f"Processing channel '{channel_title}:{key_value_pairs}'") key_value_pairs = key_value_pairs.split(',') trace_name = __get_trace_name(channel_title, key_value_pairs, groupby) trace_data = np.array(channel_data[ f"{channel_title}:{','.join(key_value_pairs)}" ][ 'Data' ], dtype=np.float32) if trace_name not in trace_values: if verbose: print(f"New trace: '{trace_name}'") trace_values[trace_name] = trace_data else: if verbose: print(f"Add to trace: '{trace_name}'") trace_values[trace_name] += trace_data return trace_values
def __get_trace_name(channel_title: str, key_value_pairs: List[str], groupby: List[str]) -> str: """ Return "canonical" trace name for a given channel, IP:value list, and groupby list. Since we may be aggregating by IP values, trace name may not equal any particular channel name. Example: title = "Infected" key_value_pairs = ["Age_Bin:Age_Bin_Property_From_0_To_20","QualityOfCare:High","QualityOfCare1:High","QualityOfCare2:High"] groupby = None return "Infected:Age_Bin:Age_Bin_Property_From_0_To_20,QualityOfCare:High,QualityOfCare1:High,QualityOfCare2:High" groupby = ["Age_Bin"] return = "Infected:Age_Bin:Age_Bin_Property_From_0_To_20" groupby = ["Age_Bin", "QualityOfCare"] return = "Infected:Age_Bin:Age_Bin_Property_From_0_To_20,QualityOfCare:High" groupby = [] return = "Infected" """ # trace name will have channel title and any property:value pairs # which aren't being grouped trace_name = channel_title + ':' if groupby is None: trace_name = f"{channel_title}:{','.join(key_value_pairs)}" else: if len(groupby) > 0: kvps = filter(lambda pair: pair.split(":")[0] in groupby, key_value_pairs) trace_name = f"{channel_title}:{','.join(kvps)}" else: trace_name = channel_title return trace_name
[docs]def save_to_csv(trace_values: Dict[str, np.ndarray], filename: Union[str, Path], transpose: bool=False) -> None: """ Save property report to CSV. Uses underlying ChannelReport.to_csv() function. Args: trace_values: full set of available channels, keyed on channel name filename: destination file for CSV data transpose: write channels as columns rather than rows """ report = ChannelReport() for channel, data in trace_values.items(): report.channels[channel] = data report.to_csv(Path(filename), transpose=transpose) # by default, use _all_ the channels we just added return
[docs]def plot_traces( trace_values: Dict[str, np.ndarray], norm_values: Optional[Union[int, np.ndarray]], overlay: bool, channels: List[str], title: str, legend: bool, ) -> plt.Figure: """ Plot trace data. One subplot per channel unless overlaying all variations of rolled-up IP(s) is requested. A trace (like old-time pen and ink EKG) may represent the aggregation of several IP values so trace may not equal any particular channel data. Args: trace_values: channel data, keyed on channel name norm_values: normalization data for channels overlay: whether or not to overlay all variations of a given channel on one subplot channels: selection of channel names to plot title: plot title legend: whether or not to include a legend on plots Returns: plt.Figure """ if len(trace_values) == 0: print("Didn't find requested channel(s) in property report.") return if not overlay: plot_count = len(trace_values) else: plot_count = len(channels) normalize = norm_values is not None if normalize: plot_count *= 2 figure = plt.figure(title, figsize=(16,9), dpi=300) trace_keys = sorted(trace_values) # plotting here for trace_name in trace_keys: plot_index = __index_for(trace_name, channels, trace_keys, normalize, overlay) plt.subplot(plot_count, 1, plot_index) plt.plot(trace_values[trace_name], label=trace_name) if normalize: plt.subplot(plot_count, 1, plot_index+1) plt.ylim((0.0, 1.0)) # yes, this takes a tuple plt.plot(trace_values[trace_name]/norm_values, label=trace_name) # make it pretty _ = plt.subplot(plot_count, 1, 1) for trace_name in trace_keys: plot_index = __index_for(trace_name, channels, trace_keys, normalize, overlay) plot_title = __title_for(trace_name, channels, overlay) plt.subplot(plot_count, 1, plot_index) plt.title(plot_title) if legend: plt.legend() if normalize: plt.subplot(plot_count, 1, plot_index+1) plt.title(f"{plot_title} normalized by 'Statistical Population'") if legend: plt.legend() plt.tight_layout() return figure
def __index_for(trace_name: str, channels: List[str], trace_keys: List[str], normalize: bool, overlay: bool) -> int: if overlay: # all pools of the same channel overlaid index = 0 for channel in channels: if channel in trace_name: break index += 1 else: # each trace separate index = trace_keys.index(trace_name) # if we're normalizing, there's a normalized trace per regular trace if normalize: index *= 2 # matplotlib is 1-based (like MATLAB) return index+1 def __title_for(trace_name: str, channels: List[str], overlay: bool): # use channel name if overlay: for channel in channels: if channel in trace_name: title = channel break else: title = trace_name return title