-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathplot_metric.py
More file actions
46 lines (38 loc) · 1.37 KB
/
plot_metric.py
File metadata and controls
46 lines (38 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import pickle as pkl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from os import listdir
import itertools
import matplotlib.patches as mpatches
import jax.numpy as jnp
from jax import vmap, jit
from functools import partial
convolve = jit(vmap(partial(jnp.convolve, mode='valid'), in_axes=(1, None)))
palette = itertools.cycle(sns.color_palette())
parser = argparse.ArgumentParser()
parser.add_argument("--data", "-d", type=str, nargs='+')
parser.add_argument("--window", "-w", type=int, default=1)
parser.add_argument("--frequency", "-f", type=int, default=1)
parser.add_argument("--metric", "-m", type=str, default="reward_rates")
args = parser.parse_args()
handles = []
for d in args.data:
color = next(palette)
values = None
with open(d, 'rb') as f:
data = pkl.load(f)["metrics"]
values = np.asarray(data[args.metric])[::args.frequency,:]
times = jnp.transpose(np.asarray([data["eval_times"][::args.frequency]]*values.shape[1]))
values = convolve(values, np.ones(args.window)/args.window)
times = convolve(times, np.ones(args.window)/args.window)
handles+=[mpatches.Patch(color=color, label=d)]
values = values.flatten()
times = times.flatten()
data_frame = pd.DataFrame({args.metric:values,"time":times})
sns.lineplot(x="time", y=args.metric, data=data_frame)
plt.legend(handles=handles)
plt.grid()
plt.show()