Calibration#

class Calibration(sim, calib_pars, n_workers=None, total_trials=None, reseed=True, build_fn=None, build_kw=None, eval_fn=None, eval_kw=None, components=None, prune_fn=None, label=None, study_name=None, db_name=None, keep_db=None, continue_db=None, storage=None, sampler=None, die=False, debug=False, verbose=True)[source]#

Bases: prettyobj

A class to handle calibration of Starsim simulations. Uses the Optuna hyperparameter optimization library (optuna.org).

Parameters:
  • sim (Sim) – the base simulation to calibrate

  • calib_pars (dict) – a dictionary of the parameters to calibrate of the format dict(key1=dict(low=1, high=2, guess=1.5, **kwargs), key2=...), where kwargs can include “suggest_type” to choose the suggest method of the trial (e.g. suggest_float) and args passed to the trial suggest function like “log” and “step”

  • n_workers (int) – the number of parallel workers (if None, will use all available CPUs)

  • total_trials (int) – the total number of trials to run, each worker will run approximately n_trials = total_trial / n_workers

  • reseed (bool) – whether to generate new random seeds for each trial

  • build_fn (callable) – function that takes a sim object and calib_pars dictionary and returns a modified sim

  • build_kw (dict) – a dictionary of options that are passed to build_fn to aid in modifying the base simulation. The API is self.build_fn(sim, calib_pars=calib_pars, **self.build_kw), where sim is a copy of the base simulation to be modified with calib_pars

  • components (list) – CalibComponents independently assess pseudo-likelihood as part of evaluating the quality of input parameters

  • prune_fn (callable) – Function that takes a dictionary of parameters and returns True if the trial should be pruned

  • eval_fn (callable) – Function mapping a sim to a float (e.g. negative log likelihood) to be maximized. If None, the default will use CalibComponents.

  • eval_kw (dict) – Additional keyword arguments to pass to the eval_fn

  • label (str) – a label for this calibration object

  • study_name (str) – name of the optuna study

  • db_name (str) – the name of the database file (default: ‘starsim_calibration.db’)

  • continue_db (bool) – whether to continue if the database already exists, removes the database if false (default: false, any existing database will be deleted)

  • keep_db (bool) – whether to keep the database after calibration (default: false, the database will be deleted)

  • storage (str) – the location of the database (default: sqlite)

  • sampler (BaseSampler) – the sampler used by optuna, like optuna.samplers.TPESampler

  • die (bool) – whether to stop if an exception is encountered (default: false)

  • debug (bool) – if True, do not run in parallel

  • verbose (bool) – whether to print details of the calibration

Returns:

A Calibration object

Methods

run_sim(calib_pars=None, label=None)[source]#

Create and run a simulation

plot(**kwargs)[source]#

” Plot the calibration results. For a component-based likelihood, it only makes sense to directly call plot after calling eval_fn.

run_trial(trial)[source]#

Define the objective for Optuna

worker()[source]#

Run a single worker

run_workers()[source]#

Run multiple workers in parallel

remove_db()[source]#

Remove the database file if keep_db is false and the path exists

make_study()[source]#

Make a study, deleting if it already exists and user does not want to continue_db

calibrate(calib_pars=None, **kwargs)[source]#

Perform calibration.

Parameters:
  • calib_pars (dict) – if supplied, overwrite stored calib_pars

  • kwargs (dict) – if supplied, overwrite stored run_args (n_trials, n_workers, etc.)

to_df(top_k=None)[source]#

Return the top K results as a dataframe, sorted by value

check_fit(do_plot=True)[source]#

Run before and after simulations to validate the fit

parse_study(study)[source]#

Parse the study into a data frame – called automatically

to_json(filename=None, indent=2, **kwargs)[source]#

Convert the results to JSON

plot_final(**kwargs)[source]#

Plot sims after calibration

Parameters:

kwargs (dict) – passed to MultiSim.plot()

plot_optuna(methods=None)[source]#

Plot Optuna’s visualizations