Skip to content

multidim_plotter

plot_from_sql(x_tag, y_tag, output, label, exp_id=None)

Plot colormap/3D figure from data in /results.db.

Parameters:

Name Type Description Default
x_tag str

Tag to use as x axis.

required
y_tag str

Tag to use as y axis.

required
output str

String to use as output, needs to correspond to one of the output cols in the db.

required
label str

Figure needs a label.

required
exp_id str

Optional experiment id. If omitted, 'latest_experiment' is used.

None
Source code in emod_api/multidim_plotter.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def plot_from_sql(x_tag: str,
                  y_tag: str,
                  output: str,
                  label: str,
                  exp_id: str = None):
    """
    Plot colormap/3D figure from data in <experiment_id>/results.db.

    Args:
        x_tag: Tag to use as x axis.
        y_tag: Tag to use as y axis.
        output: String to use as output, needs to correspond to one of the output cols in the db.
        label: Figure needs a label.
        exp_id: Optional experiment id. If omitted, 'latest_experiment' is used.
    """

    fig = plt.figure()
    ax = Axes3D(fig)

    if exp_id:
        db = os.path.join(str(exp_id), "results.db")
    else:
        db = os.path.join("latest_experiment", "results.db")
    con = sqlite3.connect(db)
    cur = con.cursor()
    x_tag = x_tag.replace(' ', '_').replace('-', '_')
    y_tag = y_tag.replace(' ', '_').replace('-', '_')
    query = f"select {x_tag}, {y_tag}, avg({output}) from results group by {x_tag}, {y_tag};"
    try:
        cur.execute(query)
        results = cur.fetchall()
    except Exception as ex:
        print(f"Encountered fatal exception {ex} when executing query {query} on db {db}.")
        return

    x = []
    y = []
    z = []
    for result in results:
        x.append(result[0])
        y.append(result[1])
        z.append(result[2])
    surf = ax.plot_trisurf(x, y, z, cmap=cm.jet, linewidth=0.1)
    ax.set_xlabel(f"{x_tag} rate")
    ax.set_ylabel(f"{y_tag} rate")
    fig.colorbar(surf, shrink=0.5, aspect=5)
    ax.view_init(elev=90, azim=0)
    plt.title(label)
    plt.show()