-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathrun_solver.py
More file actions
73 lines (61 loc) · 3.05 KB
/
run_solver.py
File metadata and controls
73 lines (61 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
""" Calls sampling-based solver that uses agents' value functions to find the optimal contract.
The low-level implementation of the solver can be found inside environments/two_stage_train.py
"""
import copy
from ray.rllib.agents import ppo
from PIL import Image
from utils.ray_config_utils import get_config_and_env,get_neg_config,get_solver_config
from environments.env_utils import make_video_from_rgb_imgs
import numpy as np
import time
import torch
def run_solver(params_dict,checkpoint_paths,logger):
num_samples = params_dict.get('solver_samples',10)
trainer_config, _ = get_config_and_env(params_dict) # that way, don't need to pickle / store
env_copy = get_solver_config(params_dict,trainer_config,checkpoint_paths)
logger.set_stage(2)
for _, path in enumerate(checkpoint_paths):
train_config = copy.deepcopy(trainer_config)
train_config['num_workers'] = 1
train_config['evaluation_num_workers'] = 0
train_config['num_gpus'] = 0
if train_config.get('stop_cond') :
del train_config['stop_cond']
frozen_trainer = ppo.PPOTrainer(config=train_config, env=train_config['env'])
frozen_trainer.load_checkpoint(path)
all_rewards,all_contracts = [] , []
for j in range(num_samples):
env_obs = env_copy.reset()
contract_param = env_copy.contract_param
env_dones = {'__all__': False}
if params_dict.get('joint'):
active_agents = ['a0']
else:
active_agents = ['a' + str(i) for i in range(train_config['env_config']['num_agents'])]
domain_steps = 0
ep_rewards = 0
while (not env_dones['__all__']) and domain_steps < train_config['horizon']:
act_dict = {}
for key in active_agents:
if params_dict['shared_policy']:
act_dict[key] = frozen_trainer.compute_single_action(
env_obs[key], policy_id='policy')
else:
act_dict[key] = frozen_trainer.compute_single_action(
env_obs[key], policy_id=key)
env_obs, r, env_dones, i = env_copy.step(act_dict)
# clear inactive agents from rewards
for key in env_dones:
if env_dones[key] and key in active_agents:
active_agents.remove(key)
domain_steps += 1
for key in env_obs:
ep_rewards += r[key]
log_dict = {'ep_rewards':ep_rewards, 'contract_param':float(contract_param[0])}
logger.simple_log(log_dict)
all_rewards.append(ep_rewards)
all_contracts.append(contract_param)
logger.simple_log({'mean reward':np.mean(all_rewards),'mean contract':np.mean(all_contracts),
'std reward':np.std(all_rewards),'std contract':np.std(all_contracts)})
time.sleep(300) # wait for logger to finish writing
print('Finished Stage 2')