-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
363 lines (314 loc) · 14.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_ataripy
import os
import random
import time
from dataclasses import dataclass
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import ClipRewardEnv
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = True
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "U3"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""
n_envs: int = 16
"""Num of parallel envs"""
# Algorithm specific arguments
total_timesteps: int = int(1e8)
"""total timesteps of the experiments"""
buffer_size: int = int(1e6)
"""the replay memory buffer size""" # smaller than in original paper but evaluation is done only for 100k steps anyway
gamma: float = 0.99
"""the discount factor gamma"""
tau: float = 1.0
"""target smoothing coefficient (default: 1)"""
batch_size: int = 64
"""the batch size of sample from the reply memory"""
learning_starts: int = 2e4
"""timestep to start learning"""
policy_lr: float = 3e-4
"""the learning rate of the policy network optimizer"""
q_lr: float = 3e-4
"""the learning rate of the Q network network optimizer"""
update_frequency: int = 4
"""the frequency of training updates"""
target_network_frequency: int = 8000
"""the frequency of updates for the target networks"""
alpha: float = 0.2
"""Entropy regularization coefficient."""
autotune: bool = True
"""automatic tuning of the entropy coefficient"""
target_entropy_scale: float = 0.89
"""coefficient for scaling the autotune entropy target"""
def make_env(seed, idx, capture_video, run_name):
def thunk():
env = gym.make(
"u3gym:U3GymEnv-v0",
file_name='unity/Builds/LinuxTraining/XLand',
worker_id=idx,
disable_env_checker=True,
camera_width=84,
camera_height=84,
world_folder='/network/scratch/o/omar.younis/u3-datasets/medium_low',
rule_folder='/network/scratch/o/omar.younis/u3-datasets/middle_few'
)
if capture_video and idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = gym.wrappers.RecordEpisodeStatistics(env)
env = ClipRewardEnv(env)
env.action_space.seed(seed)
return env
return thunk
def layer_init(layer, bias_const=0.0):
nn.init.kaiming_normal_(layer.weight)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
# ALGO LOGIC: initialize agent here:
# NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients
# See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info
# TL;DR The actor's gradients mess up the representation when using a joint encoder
class SoftQNetwork(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.Flatten(),
)
with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]
self.fc1 = layer_init(nn.Linear(output_dim, 512))
self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n))
def forward(self, x):
x = F.relu(self.conv(x))
x = F.relu(self.fc1(x))
q_vals = self.fc_q(x)
return q_vals
class Actor(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.Flatten(),
)
with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]
self.fc1 = layer_init(nn.Linear(output_dim, 512))
self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n))
def forward(self, x):
x = F.relu(self.conv(x))
x = F.relu(self.fc1(x))
logits = self.fc_logits(x)
return logits
def get_action(self, x):
logits = self(x)
policy_dist = Categorical(logits=logits)
action = policy_dist.sample()
# Action probabilities for calculating the adapted soft-Q loss
action_probs = policy_dist.probs
log_prob = F.log_softmax(logits, dim=1)
return action, log_prob, action_probs
if __name__ == "__main__":
import stable_baselines3 as sb3
if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = tyro.cli(Args)
run_name = f"U3__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
# env setup
VecClass = gym.vector.SyncVectorEnv if args.n_envs == 1 else gym.vector.AsyncVectorEnv
envs = VecClass([
make_env(args.seed, idx, args.capture_video, run_name)
for idx in range(args.n_envs)
])
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
actor = Actor(envs).to(device)
qf1 = SoftQNetwork(envs).to(device)
qf2 = SoftQNetwork(envs).to(device)
qf1_target = SoftQNetwork(envs).to(device)
qf2_target = SoftQNetwork(envs).to(device)
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
# TRY NOT TO MODIFY: eps=1e-4 increases numerical stability
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4)
# Automatic entropy tuning
if args.autotune:
target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha = log_alpha.exp().item()
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4)
else:
alpha = args.alpha
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=False,
n_envs=args.n_envs
)
start_time = time.time()
# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
actions = actions.detach().cpu().numpy()
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
# ALGO LOGIC: training.
if global_step > args.learning_starts:
if global_step % args.update_frequency == 0:
data = rb.sample(args.batch_size)
# CRITIC training
with torch.no_grad():
_, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations)
qf1_next_target = qf1_target(data.next_observations)
qf2_next_target = qf2_target(data.next_observations)
# we can use the action probabilities instead of MC sampling to estimate the expectation
min_qf_next_target = next_state_action_probs * (
torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
)
# adapt Q-target for discrete Q-function
min_qf_next_target = min_qf_next_target.sum(dim=1)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target)
# use Q-values only for the taken actions
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1)
qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss
q_optimizer.zero_grad()
qf_loss.backward()
q_optimizer.step()
# ACTOR training
_, log_pi, action_probs = actor.get_action(data.observations)
with torch.no_grad():
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
min_qf_values = torch.min(qf1_values, qf2_values)
# no need for reparameterization, the expectation can be calculated for discrete actions
actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean()
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()
if args.autotune:
# re-use action probabilities for temperature loss
alpha_loss = (action_probs.detach() * (-log_alpha.exp() * (log_pi + target_entropy).detach())).mean()
a_optimizer.zero_grad()
alpha_loss.backward()
a_optimizer.step()
alpha = log_alpha.exp().item()
# update the target networks
if global_step % args.target_network_frequency == 0:
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
if global_step % 100 == 0:
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
writer.add_scalar("losses/alpha", alpha, global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.autotune:
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)
if global_step % 200000 == 0:
path = f"/network/scratch/o/omar.younis/U3/{run_name}"
os.makedirs(path, exist_ok=True)
torch.save({
"actor": actor.state_dict(),
"qf1": qf1.state_dict(),
"qf2": qf2.state_dict(),
"qf1_target": qf1_target.state_dict(),
"qf2_target": qf2_target.state_dict(),
"actor_optimizer": actor_optimizer.state_dict(),
"q_optimizer": q_optimizer.state_dict()
}, f"{path}/model.pth")
save_to_pkl(f"{path}/rb", rb, 1)
envs.close()
writer.close()