Source code for emodpy_hiv.plotting.plot_inset_chart

#!/usr/bin/python

"""
This module contains methods for plotting channel reports (i.e. InsetChart).
"""

import argparse
import matplotlib.pyplot as plt
import numpy as np
import json
import sys
import os
import pylab
from math import sqrt, ceil

import emodpy_hiv.plotting.helpers as helpers

[docs]def get_raw_color(idx: int): """ When plotting the raw data as background, use a lighter color than the test data. Needs to be synchronized with get_color_name(). Args: idx: index of the plot used to select color Returns: Matplotlib basic color to use for plotting. """ colors = [('blue', 0.1), ('green', 0.1), ('cyan', 0.1), ('magenta', 0.1), ('yellow', 0.1), ('black', 0.1)] return colors[idx % len(colors)]
[docs]def get_color_name(idx: int): """ Return name of color that should be returned by getColor() given the same input value. Needs to be synchronized with get_raw_color(). Args: idx: index of the plot used to select color Returns: Name of the basic color to use for plotting in matplotlib. """ color_names = ['blue', 'green', 'cyan', 'magenta', 'yellow', 'black'] return color_names[idx % len(color_names)]
[docs]def get_list_of_channels(ref_data: dict, test_data: list[dict]): """ Returns a list of the unique channel names used in both the reference data and the test data. This should enable the display of all of the channels even when both reports do not have the same channels. Args: ref_data: channel report, json dictionary consider to contain the baseline data test_data: a list of channel reports (dictionaries) containing data to compare to Returns: Unique list of channels from all the channels in the input """ channel_titles_list = [] if ref_data is not None: channel_titles_list = list(ref_data["Channels"].keys()) for data in test_data: channel_titles_list = channel_titles_list + list(data["Channels"].keys()) channel_titles_set = set(channel_titles_list) channel_titles_list = sorted(list(channel_titles_set)) return channel_titles_list
[docs]def create_title_string(reference: str, data_filenames: list[str]): """ Returns a string that contains the input file names where the color used in plotting is included in the name. This can be used as the title of the plot. Args: reference: name of the reference data file data_filenames: a list of the test data file names Returns: A string where each file name is on its own line and includes the color to be used in plotting in the name. """ title = "" if reference is not None: title = "reference(red)=" + reference + "\n" for i, filename in enumerate(data_filenames): color_name = get_color_name(i) title = title + "test(" + color_name + ")=" + filename if i < (len(data_filenames) - 1): title = title + "\n" return title
[docs]def plot_subplot(chan_title: str, data: dict, color: str, linewidth: int, subplot: plt.Axes): if chan_title in data["Channels"]: tstep = 1 if "Simulation_Timestep" in data["Header"]: tstep = data["Header"]["Simulation_Timestep"] x_len = len(data["Channels"][chan_title]["Data"]) x_data = np.arange(0, (x_len * tstep), tstep) y_data = data["Channels"][chan_title]["Data"] subplot.plot(x_data, y_data, color=color, linewidth=linewidth) else: print("Raw Data missing channel = " + chan_title)
[docs]def plot_data(title: str, ref_data: dict = None, test_data: list[dict] = None, raw_data_list_of_lists: list[list[dict]] = None, test_filenames: list[str] = None, subplot_index_min: int = 0, subplot_index_max: int = 100, img_dir: str = None, plot_name: str = None): """ Plot the data such that there is a grid of subplots with each subplot representing a "channel" of data. Each subplot will have time on the x-axis and the units of that channel on the y-axis. Args: title: The string to put at the top of the page ref_data: A channel report dictionary whose data will be plotted in red test_data: A list of channel report dictionaries whose data will be plotted in colors other than red test_file_names: The list of file names in parallel to the test_data. subplot_index_min: The index of the first subplot to show based on the alphabetical order of the channels in the report. subplot_index_min: The index of the last subplot to show based on the alphabetical order of the channels in the report. img_dir: The name of the directory to save the images to. If not provided, it will open a window. plot_name: If provided the name of the file for the saved image. Returns: Nothing """ if test_filenames is None: test_filenames = [] if img_dir is not None: plt.figure(figsize=(24, 13)) channel_titles_list = get_list_of_channels(ref_data, test_data) num_chans = len(channel_titles_list) if subplot_index_max >= num_chans: subplot_index_max = num_chans - 1 num_chans = subplot_index_max - subplot_index_min + 1 square_root = ceil(sqrt(num_chans)) # Explicitly perform a float division here as integer division floors in Python 2.x n_figures_y = ceil(float(num_chans) / float(square_root)) n_figures_x = square_root ref_color = "red" if len(test_data) == 0: ref_color = "blue" idx = -1 for subplot_index, chan_title in enumerate(channel_titles_list): if (subplot_index < subplot_index_min) or (subplot_index > subplot_index_max): continue idx += 1 idx_x = idx % n_figures_x idx_y = int(idx / n_figures_x) try: subplot = plt.subplot2grid((n_figures_y, n_figures_x), (idx_y, idx_x)) if raw_data_list_of_lists is not None: for list_index, raw_data_list in enumerate(raw_data_list_of_lists): raw_color = get_raw_color(list_index) for raw_data in raw_data_list: plot_subplot(chan_title=chan_title, data=raw_data, color=raw_color, linewidth=1, subplot=subplot) if ref_data is not None: plot_subplot(chan_title=chan_title, data=ref_data, color=ref_color, linewidth=2, subplot=subplot) for test_idx, data in enumerate(test_data): tst_color = get_color_name(test_idx) plot_subplot(chan_title=chan_title, data=data, color=tst_color, linewidth=1, subplot=subplot) plt.setp(subplot.get_xticklabels(), fontsize='7') plt.title(chan_title, fontsize='9') except Exception as ex: print("Exception: " + str(ex)) plt.suptitle(title) plt.subplots_adjust(left=0.04, right=0.99, bottom=0.04, top=0.9, wspace=0.3, hspace=0.3) if img_dir: if not os.path.exists(img_dir): os.makedirs(img_dir) plot_name = plot_name.replace(" ", "_") plot_name = plot_name.replace("\n", "_") fn = os.path.join(img_dir, plot_name + ".png") print(fn) pylab.savefig(fn, dpi=300, orientation='landscape') else: plt.show() plt.close() return
[docs]def plot_inset_chart(dir_name: str = None, reference: str = None, comparison1: str = None, comparison2: str = None, comparison3: str = None, title: str = None, include_filenames_in_title=True, output: str = None): """ Plot the inset chart using the provided parameters. Args: dir_name: Directory containing channel reports with .json extension reference: Reference channel report filename comparison1: Comparison1 channel report filename comparison2: Comparison2 channel report filename comparison3: Comparison3 channel report filename title: Title of Plot include_filenames_in_title: If true, includes the filenames in the title (needed for testing) output: If provided, a directory will be created and images saved to the folder. If not provided, it opens windows. Returns: Nothing """ test_filenames = [] test_data = [] if dir_name is not None: test_filenames = helpers.get_filenames(dir_or_filename=dir_name, file_prefix="InsetChart", file_extension="json") if comparison1 is not None: test_filenames.append(comparison1) if comparison2 is not None: test_filenames.append(comparison2) if comparison3 is not None: test_filenames.append(comparison3) for test_fn in test_filenames: with open(test_fn, "r") as test_file: test_data.append(json.loads(test_file.read())) ref_data = None if reference is not None: with open(reference, "r") as ref_file: ref_data = json.loads(ref_file.read()) plot_name = title if plot_name is None: plot_name = "InsetChart" if title is None: title = "" if include_filenames_in_title: num_files = 0 if reference is not None: num_files = 1 num_files = num_files + len(test_filenames) if (num_files > 4) and (dir_name is not None): title = title + "\n" + dir_name else: title = title + "\n" + create_title_string(reference, test_filenames) plot_data(title=title, ref_data=ref_data, test_data=test_data, img_dir=output, plot_name=plot_name)
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('reference', default=None, nargs='?', help='Reference InsetChart filename') parser.add_argument('comparison1', default=None, nargs='?', help='Comparison1 InsetChart filename') parser.add_argument('comparison2', default=None, nargs='?', help='Comparison2 InsetChart filename') parser.add_argument('comparison3', default=None, nargs='?', help='Comparison3 InsetChart filename') parser.add_argument('-d', '--dir', default=None, nargs='?', help='Directory, or parent directory that contains subdirectories, of InsetChart.json files') parser.add_argument('-t', '--title', default=None, nargs='?', help='Title of Plot') parser.add_argument('-o', '--output', default=None, help='If provided, a directory will be created and images saved to the folder. If not provided, it opens windows.') args = parser.parse_args() if len(sys.argv) == 1: parser.print_help() sys.exit() plot_inset_chart(dir_name=args.dir, reference=args.reference, comparison1=args.comparison1, comparison2=args.comparison2, comparison3=args.comparison3, title=args.title, output=args.output)