Source code for idmtools_calibra.plotters.site_data_plotter

# flake8: noqa E402
import logging
import os
import matplotlib.pyplot as plt

import as cm
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from idmtools.core.enums import ItemType
from idmtools_calibra.iteration_state import IterationState
from idmtools_calibra.plotters.base_plotter import BasePlotter
from idmtools_calibra.process_state import StatusPoint

sns.set_style('white', {'axes.linewidth': 0.5})

logger = logging.getLogger(__name__)

[docs]class SiteDataPlotter(BasePlotter): def __init__(self, combine_sites=True, num_to_plot=5, ll_all_name: str = 'LL_all.csv'): super(SiteDataPlotter, self).__init__(combine_sites) self.num_to_plot = num_to_plot self.ll_all_name = ll_all_name @property def directory(self): return self.get_plot_directory() # ZD [TODO]: self.iteration_state.analyzer_list doesn't keep site info, here we assume all analyzers have different names!!!
[docs] def get_site_analyzer(self, site_name, analyzer_name): for site, analyzers in self.site_analyzer_names.items(): if site_name != site: continue site_analyzer = f'{site_name}_{analyzer_name}' for analyzer in self.iteration_state.analyzer_list: if site_analyzer == analyzer.uid: return analyzer raise Exception(f'Unable to find analyzer={analyzer_name} for site={site_name}')
[docs] def get_analyzer_data(self, iteration, site_name, analyzer_name): site_analyzer = '%s_%s' % (site_name, analyzer_name) return IterationState.restore_state(iteration).analyzers[site_analyzer]
[docs] def visualize(self, iteration_state): self.iteration_state = iteration_state self.site_analyzer_names = iteration_state.site_analyzer_names iteration_status = self.iteration_state.status if iteration_status != StatusPoint.plot: return # Only plot once results are available try: if self.combine_sites: for site_name, analyzer_names in self.site_analyzer_names.items(): sorted_results = self.all_results.sort_values(by='total', ascending=False).reset_index() # self.plot_analyzers(site_name, analyzer_names, sorted_results) else: for site_name, analyzer_names in self.site_analyzer_names.items(): self.combine_by_site(site_name, analyzer_names, self.all_results) sorted_results = self.all_results.sort_values(by=f'{site_name}_total', ascending=False).reset_index() # self.plot_analyzers(site_name, analyzer_names, sorted_results) except Exception as e:"SiteDataPlotter could not plot for one or more analyzer(s).") raise e try: self.write_LL_csv() except:"Log likelihood CSV could not be created. Skipping...")
[docs] def plot_analyzers(self, site_name, analyzer_names, samples): cmin, cmax = samples['total'].describe()[['min', 'max']].tolist() cmin = cmin if cmin < cmax else cmax - 1 # avoid divide by zero in color range for analyzer_name in analyzer_names: site_analyzer = '%s_%s' % (site_name, analyzer_name) try: os.makedirs(os.path.join(, site_analyzer)) except: pass self.plot_best(site_name, analyzer_name, samples.iloc[:self.num_to_plot]) self.plot_all(site_name, analyzer_name, samples, clim=(cmin, cmax))
[docs] def plot_best(self, site_name, analyzer_name, samples): analyzer = self.get_site_analyzer(site_name, analyzer_name) for iteration, iter_samples in samples.groupby('iteration'): analyzer_data = self.get_analyzer_data(iteration, site_name, analyzer_name) for rank, sample in iter_samples['sample'].items(): # index is rank fname = os.path.join(, f'{site_name}_{analyzer_name}', f'rank{rank:d}') fig = plt.figure(fname, figsize=(8, 6)) analyzer.plot_comparison(fig, analyzer_data['samples'][sample], fmt='-o', color='#CB5FA4', alpha=1, linewidth=1) analyzer.plot_comparison(fig, analyzer_data['ref'], fmt='-o', color='#8DC63F', alpha=1, linewidth=1, reference=True) fig.set_tight_layout(True) plt.savefig(fname + '.png', format='PNG') plt.close(fig)
[docs] def plot_all(self, site_name, analyzer_name, samples, clim): analyzer = self.get_site_analyzer(site_name, analyzer_name) fname = os.path.join(, f'{site_name}_{analyzer_name}_all') fig = plt.figure(fname, figsize=(4, 3)) cmin, cmax = clim for iteration, iter_samples in samples.groupby('iteration'): analyzer_data = self.get_analyzer_data(iteration, site_name, analyzer_name) results_by_sample = iter_samples.reset_index().set_index('sample')['total'] for sample, result in results_by_sample.items(): analyzer.plot_comparison(fig, analyzer_data['samples'][sample], fmt='-', color=cm.Blues((result - cmin) / (cmax - cmin)), alpha=0.5, linewidth=0.5) analyzer.plot_comparison(fig, analyzer_data['ref'], fmt='-o', color='#8DC63F', alpha=1, linewidth=1, reference=True) fig.set_tight_layout(True) plt.savefig(fname + '.png', format='PNG') plt.close(fig)
[docs] def cleanup(self): """ cleanup the existing plots :param calib_manager: :return: """ if self.combine_sites: for site, analyzers in self.site_analyzer_names.items(): self.cleanup_plot_by_analyzers(site, analyzers, self.all_results) else: for site, analyzers in self.site_analyzer_names.items(): self.cleanup_plot_by_analyzers(site, analyzers, self.all_results)
[docs] def cleanup_plot_by_analyzers(self, site, analyzers, samples): """ cleanup the existing plots :param site: :param analyzers: :param samples: :return: """ best_samples = samples.iloc[:self.num_to_plot] for analyzer in analyzers: site_analyzer = f'{site}_{analyzer}' self.cleanup_plot_for_best(site_analyzer, best_samples) self.cleanup_plot_for_all(site_analyzer)
[docs] def cleanup_plot_for_best(self, site_analyzer, samples): """ cleanup the existing plots :param site_analyzer: :param samples: :return: """ for iteration, iter_samples in samples.groupby('iteration'): for rank, sample in iter_samples['sample'].items(): # index is rank fname = os.path.join(, site_analyzer, 'rank%d' % rank) plot_path = fname + '.pdf' if os.path.exists(plot_path): try: #"Try to delete %s" % plot_path) os.remove(plot_path) pass except OSError: logger.error(f"Failed to delete {plot_path}")
[docs] def cleanup_plot_for_all(self, site_analyzer): """ cleanup the existing plots :param site_analyzer: :return: """ fname = os.path.join(, f'{site_analyzer}_all') plot_path = fname + '.pdf' if os.path.exists(plot_path): try: #"Try to delete %s" % plot_path) os.remove(plot_path) except OSError: logger.error(f"Failed to delete {plot_path}")
[docs] def write_LL_csv(self, ll_all_name: str = None): """ Write the LL_summary.csv with what is in the CalibManager """ if ll_all_name: self.ll_all_name = ll_all_name # Data needed for the LL_CSV # location = self.iteration_state.exp_manager.experiment.location iteration_state = self.iteration_state iteration = self.iteration_state.iteration suite_id = iteration_state.suite_id # Deep copy all_results ato not disturb the calibration all_results = self.all_results.copy() # Index the likelihood-results DataFrame on (iteration, sample) to join with simulation info results_df = all_results.reset_index().set_index(['iteration', 'sample']) # Get the simulation info from the iteration state siminfo_df = pd.DataFrame.from_dict(iteration_state.simulations, orient='index') = 'simid' siminfo_df['iteration'] = iteration siminfo_df = siminfo_df.rename(columns={'__sample_index__': 'sample'}).reset_index() # Group simIDs by sample point and merge back into results grouped_simids_df = siminfo_df.groupby(['iteration', 'sample']).simid.agg(lambda x: tuple(x)) join_results_df = results_df.join(grouped_simids_df, how='right') # right: only this iteration with new sim info platform = self.iteration_state.platform sims_paths = platform.create_sim_directory_map(item_id=suite_id, item_type=ItemType.SUITE) # Transform the ids in actual paths def find_path(el): paths = list() try: for e in el: paths.append(sims_paths[e]) except Exception as ex: pass # [TODO]: fix issue later. return ",".join(paths) join_results_df['outputs'] = join_results_df['simid'].apply(find_path) # Concatenate with any existing data from previous iterations and dump to file csv_path = os.path.join(, self.ll_all_name) if os.path.exists(csv_path): current = pd.read_csv(csv_path, index_col=['iteration', 'sample']) final_results_df = pd.concat([current, join_results_df]) else: final_results_df = join_results_df final_results_df.sort_values(by='total', ascending=False).to_csv(csv_path)