# flake8: noqa E402
import logging
import os
import matplotlib
matplotlib.use('Agg', force=True)
import matplotlib.pyplot as plt
import seaborn as sns
from idmtools_calibra.process_state import StatusPoint
from idmtools_calibra.plotters.base_plotter import BasePlotter
sns.set_style('white')
logger = logging.getLogger(__name__)
[docs]class LikelihoodPlotter(BasePlotter):
def __init__(self, combine_sites=True):
super(LikelihoodPlotter, self).__init__(combine_sites)
@property
def param_names(self):
return self.iteration_state.param_names
@property
def prior_fn(self):
return self.iteration_state.next_point.prior_fn
@property
def directory(self):
return self.get_iteration_directory()
[docs] def visualize(self, iteration_state):
self.iteration_state = iteration_state
self.site_analyzer_names = iteration_state.site_analyzer_names
iteration_status = self.iteration_state.status
if iteration_status != StatusPoint.plot:
return # Only plot once results are available
if self.combine_sites:
self.plot_by_parameter()
else:
self.plot_by_parameter_and_site()
[docs] def plot_by_parameter_and_site(self):
for site, analyzers in self.site_analyzer_names.items():
self.combine_by_site(site, analyzers, self.all_results)
self.plot_by_parameter(site=site)
[docs] def plot_by_parameter(self, site='', **kwargs):
for param in self.param_names:
fig = plt.figure('LL by parameter ' + param, figsize=(5, 4))
ax = fig.add_subplot(111)
total = site + '_total' if site else 'total'
results = self.all_results[[total, 'iteration', param]]
self.plot1d_by_iteration(results, param, total, **kwargs)
try:
sample_range = self.prior_fn.sample_functions[param].sample_range
if sample_range.is_log():
ax.set_xscale('log')
ax.set_xlim(sample_range.get_xlim())
except (KeyError, AttributeError):
pass
ax.set(xlabel=param, ylabel='log likelihood')
os.makedirs(os.path.join(self.directory, site), exist_ok=True)
fig.set_tight_layout(True)
plt.savefig(os.path.join(self.directory, site, 'LL_%s.pdf' % param), format='PDF')
plt.close(fig)
[docs] @staticmethod
def plot1d_by_iteration(results, param, total, **kwargs):
iterations = results.groupby('iteration', sort=True)
n_iterations = len(iterations)
colors = ['#4BB5C1'] * (n_iterations - 1) + ['#FF2D00']
for iteration, values in iterations:
sorted_values = values.sort_values(by=param)
plt.plot(sorted_values[param], sorted_values[total],
color=colors[iteration],
linewidth=(iteration + 1) / (n_iterations + 1.) * 2,
alpha=(iteration + 1) / (n_iterations + 1.),
**kwargs)