Source code for emod_api.multidim_plotter

import os
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import sqlite3

"""
Plot EMOD output data from an experiment found in an SQLite db created by emodpy.
"""
[docs]def plot_from_sql(x_tag, y_tag, output, label, exp_id=None): """ Plot colormap/3D figure from data in <experiment_id>/results.db. Args: x_tag: Tag to use as x axis. y_tag: Tag to use as y axis. output: String to use as output, needs to correspond to one of the output cols in the db. label: Figure needs a label. exp_id: Optional experiment id. If omitted, 'latest_experiment' is used. """ fig = plt.figure() ax = Axes3D(fig) #ax = fig.gca(projection='3d') #query = f"select {x_tag}, {y_tag}, avg(output) from results where cast({x_tag} as decimal) > 0.01 group by {x_tag}, {y_tag};" #query = f"select {x_tag}, {y_tag}, avg(output), output2 from results group by {x_tag}, {y_tag};" if exp_id: db = os.path.join( str(exp_id), "results.db" ) else: db = os.path.join( "latest_experiment", "results.db" ) con = sqlite3.connect( db ) cur = con.cursor() x_tag = x_tag.replace( ' ', '_' ).replace( '-', '_' ) y_tag = y_tag.replace( ' ', '_' ).replace( '-', '_' ) #query = f"select {x_tag}, {y_tag}, avg(output) from results where cast(output2 as decimal) > 0.35 and cast(output2 as decimal)<0.45 group by {x_tag}, {y_tag};" query = f"select {x_tag}, {y_tag}, avg({output}) from results group by {x_tag}, {y_tag};" try: cur.execute( query ) results = cur.fetchall() except Exception as ex: print( f"Encountered fatal exception {ex} when executing query {query} on db {db}." ) return x = [] y = [] z = [] for result in results: x.append( result[0] ) y.append( result[1] ) z.append( result[2] ) surf = ax.plot_trisurf(x, y, z, cmap=cm.jet, linewidth=0.1) ax.set_xlabel( f"{x_tag} rate" ) ax.set_ylabel( f"{y_tag} rate" ) #ax.set_zlabel( "final prevalence" ) #ax.set_zlabel( label ) fig.colorbar(surf, shrink=0.5, aspect=5) #X, Y = np.meshgrid(np.array(x), np.array(y)) #ax.plot_surface( X, Y, np.array(z) ) ax.view_init(elev=90, azim=0) plt.title( label ) plt.show()
if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Spatial Report Plotting') parser.add_argument('-x', '--xtag', action='store', default="", help='X tag (must be in db)' ) parser.add_argument('-y', '--ytag', action='store', default="", help='Y tag (must be in db' ) parser.add_argument('-o', '--output', action='store', default="", help='Single value output file' ) parser.add_argument('-t', '--title', action='store', default="", help='Graph title' ) parser.add_argument('-e', '--experiment_id', action='store', default=None, help='experiment id to plot, uses latest_experiment folder if omitted (not used yet)' ) args = parser.parse_args() if not args.experiment_id: with open( "COMPS_ID", "r" ) as fp: args.experiment_id = fp.read() # check that folder with name experiment_id exists if not os.path.exists( str(args.experiment_id) ): raise ValueError( f"Don't see folder for {args.experiment_id}." ) plot_from_sql( args.xtag, args.ytag, args.output, args.title )