-
Notifications
You must be signed in to change notification settings - Fork 0
/
TD3.py
575 lines (429 loc) · 22.2 KB
/
TD3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
"""
Twin Delayed Deep Deterministic Policy Gradient (TD3) Algorithm.
Reference:
- [TD3] Addressing Function Approximation Error in Actor-Critic Methods: http://proceedings.mlr.press/v80/fujimoto18a.html
- OpenAI Spinning Up implemented TD3 algorithm: https://spinningup.openai.com/en/latest/algorithms/td3.html
"""
import numpy as np
from pysc2.lib import actions as ACTIONS
import torch
import torch.nn as nn
from torch.nn.functional import gumbel_softmax
from torch.distributions import Categorical
import os
import copy
from utils.ReplayBuffer import DDPGReplayBuffer
class TD3:
"""
The Twin Delayed Deep Deterministic Policy Gradient (TD3) Agent.
"""
# TODO: abstract a base class
def __init__(self, env, actor, critic1, critic2, replay_buffer_size=10000, device=None, actor_lr=0.001,
critic_lr=0.001, gamma=0.99, tau=0.005, target_noise=0.2, noise_clip=0.5, batch_size=32,
warmup_steps=1000, update_steps=50, actor_delay_steps=2, soft_update_steps=10, map_size=64, seed=0,
action_space=len(ACTIONS.FUNCTIONS), save_path="./Saves/", model_name='UnnamedModel', save_epochs=100):
"""
Initialization.
:param env: the SC2Env environment instance
:param actor: the actor network
:param critic1: the critic network, q1
:param critic2: the critic network, q2
:param replay_buffer_size: the capacity of the replay buffer
:param device: the training device
:param actor_lr: the learning rate of the actor model
:param critic_lr: the learning rate of the critic model
:param gamma: the discount factor
:param tau: the soft update factor
:param target_noise: standard deviation for smoothing noise added to target policy.
:param noise_clip: limit for absolute value of target policy smoothing noise.
:param batch_size: batch size
:param warmup_steps: warm up steps
:param update_steps: how many environment steps to update how many steps (two meanings)
:param actor_delay_steps: how many steps to ascent the gradient of actor after the critics step
:param soft_update_steps: how many steps to update the target networks softly using tau
:param map_size: the size of the map
:param seed: a random seed
:param action_space: the length of the action space
:param save_path: the location to save the model
:param model_name:a name for the training model
:param save_epochs: how many epochs to save one check point
"""
# set random seed
torch.manual_seed(seed)
np.random.seed(seed)
# initialize the environment
self.env = env
# initialize the device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
# initialize the policy actor and the target actor
self.actor = actor.to(self.device)
self.target_actor = copy.deepcopy(actor).to(self.device)
# * initialize the policy critics (Q1 and Q2) and the target critic
self.critic_1 = critic1.to(self.device)
self.critic_2 = critic2.to(self.device)
self.target_critic_1 = copy.deepcopy(critic1).to(self.device)
self.target_critic_2 = copy.deepcopy(critic2).to(self.device)
# the target networks are used to evaluate
self.target_actor.eval()
self.target_critic_1.eval()
self.target_critic_2.eval()
# * Freeze target networks with respect to optimizers (only update via soft updating)
for p in self.target_critic_1.parameters():
p.requires_grad = False
for p in self.target_critic_2.parameters():
p.requires_grad = False
for p in self.target_actor.parameters():
p.requires_grad = False
# * initialize three optimizers using Adam
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), actor_lr)
self.critic_optimizer_1 = torch.optim.Adam(self.critic_1.parameters(), critic_lr)
self.critic_optimizer_2 = torch.optim.Adam(self.critic_2.parameters(), critic_lr)
# to record how many iterations
self.iteration = 0
# ! how many environment-steps to update AND how many times to update
self.update_steps = update_steps
# ! how many update-steps to delayed updating the actor
self.actor_delay_steps = actor_delay_steps
# initialize the replay buffer
self.replay_buffer = DDPGReplayBuffer(replay_buffer_size)
# specify the action space, a number to identify the number of possible actions
self.action_space = action_space
# for the map size
self.map_size = map_size
# for the batch size of training
self.batch_size = batch_size
# for the discount factor
self.gamma = gamma
# TODO: this may not work!
# for the target_noise
self.target_noise = target_noise
self.noise_clip = noise_clip
# for the soft update
self.tau = tau
self.soft_update_steps = soft_update_steps
# for training warm-up steps
self.warmup_steps = warmup_steps
# model and training information saved location
self.model_name = model_name
self.save_path = os.path.join(save_path, model_name)
os.makedirs(self.save_path, exist_ok=True)
self.check_point_save_epochs = save_epochs
# to record the epoch cumulative rewards
self.epoch_rewards = []
def _state_2_obs_np(self, state):
"""
The helper to transform a SC2Env-returned state to an obs dict with numpy values.
:param state: a state from SC2Env
:return: an obs dict for agent with numpy value
"""
avail_actions = np.zeros((self.action_space,), dtype='float32')
avail_actions[state.observation['available_actions']] = 1
obs_np = {'minimap': state.observation.feature_minimap,
'screen': state.observation.feature_screen,
'non_spatial': avail_actions}
assert obs_np['minimap'].shape == (11, 64, 64), "obs_np.minimap is in the wrong shape"
assert obs_np['screen'].shape == (27, 64, 64), "obs_np.screen is in the wrong shape"
assert obs_np['non_spatial'].shape == (573,), "obs_np.non_spatial is in the wrong shape"
return obs_np
def _function_call_2_action_np(self, function_call):
"""
The helper to transform a SC2Env-interacted FunctionCall to an action dict with numpy values.
:param function_call: an SC2Env-interacted FunctionCall instance
:return: an action dict with numpy values
"""
function_id = np.zeros(shape=(self.action_space,), dtype='float32')
function_id[function_call.function] = 1
coordinates = [np.zeros(shape=(1, 64, 64), dtype='float32')] * 2
c_i = 0
for arg in function_call.arguments:
if arg != [0]:
coordinates[c_i][0, arg[0], arg[1]] = 1
c_i += 1
action_np = {'function_id': function_id,
'coordinate1': coordinates[0],
'coordinate2': coordinates[1]}
assert action_np['function_id'].shape == (573,), "action_np.function_id is in the wrong shape"
assert action_np['coordinate1'].shape == (1, 64, 64), "action_np.coordinate1 is in the wrong shape"
assert action_np['coordinate2'].shape == (1, 64, 64), "action_np.coordinate2 is in the wrong shape"
return action_np
def _action_ts_2_function_call(self, action_ts, available_actions):
"""
The helper to transform an action dict with tensor values to FunctionCall to interact with SC2Env.
:param action_ts: an action dict with tensor values
:param available_actions: the available actions for this step
:return: an SC2Env-interacted FunctionCall instance
"""
probable_function_id = nn.Softmax(dim=-1)(action_ts['function_id']).detach()
probable_function_id = probable_function_id * available_actions
# * make sure the sum of all probabilities is 1
if probable_function_id.sum(1) == 0:
# ? select an available_action uniformly
# ? BUT, why the probabilities are all 0?
distribution = Categorical(available_actions)
else:
distribution = Categorical(probable_function_id)
# sample the function id from distribution
function_id = distribution.sample().item()
# sample the coordinates
coordinate_position1 = nn.Softmax(dim=-1)(action_ts['coordinate1'].view(1, -1)).detach()
coordinate_position1 = Categorical(coordinate_position1).sample().item()
coordinate_position2 = nn.Softmax(dim=-1)(action_ts['coordinate2'].view(1, -1)).detach()
coordinate_position2 = Categorical(coordinate_position2).sample().item()
positions = [[int(coordinate_position1 % self.map_size), int(coordinate_position1 // self.map_size)],
[int(coordinate_position2 % self.map_size), int(coordinate_position2 // self.map_size)]]
# put the position coordinates into the args list
args = []
number_of_arg = 0
for arg in ACTIONS.FUNCTIONS[function_id].args:
if arg.name in ['screen', 'screen2', 'minimap']:
args.append(positions[number_of_arg])
number_of_arg += 1
else:
# for now, for other kinds of arguments (such as `queued`), give [0] by default
args.append([0])
return ACTIONS.FunctionCall(function_id, args)
def _obs_np_2_obs_ts(self, obs_np):
"""
The helper to transform an obs dict with numpy values to an obs dict with tensor values and move to the device.
:param obs_np: an obs dict with numpy value
:return: an obs dict with tensor value
"""
obs_ts = {}
for key in obs_np.keys():
x = obs_np[key].astype('float32')
x = np.expand_dims(x, 0)
obs_ts[key] = torch.from_numpy(x).to(self.device)
return obs_ts
def _gumbel_softmax(self, x):
"""
A helper to do gumbel softmax
:param x: input vector
:return: output normalized vector
"""
shape = x.shape
if len(shape) == 4:
x_reshape = x.contiguous().view(shape[0], -1)
y = gumbel_softmax(x_reshape, hard=True, dim=-1)
y = y.contiguous().view(shape)
else:
y = gumbel_softmax(x, hard=True, dim=-1)
return y
def sample_batch(self):
"""
sample a batch of trajectories from the replay buffer and transform them into torch tensors.
:return: a batch of samples
"""
transitions = self.replay_buffer.sample(self.batch_size)
obs_ts = {'minimap': [], 'screen': [], 'non_spatial': []}
action_ts = {'function_id': [], 'coordinate1': [], 'coordinate2': []}
reward_ts = []
obs_next_ts = {'minimap': [], 'screen': [], 'non_spatial': []}
done_ts = []
for transition in transitions:
for key, value in transition.obs.items():
value = torch.as_tensor(value, dtype=torch.float32)
obs_ts[key].append(value)
for key, value in transition.action.items():
value = torch.as_tensor(value, dtype=torch.float32)
action_ts[key].append(value)
for key, value in transition.obs_next.items():
value = torch.as_tensor(value, dtype=torch.float32)
obs_next_ts[key].append(value)
reward_ts.append(torch.as_tensor(transition.reward, dtype=torch.float32))
done_ts.append(torch.as_tensor(1 if transition.done else 0, dtype=torch.float32))
for key in obs_ts.keys():
obs_ts[key] = torch.stack(obs_ts[key], dim=0).to(self.device)
obs_next_ts[key] = torch.stack(obs_next_ts[key], dim=0).to(self.device)
for key in action_ts.keys():
action_ts[key] = torch.stack(action_ts[key], dim=0).to(self.device)
reward_ts = torch.tensor(reward_ts).to(self.device)
done_ts = torch.tensor(done_ts).to(self.device)
return obs_ts, action_ts, obs_next_ts, reward_ts, done_ts
def select_action_from_obs_np(self, obs_np):
"""
The function to return a FunctionCall with arguments given a state from SC2Env
:param state: a state from SC2Env
:return: a pysc2 FunctionCall as an action
"""
obs_ts = self._obs_np_2_obs_ts(obs_np)
# return logit-actions from the actor
action_ts = self.actor(obs_ts)
available_actions_now = obs_ts['non_spatial']
function_call = self._action_ts_2_function_call(action_ts, available_actions_now)
return function_call
def soft_update(self, target, source, tau):
"""
The function to implement the soft update, for hard update, set tau = 1
:param target: target network
:param source: original training network
:param tau: soft update factor, usually near to 0
"""
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1 - tau) + source_param.data * tau)
def optimize(self, iteration):
"""
Sample a batch of data from the replay buffer and don the optimization
:param iteration: the counter for the time of training in this batch of optimization
"""
obs_0_ts, action_0_ts, obs_1_ts, reward, done = self.sample_batch()
# ! optimize the critics
self.critic_optimizer_1.zero_grad()
self.critic_optimizer_2.zero_grad()
q_value_1 = self.critic_1(obs_0_ts, action_0_ts)
q_value_2 = self.critic_2(obs_0_ts, action_0_ts)
with torch.no_grad():
actor_action_by_obs_1 = self.target_actor(obs_1_ts)
action_1_ts = {}
for key, value in actor_action_by_obs_1.items():
action_1_ts[key] = self._gumbel_softmax(value)
# TODO: target policy smoothing
# epsilon = torch.randn_like(action_1_ts) * self.target_noise
# epsilon = torch.clamp(epsilon, -self.noise_clip, self.noise_clip)
#
# action_1_ts = action_1_ts + epsilon
# action_1_ts = torch.clamp(action_1_ts, -1, 1)
# * compute target Q-values
q_value_1_target = self.target_critic_1(obs_1_ts, action_1_ts)
q_value_2_target = self.target_critic_2(obs_1_ts, action_1_ts)
q_value_target = torch.min(q_value_1_target, q_value_2_target)
# ? q_value_target = torch.squeeze(q_value_2_target)
backup = reward + self.gamma * (1 - done) * q_value_target
q_1_loss = ((q_value_1 - backup) ** 2).mean()
q_2_loss = ((q_value_2 - backup) ** 2).mean()
# ? original one:
# q_1_loss = nn.SmoothL1Loss()(q_value_1, backup)
# q_2_loss = nn.SmoothL1Loss()(q_value_2, backup)
loss_critic = q_1_loss + q_2_loss
self.critic_optimizer_1.zero_grad()
self.critic_optimizer_2.zero_grad()
loss_critic.backward()
nn.utils.clip_grad_norm_(self.critic_1.parameters(), 0.5)
nn.utils.clip_grad_norm_(self.critic_2.parameters(), 0.5)
self.critic_optimizer_1.step()
self.critic_optimizer_2.step()
# ! optimize the actor
# * a delay update of the actor
if iteration % self.actor_delay_steps == 0:
# * Freeze Q-networks so you don't waste computational effort
# * computing gradients for them during the policy learning step.
for p in self.critic_1.parameters():
p.requires_grad = False
for p in self.critic_2.parameters():
p.requires_grad = False
actor_action_by_obs_0 = self.actor(obs_0_ts)
action_0_predicted_ts = {}
for key, value in actor_action_by_obs_0.items():
action_0_predicted_ts[key] = self._gumbel_softmax(value)
q_1_max = self.critic_1(obs_0_ts, action_0_predicted_ts)
q_1_max = -1 * q_1_max.mean()
l2_reg = torch.FloatTensor(1).to(self.device)
for param in self.actor.parameters():
l2_reg = l2_reg + param.norm(2)
loss_actor = q_1_max + torch.squeeze(l2_reg) * 1e-3
self.actor_optimizer.zero_grad()
loss_actor.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optimizer.step()
for p in self.critic_1.parameters():
p.requires_grad = True
for p in self.critic_2.parameters():
p.requires_grad = True
# update the target network
if iteration % self.soft_update_steps == 0:
self.soft_update(self.target_actor, self.actor, self.tau)
self.soft_update(self.target_critic_1, self.critic_1, self.tau)
self.soft_update(self.target_critic_2, self.critic_2, self.tau)
def learn(self, epochs=1000):
"""
The function to do training.
:param epochs: number of epochs to train
"""
best_epoch_reward = 0
best_epoch_reward_time = 0
for epoch in range(epochs):
state = self.env.reset()[0]
obs_np = self._state_2_obs_np(state)
while True:
function_call = self.select_action_from_obs_np(obs_np)
state_next = self.env.step(actions=[function_call])[0]
# TODO: may need an action noise
action_np = self._function_call_2_action_np(function_call)
obs_next_np = self._state_2_obs_np(state_next)
self.replay_buffer.store(obs_np, action_np, obs_next_np, state.reward, state_next.last())
self.iteration += 1
if self.iteration > self.warmup_steps:
# ! THIS IS THE OVERALL CONDITION TO START TRAINING
# * optimize the model after the warmup steps
if self.iteration % self.update_steps == 0:
for opti in range(self.update_steps):
self.optimize(opti)
if state_next.last():
epoch_reward = state_next.observation['score_cumulative'][0]
self.epoch_rewards.append(epoch_reward)
# save the best model
if epoch_reward > best_epoch_reward:
self.save_models(token='best')
best_epoch_reward = epoch_reward
best_epoch_reward_time = 1
elif epoch_reward == best_epoch_reward:
best_epoch_reward_time += 1
print("Epoch: \033[34m{}\033[0m, epoch rewards:\033[32m{}\033[0m, best rewards: \033[35m{}\033[0m "
"with \033[33m{}\033[0m times".format(epoch + 1,
epoch_reward,
best_epoch_reward,
best_epoch_reward_time))
break
else:
obs_np = copy.deepcopy(obs_next_np)
state = copy.deepcopy(state_next)
# save a check-point
if (epoch + 1) % self.check_point_save_epochs == 0:
self.save_models(token="{}".format(epoch + 1))
# before the training completed, update the target networks once again.
self.soft_update(self.target_actor, self.actor, self.tau)
self.soft_update(self.target_critic_1, self.critic_1, self.tau)
self.soft_update(self.target_critic_2, self.critic_2, self.tau)
self.save_models(token='final')
self.env.close()
print("Training Completed!")
def save_models(self, token=''):
"""
The function to save the target actor and critic networks
:param token: a token to identify the model
"""
save_path = os.path.join(self.save_path, token)
os.makedirs(save_path, exist_ok=True)
torch.save(self.target_actor.state_dict(), os.path.join(save_path, 'actor.pt'))
torch.save(self.target_critic_1.state_dict(), os.path.join(save_path, 'critic_1.pt'))
torch.save(self.target_critic_2.state_dict(), os.path.join(save_path, 'critic_2.pt'))
np.save(os.path.join(save_path, "epoch_rewards.npy"), self.epoch_rewards)
print('Model and Information with token-{} saved successfully'.format(token))
def load_models(self, token=''):
"""
The function to load the target actor and critic networks, and copy them onto actor and critic networks
:param token: the token to identify the model
"""
model_path = os.path.join(self.save_path, token)
self.actor.load_state_dict(torch.load(os.path.join(model_path, 'actor.pt')))
self.critic_1.load_state_dict(torch.load(os.path.join(model_path, 'critic_1.pt')))
self.critic_2.load_state_dict(torch.load(os.path.join(model_path, 'critic_2.pt')))
self.soft_update(self.target_actor, self.actor, 1)
self.soft_update(self.target_critic_1, self.critic_1, 1)
self.soft_update(self.target_critic_2, self.critic_2, 1)
print('Models loaded successfully')
def restore(self, token, episodes=1000, restore_token=1):
"""
The function to restore the training.
:param token: the token to identify the model
:param episodes: number of episodes to continue training
:param restore_token: a token to identify the number of restores, if it (N) larger than 1,
then load the model from "self.save_path/restore-(N-1)/token/"
"""
assert isinstance(restore_token, int), "the restore_token parameter is NOT an int value"
if restore_token == 1:
self.load_models(token)
else:
self.load_models(os.path.join("restore-{}".format(restore_token - 1), token))
# change the save_path to a new folder
self.save_path = os.path.join(self.save_path, "restore-{}".format(restore_token))
self.learn(episodes)