-
Notifications
You must be signed in to change notification settings - Fork 5
/
run.py
145 lines (111 loc) · 4.68 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import random
import os
import warnings
import comet_ml
import supersuit as ss
import numpy as np
from gym import spaces
import torch
from torch.nn import Module
import hydra
from omegaconf import DictConfig
from src.envs import get_env
from src.envs import ObstoStateWrapper, pettingzoo_env_to_vec_env_v1, concat_vec_envs_v1, black_death_v3, PermuteObsWrapper, AddStateSpaceActMaskWrapper, CooperativeRewardsWrapper, ParallelEnv
from src.replay_buffer import ReplayBuffer, ReplayBufferImageObs
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def make_train_env(env_config):
if env_config.family == 'marlgrid':
envs = [AddStateSpaceActMaskWrapper(PermuteObsWrapper(CooperativeRewardsWrapper(get_env(env_config.name, env_config.family, env_config.params)))) for _ in range(env_config.rollout_threads)]
env = ParallelEnv(envs)
return env
env_class = get_env(env_config.name, env_config.family, env_config.params)
env = env_class.parallel_env(**env_config.params)
if env_config.continuous_action:
env = ss.clip_actions_v0(env)
if env_config.family != 'starcraft':
env = ss.pad_observations_v0(env)
env = ss.pad_action_space_v0(env)
else:
env = black_death_v3(env)
env = ObstoStateWrapper(env)
if env_config.family == 'starcraft':
env = pettingzoo_env_to_vec_env_v1(env, black_death=True)
else:
env = pettingzoo_env_to_vec_env_v1(env, black_death=False)
env = concat_vec_envs_v1(env, env_config.rollout_threads, num_cpus=1, base_class='gym')
return env
def make_eval_env(env_config):
if env_config.family == 'marlgrid':
envs = [AddStateSpaceActMaskWrapper(PermuteObsWrapper(get_env(env_config.name, env_config.family, env_config.params))) for _ in range(1)]
env = ParallelEnv(envs)
return env
env_class = get_env(env_config.name, env_config.family, env_config.params)
env = env_class.parallel_env(**env_config.params)
if env_config.continuous_action:
env = ss.clip_actions_v0(env)
if env_config.family != 'starcraft':
env = ss.pad_observations_v0(env)
env = ss.pad_action_space_v0(env)
else:
env = black_death_v3(env)
env = ObstoStateWrapper(env)
return env
@hydra.main(config_path="configs/", config_name="config.yaml")
def main(cfg: DictConfig):
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")
train_envs = make_train_env(cfg.env)
eval_env = make_eval_env(cfg.env)
if isinstance(train_envs.observation_space, spaces.Dict):
observation_space = train_envs.observation_space['observation']
elif isinstance(train_envs.observation_space, tuple):
observation_space = train_envs.observation_space[0]
else:
observation_space = train_envs.observation_space
if isinstance(train_envs.action_space, tuple):
action_space = train_envs.action_space[0]
else:
action_space = train_envs.action_space
if cfg.env.obs_type == 'image' and cfg.policy.params.type == 'conv':
state_space = spaces.Box(
low=-float('inf'),
high=float('inf'),
shape=(cfg.policy.params.conv_out_size * cfg.n_agents,),
dtype='float',
)
elif isinstance(train_envs.state_space, tuple):
state_space = train_envs.state_space[0]
else:
state_space = train_envs.state_space
policy = hydra.utils.instantiate(
cfg.policy,
observation_space=observation_space,
action_space=action_space,
state_space=state_space,
params=cfg.policy.params)
policy = policy.to(device)
if cfg.env.obs_type == 'image' and cfg.policy.params.type == 'conv':
buffer = ReplayBufferImageObs(observation_space, action_space, cfg.buffer, device)
else:
buffer = ReplayBuffer(observation_space, action_space, state_space, cfg.buffer, device)
runner = hydra.utils.instantiate(
cfg.runner,
train_env=train_envs,
eval_env=eval_env,
env_family=cfg.env.family,
policy=policy,
buffer=buffer,
params=cfg.runner.params,
device=device)
if not cfg.test_mode:
runner.run()
mean_rewards, std_rewards, mean_wins, std_wins = runner.evaluate()
print(f"Eval Rewards: {mean_rewards} +- {std_rewards} | Eval Win Rate: {mean_wins} +- {std_wins}")
train_envs.close()
eval_env.close()
if __name__ == "__main__":
main()