Source code for idmtools_calibra.plotters.site_data_plotter

# flake8: noqa E402
import logging
import os
import matplotlib.pyplot as plt

plt.switch_backend('agg')
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from idmtools.core.enums import ItemType
from idmtools_calibra.iteration_state import IterationState
from idmtools_calibra.plotters.base_plotter import BasePlotter
from idmtools_calibra.process_state import StatusPoint

sns.set_style('white', {'axes.linewidth': 0.5})

logger = logging.getLogger(__name__)


[docs]class SiteDataPlotter(BasePlotter): def __init__(self, combine_sites=True, num_to_plot=5, ll_all_name: str = 'LL_all.csv'): super(SiteDataPlotter, self).__init__(combine_sites) self.num_to_plot = num_to_plot self.ll_all_name = ll_all_name @property def directory(self): return self.get_plot_directory() # ZD [TODO]: self.iteration_state.analyzer_list doesn't keep site info, here we assume all analyzers have different names!!!
[docs] def get_site_analyzer(self, site_name, analyzer_name): for site, analyzers in self.site_analyzer_names.items(): if site_name != site: continue site_analyzer = f'{site_name}_{analyzer_name}' for analyzer in self.iteration_state.analyzer_list: if site_analyzer == analyzer.uid: return analyzer raise Exception(f'Unable to find analyzer={analyzer_name} for site={site_name}')
[docs] def get_analyzer_data(self, iteration, site_name, analyzer_name): site_analyzer = '%s_%s' % (site_name, analyzer_name) return IterationState.restore_state(iteration).analyzers[site_analyzer]
[docs] def visualize(self, iteration_state): self.iteration_state = iteration_state self.site_analyzer_names = iteration_state.site_analyzer_names iteration_status = self.iteration_state.status if iteration_status != StatusPoint.plot: return # Only plot once results are available try: if self.combine_sites: for site_name, analyzer_names in self.site_analyzer_names.items(): sorted_results = self.all_results.sort_values(by='total', ascending=False).reset_index() # self.plot_analyzers(site_name, analyzer_names, sorted_results) else: for site_name, analyzer_names in self.site_analyzer_names.items(): self.combine_by_site(site_name, analyzer_names, self.all_results) sorted_results = self.all_results.sort_values(by=f'{site_name}_total', ascending=False).reset_index() # self.plot_analyzers(site_name, analyzer_names, sorted_results) except Exception as e: logger.info("SiteDataPlotter could not plot for one or more analyzer(s).") raise e try: self.write_LL_csv() except: logger.info("Log likelihood CSV could not be created. Skipping...")
[docs] def plot_analyzers(self, site_name, analyzer_names, samples): cmin, cmax = samples['total'].describe()[['min', 'max']].tolist() cmin = cmin if cmin < cmax else cmax - 1 # avoid divide by zero in color range for analyzer_name in analyzer_names: site_analyzer = '%s_%s' % (site_name, analyzer_name) try: os.makedirs(os.path.join(self.directory, site_analyzer)) except: pass self.plot_best(site_name, analyzer_name, samples.iloc[:self.num_to_plot]) self.plot_all(site_name, analyzer_name, samples, clim=(cmin, cmax))
[docs] def plot_best(self, site_name, analyzer_name, samples): analyzer = self.get_site_analyzer(site_name, analyzer_name) for iteration, iter_samples in samples.groupby('iteration'): analyzer_data = self.get_analyzer_data(iteration, site_name, analyzer_name) for rank, sample in iter_samples['sample'].items(): # index is rank fname = os.path.join(self.directory, f'{site_name}_{analyzer_name}', f'rank{rank:d}') fig = plt.figure(fname, figsize=(8, 6)) analyzer.plot_comparison(fig, analyzer_data['samples'][sample], fmt='-o', color='#CB5FA4', alpha=1, linewidth=1) analyzer.plot_comparison(fig, analyzer_data['ref'], fmt='-o', color='#8DC63F', alpha=1, linewidth=1, reference=True) fig.set_tight_layout(True) plt.savefig(fname + '.png', format='PNG') plt.close(fig)
[docs] def plot_all(self, site_name, analyzer_name, samples, clim): analyzer = self.get_site_analyzer(site_name, analyzer_name) fname = os.path.join(self.directory, f'{site_name}_{analyzer_name}_all') fig = plt.figure(fname, figsize=(4, 3)) cmin, cmax = clim for iteration, iter_samples in samples.groupby('iteration'): analyzer_data = self.get_analyzer_data(iteration, site_name, analyzer_name) results_by_sample = iter_samples.reset_index().set_index('sample')['total'] for sample, result in results_by_sample.items(): analyzer.plot_comparison(fig, analyzer_data['samples'][sample], fmt='-', color=cm.Blues((result - cmin) / (cmax - cmin)), alpha=0.5, linewidth=0.5) analyzer.plot_comparison(fig, analyzer_data['ref'], fmt='-o', color='#8DC63F', alpha=1, linewidth=1, reference=True) fig.set_tight_layout(True) plt.savefig(fname + '.png', format='PNG') plt.close(fig)
[docs] def cleanup(self): """ cleanup the existing plots :param calib_manager: :return: """ if self.combine_sites: for site, analyzers in self.site_analyzer_names.items(): self.cleanup_plot_by_analyzers(site, analyzers, self.all_results) else: for site, analyzers in self.site_analyzer_names.items(): self.cleanup_plot_by_analyzers(site, analyzers, self.all_results)
[docs] def cleanup_plot_by_analyzers(self, site, analyzers, samples): """ cleanup the existing plots :param site: :param analyzers: :param samples: :return: """ best_samples = samples.iloc[:self.num_to_plot] for analyzer in analyzers: site_analyzer = f'{site}_{analyzer}' self.cleanup_plot_for_best(site_analyzer, best_samples) self.cleanup_plot_for_all(site_analyzer)
[docs] def cleanup_plot_for_best(self, site_analyzer, samples): """ cleanup the existing plots :param site_analyzer: :param samples: :return: """ for iteration, iter_samples in samples.groupby('iteration'): for rank, sample in iter_samples['sample'].items(): # index is rank fname = os.path.join(self.directory, site_analyzer, 'rank%d' % rank) plot_path = fname + '.pdf' if os.path.exists(plot_path): try: # logger.info("Try to delete %s" % plot_path) os.remove(plot_path) pass except OSError: logger.error(f"Failed to delete {plot_path}")
[docs] def cleanup_plot_for_all(self, site_analyzer): """ cleanup the existing plots :param site_analyzer: :return: """ fname = os.path.join(self.directory, f'{site_analyzer}_all') plot_path = fname + '.pdf' if os.path.exists(plot_path): try: # logger.info("Try to delete %s" % plot_path) os.remove(plot_path) except OSError: logger.error(f"Failed to delete {plot_path}")
[docs] def write_LL_csv(self, ll_all_name: str = None): """ Write the LL_summary.csv with what is in the CalibManager """ if ll_all_name: self.ll_all_name = ll_all_name # Data needed for the LL_CSV # location = self.iteration_state.exp_manager.experiment.location iteration_state = self.iteration_state iteration = self.iteration_state.iteration suite_id = iteration_state.suite_id # Deep copy all_results ato not disturb the calibration all_results = self.all_results.copy() # Index the likelihood-results DataFrame on (iteration, sample) to join with simulation info results_df = all_results.reset_index().set_index(['iteration', 'sample']) # Get the simulation info from the iteration state siminfo_df = pd.DataFrame.from_dict(iteration_state.simulations, orient='index') siminfo_df.index.name = 'simid' siminfo_df['iteration'] = iteration siminfo_df = siminfo_df.rename(columns={'__sample_index__': 'sample'}).reset_index() # Group simIDs by sample point and merge back into results grouped_simids_df = siminfo_df.groupby(['iteration', 'sample']).simid.agg(lambda x: tuple(x)) join_results_df = results_df.join(grouped_simids_df, how='right') # right: only this iteration with new sim info platform = self.iteration_state.platform sims_paths = platform.create_sim_directory_map(item_id=suite_id, item_type=ItemType.SUITE) # Transform the ids in actual paths def find_path(el): paths = list() try: for e in el: paths.append(sims_paths[e]) except Exception as ex: pass # [TODO]: fix issue later. return ",".join(paths) join_results_df['outputs'] = join_results_df['simid'].apply(find_path) # Concatenate with any existing data from previous iterations and dump to file csv_path = os.path.join(self.directory, self.ll_all_name) if os.path.exists(csv_path): current = pd.read_csv(csv_path, index_col=['iteration', 'sample']) final_results_df = pd.concat([current, join_results_df]) else: final_results_df = join_results_df final_results_df.sort_values(by='total', ascending=False).to_csv(csv_path)