-
Notifications
You must be signed in to change notification settings - Fork 0
/
dreamer_main.py
1003 lines (973 loc) · 33.8 KB
/
dreamer_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
import os
from functools import partialmethod
from math import inf
import numpy as np
import torch
from torch import nn, optim
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from torch.nn import functional as F
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
from dreamer import DreamerAgent
from env import CONTROL_SUITE_ENVS, GYM_ENVS, Env, EnvBatcher
from memory import ExperienceReplay
from models import Encoder, ObservationModel, RewardModel, TransitionModel, bottle
from utils import lineplot, write_video
# Hyperparameters
parser = argparse.ArgumentParser(description="Dreamer")
parser.add_argument("--id", type=str, default="dreamer", help="Experiment ID")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="Random seed")
parser.add_argument("--disable-cuda", action="store_true", help="Disable CUDA")
parser.add_argument(
"--env",
type=str,
default="Pendulum-v1",
choices=GYM_ENVS + CONTROL_SUITE_ENVS,
help="Gym/Control Suite environment",
)
parser.add_argument("--symbolic-env", action="store_true", help="Symbolic features")
parser.add_argument(
"--max-episode-length",
type=int,
default=1000,
metavar="T",
help="Max episode length",
)
parser.add_argument(
"--experience-size",
type=int,
default=1000000,
metavar="D",
help="Experience replay size",
) # Original implementation has an unlimited buffer size, but 1 million is the max experience collected anyway
parser.add_argument(
"--activation-function",
type=str,
default="relu",
choices=dir(F),
help="Model activation function",
)
parser.add_argument(
"--dreamer-activation-function",
type=str,
default="elu",
choices=dir(F),
help="Agent activation function",
)
parser.add_argument(
"--embedding-size",
type=int,
default=1024,
metavar="E",
help="Observation embedding size",
) # Note that the default encoder for visual observations outputs a 1024D vector; for other embedding sizes an additional fully-connected layer is used
parser.add_argument(
"--hidden-size", type=int, default=200, metavar="H", help="Hidden size"
)
parser.add_argument(
"--dreamer-hidden-size", type=int, default=300, metavar="H", help="Hidden size"
)
parser.add_argument(
"--belief-size", type=int, default=200, metavar="H", help="Belief/hidden size"
)
parser.add_argument(
"--state-size", type=int, default=30, metavar="Z", help="State/latent size"
)
parser.add_argument(
"--action-repeat", type=int, default=2, metavar="R", help="Action repeat"
)
parser.add_argument(
"--action-noise", type=float, default=0.3, metavar="ε", help="Action noise"
)
parser.add_argument(
"--episodes", type=int, default=1000, metavar="E", help="Total number of episodes"
)
parser.add_argument(
"--seed-episodes", type=int, default=5, metavar="S", help="Seed episodes"
)
parser.add_argument(
"--collect-interval", type=int, default=100, metavar="C", help="Collect interval"
)
parser.add_argument(
"--batch-size", type=int, default=50, metavar="B", help="Batch size"
)
parser.add_argument(
"--chunk-size", type=int, default=50, metavar="L", help="Chunk size"
)
parser.add_argument(
"--overshooting-distance",
type=int,
default=50,
metavar="D",
help="Latent overshooting distance/latent overshooting weight for t = 1",
)
parser.add_argument(
"--overshooting-kl-beta",
type=float,
default=0,
metavar="β>1",
help="Latent overshooting KL weight for t > 1 (0 to disable)",
)
parser.add_argument(
"--overshooting-reward-scale",
type=float,
default=0,
metavar="R>1",
help="Latent overshooting reward prediction weight for t > 1 (0 to disable)",
)
parser.add_argument(
"--global-kl-beta",
type=float,
default=0,
metavar="βg",
help="Global KL weight (0 to disable)",
)
parser.add_argument("--free-nats", type=float, default=3, metavar="F", help="Free nats")
parser.add_argument(
"--bit-depth",
type=int,
default=5,
metavar="B",
help="Image bit depth (quantisation)",
)
parser.add_argument(
"--learning-rate", type=float, default=6e-4, metavar="α", help="Learning rate"
)
parser.add_argument(
"--learning-rate-schedule",
type=int,
default=0,
metavar="αS",
help="Linear learning rate schedule (optimisation steps from 0 to final learning rate; 0 to disable)",
)
parser.add_argument(
"--adam-epsilon",
type=float,
default=1e-4,
metavar="ε",
help="Adam optimiser epsilon value",
)
# Note that original has a linear learning rate decay, but it seems unlikely that this makes a significant difference
parser.add_argument(
"--policy-lr",
type=float,
default=8e-5,
metavar="pα",
help="Learning rate for policy",
)
parser.add_argument(
"--value-lr",
type=float,
default=8e-5,
metavar="vα",
help="Learning rate for value",
)
parser.add_argument(
"--gamma",
type=float,
default=0.99,
metavar="G",
help="gamma for value estimation",
)
parser.add_argument(
"--lam",
type=float,
default=0.95,
metavar="lam",
help="lambda for value estimation",
)
parser.add_argument(
"--grad-clip-norm",
type=float,
default=100,
metavar="C",
help="Gradient clipping norm",
)
parser.add_argument(
"--horizon",
type=int,
default=15,
metavar="I",
help="Imagination horizon distance",
)
parser.add_argument("--test", action="store_true", help="Test only")
parser.add_argument("--quiet", action="store_true", help="Disable tqdm bar")
parser.add_argument("--verbose", action="store_true", help="print losses at each step")
parser.add_argument(
"--test-interval",
type=int,
default=25,
metavar="I",
help="Test interval (episodes)",
)
parser.add_argument(
"--test-episodes", type=int, default=10, metavar="E", help="Number of test episodes"
)
parser.add_argument(
"--checkpoint-interval",
type=int,
default=50,
metavar="I",
help="Checkpoint interval (episodes)",
)
parser.add_argument(
"--checkpoint-experience", action="store_true", help="Checkpoint experience replay"
)
parser.add_argument(
"--metrics", type=str, default="", metavar="M", help="Load metrics checkpoint"
)
parser.add_argument(
"--model", type=str, default="", metavar="M", help="Load model checkpoint"
)
parser.add_argument(
"--transition-model",
type=str,
default="",
metavar="M",
help="Load transition model checkpoint",
)
parser.add_argument(
"--observation-model",
type=str,
default="",
metavar="M",
help="Load observation model checkpoint",
)
parser.add_argument(
"--reward-model",
type=str,
default="",
metavar="M",
help="Load reward model checkpoint",
)
parser.add_argument(
"--encoder", type=str, default="", metavar="M", help="Load encoder checkpoint"
)
parser.add_argument(
"--optimiser", type=str, default="", metavar="M", help="Load optimiser checkpoint"
)
parser.add_argument(
"--agent-model", type=str, default="", metavar="M", help="Load agent checkpoint"
)
parser.add_argument(
"--policy-model", type=str, default="", metavar="M", help="Load policy checkpoint"
)
parser.add_argument(
"--value-model", type=str, default="", metavar="M", help="Load value checkpoint"
)
parser.add_argument(
"--experience-replay",
type=str,
default="",
metavar="ER",
help="Load experience replay",
)
parser.add_argument("--render", action="store_true", help="Render environment")
parser.add_argument("--video", action="store_true", help="Record video of environment")
parser.add_argument(
"--img-source",
type=str,
default=None,
choices=[None, "color", "noise", "images", "video"],
help="Type of dm_control background",
)
parser.add_argument(
"--resource-files",
type=str,
default="",
help="resource files for dm_control background",
)
args = parser.parse_args()
args.overshooting_distance = min(
args.chunk_size, args.overshooting_distance
) # Overshooting distance cannot be greater than chunk size
print(" " * 26 + "Options")
for k, v in vars(args).items():
print(" " * 26 + k + ": " + str(v))
# Setup
tqdm.__init__ = partialmethod(tqdm.__init__, disable=args.quiet)
results_dir = os.path.join("results", args.env, args.id)
os.makedirs(results_dir, exist_ok=True)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available() and not args.disable_cuda:
args.device = torch.device("cuda")
torch.cuda.manual_seed(args.seed)
else:
args.device = torch.device("cpu")
metrics = {
"steps": [],
"episodes": [],
"train_rewards": [],
"test_episodes": [],
"test_rewards": [],
"observation_loss": [],
"reward_loss": [],
"kl_loss": [],
"policy_loss": [],
"value_loss": [],
}
# Initialise training environment and experience replay memory
if args.env in GYM_ENVS:
env = Env(
args.env,
args.symbolic_env,
args.seed,
args.max_episode_length,
args.action_repeat,
args.bit_depth,
)
elif args.env in CONTROL_SUITE_ENVS:
env = Env(
args.env,
args.symbolic_env,
args.seed,
args.max_episode_length,
args.action_repeat,
args.bit_depth,
args.img_source,
args.resource_files,
)
if args.experience_replay != "" and os.path.exists(args.experience_replay):
D = torch.load(args.experience_replay)
metrics["steps"], metrics["episodes"] = [D.steps] * D.episodes, list(
range(1, D.episodes + 1)
)
elif not args.test:
D = ExperienceReplay(
args.experience_size,
args.symbolic_env,
env.observation_size,
env.action_size,
args.bit_depth,
args.device,
)
# Initialise dataset D with S random seed episodes
for s in range(1, args.seed_episodes + 1):
observation, done, t = env.reset(), False, 0
while not done:
action = env.sample_random_action()
next_observation, reward, done = env.step(action)
D.append(observation, action, reward, done)
observation = next_observation
t += 1
metrics["steps"].append(
t * args.action_repeat
+ (0 if len(metrics["steps"]) == 0 else metrics["steps"][-1])
)
metrics["episodes"].append(s)
# Initialise model parameters randomly
transition_model = TransitionModel(
args.belief_size,
args.state_size,
env.action_size,
args.hidden_size,
args.embedding_size,
args.activation_function,
).to(device=args.device)
observation_model = ObservationModel(
args.symbolic_env,
env.observation_size,
args.belief_size,
args.state_size,
args.embedding_size,
args.activation_function,
).to(device=args.device)
reward_model = RewardModel(
args.belief_size, args.state_size, args.hidden_size, args.activation_function
).to(device=args.device)
encoder = Encoder(
args.symbolic_env,
env.observation_size,
args.embedding_size,
args.activation_function,
).to(device=args.device)
param_list = (
list(transition_model.parameters())
+ list(observation_model.parameters())
+ list(reward_model.parameters())
+ list(encoder.parameters())
)
optimiser = optim.Adam(
param_list,
lr=0 if args.learning_rate_schedule != 0 else args.learning_rate,
eps=args.adam_epsilon,
)
if args.metrics != "" and os.path.exists(args.metrics):
metrics = torch.load(args.metrics)
if args.model != "" and os.path.exists(args.model):
model_dicts = torch.load(args.model)
transition_model.load_state_dict(model_dicts["transition_model"])
observation_model.load_state_dict(model_dicts["observation_model"])
reward_model.load_state_dict(model_dicts["reward_model"])
encoder.load_state_dict(model_dicts["encoder"])
optimiser.load_state_dict(model_dicts["optimiser"])
if args.transition_model != "" and os.path.exists(args.transition_model):
model_dicts = torch.load(args.transition_model)
transition_model.load_state_dict(model_dicts["transition_model"])
if args.observation_model != "" and os.path.exists(args.observation_model):
model_dicts = torch.load(args.observation_model)
observation_model.load_state_dict(model_dicts["observation_model"])
if args.reward_model != "" and os.path.exists(args.reward_model):
model_dicts = torch.load(args.reward_model)
reward_model.load_state_dict(model_dicts["reward_model"])
if args.encoder != "" and os.path.exists(args.encoder):
model_dicts = torch.load(args.encoder)
encoder.load_state_dict(model_dicts["encoder"])
if args.optimiser != "" and os.path.exists(args.optimiser):
model_dicts = torch.load(args.optimiser)
optimiser.load_state_dict(model_dicts["optimiser"])
agent = DreamerAgent(
env.action_size,
args.belief_size,
args.state_size,
args.dreamer_hidden_size,
args.horizon,
args.policy_lr,
args.value_lr,
args.gamma,
args.lam,
args.grad_clip_norm,
args.dreamer_activation_function,
transition_model,
reward_model,
env.action_range[0],
env.action_range[1],
args.device,
)
if args.agent_model != "" and os.path.exists(args.agent_model):
model_dicts = torch.load(args.agent_model)
agent.policy_model.load_state_dict(model_dicts["policy_model"])
agent.value_model.load_state_dict(model_dicts["value_model"])
if args.policy_model != "" and os.path.exists(args.policy_model):
model_dicts = torch.load(args.policy_model)
agent.policy_model.load_state_dict(model_dicts["policy_model"])
if args.value_model != "" and os.path.exists(args.value_model):
model_dicts = torch.load(args.value_model)
agent.value_model.load_state_dict(model_dicts["value_model"])
global_prior = Normal(
torch.zeros(args.batch_size, args.state_size, device=args.device),
torch.ones(args.batch_size, args.state_size, device=args.device),
) # Global prior N(0, I)
free_nats = torch.full(
(1,), args.free_nats, dtype=torch.float32, device=args.device
) # Allowed deviation in KL divergence
def update_belief_and_act(
args,
env,
agent,
transition_model,
encoder,
belief,
posterior_state,
action,
observation,
min_action=-inf,
max_action=inf,
explore=False,
):
# Infer belief over current state q(s_t|o≤t,a<t) from the history
belief, _, _, _, posterior_state, _, _ = transition_model(
posterior_state,
action.unsqueeze(dim=0),
belief,
encoder(observation).unsqueeze(dim=0),
) # Action and observation need extra time dimension
belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(
dim=0
) # Remove time dimension from belief/state
action = agent.act(
belief, posterior_state
) # Get action from dreamer agent(q(s_t|o≤t,a<t), p)
if explore:
action = action + args.action_noise * torch.randn_like(
action
) # Add exploration noise ε ~ p(ε) to the action
action.clamp_(min=min_action, max=max_action) # Clip action range
next_observation, reward, done = env.step(
action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu()
) # Perform environment step (action repeats handled internally)
return belief, posterior_state, action, next_observation, reward, done
# Testing only
if args.test:
# Set models to eval mode
transition_model.eval()
reward_model.eval()
encoder.eval()
with torch.no_grad():
total_reward = 0
for episode in tqdm(range(args.test_episodes)):
video_frames = []
observation = env.reset()
belief, posterior_state, action = (
torch.zeros(1, args.belief_size, device=args.device),
torch.zeros(1, args.state_size, device=args.device),
torch.zeros(1, env.action_size, device=args.device),
)
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
for t in pbar:
(
belief,
posterior_state,
action,
next_observation,
reward,
done,
) = update_belief_and_act(
args,
env,
agent,
transition_model,
encoder,
belief,
posterior_state,
action,
observation.to(device=args.device),
env.action_range[0],
env.action_range[1],
)
total_reward += reward
if (
not args.symbolic_env and args.video
): # Collect real vs. predicted frames for video
video_frames.append(
make_grid(
torch.cat(
[
observation,
observation_model(belief, posterior_state).cpu(),
],
dim=3,
)
+ 0.5,
nrow=5,
).numpy()
) # Decentre
observation = next_observation
if args.render:
env.render()
if done:
pbar.close()
break
if not args.symbolic_env and args.video:
episode_str = str(episode).zfill(len(str(args.episodes)))
write_video(
video_frames, "test_episode_%s" % episode_str, results_dir
) # Lossy compression
print("Average Reward:", total_reward / args.test_episodes)
env.close()
quit()
# Training (and testing)
for episode in tqdm(
range(metrics["episodes"][-1] + 1, args.episodes + 1),
total=args.episodes,
initial=metrics["episodes"][-1] + 1,
):
# Model fitting
losses = []
for s in tqdm(range(args.collect_interval)):
# Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
observations, actions, rewards, nonterminals = D.sample(
args.batch_size, args.chunk_size
) # Transitions start at time t = 0
# Create initial belief and state for time t = 0
init_belief, init_state = torch.zeros(
args.batch_size, args.belief_size, device=args.device
), torch.zeros(args.batch_size, args.state_size, device=args.device)
# Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
(
beliefs,
prior_states,
prior_means,
prior_std_devs,
posterior_states,
posterior_means,
posterior_std_devs,
) = transition_model(
init_state,
actions[:-1],
init_belief,
bottle(encoder, (observations[1:],)),
nonterminals[:-1],
)
# Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
observation_loss = (
F.mse_loss(
bottle(observation_model, (beliefs, posterior_states)),
observations[1:],
reduction="none",
)
.sum(dim=2 if args.symbolic_env else (2, 3, 4))
.mean(dim=(0, 1))
)
reward_loss = F.mse_loss(
bottle(reward_model, (beliefs, posterior_states)),
rewards[:-1],
reduction="none",
).mean(dim=(0, 1))
kl_loss = torch.max(
kl_divergence(
Normal(posterior_means, posterior_std_devs),
Normal(prior_means, prior_std_devs),
).sum(dim=2),
free_nats,
).mean(
dim=(0, 1)
) # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
if args.global_kl_beta != 0:
kl_loss += args.global_kl_beta * kl_divergence(
Normal(posterior_means, posterior_std_devs), global_prior
).sum(dim=2).mean(dim=(0, 1))
# Calculate latent overshooting objective for t > 0
if args.overshooting_kl_beta != 0:
overshooting_vars = (
[]
) # Collect variables for overshooting to process in batch
for t in range(1, args.chunk_size - 1):
d = min(
t + args.overshooting_distance, args.chunk_size - 1
) # Overshooting distance
t_, d_ = (
t - 1,
d - 1,
) # Use t_ and d_ to deal with different time indexing for latent states
seq_pad = (
0,
0,
0,
0,
0,
t - d + args.overshooting_distance,
) # Calculate sequence padding so overshooting terms can be calculated in one batch
# Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) posterior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks
overshooting_vars.append(
(
F.pad(actions[t:d], seq_pad),
F.pad(nonterminals[t:d], seq_pad),
F.pad(rewards[t:d], seq_pad[2:]),
beliefs[t_],
posterior_states[t_].detach(),
F.pad(posterior_means[t_ + 1 : d_ + 1].detach(), seq_pad),
F.pad(
posterior_std_devs[t_ + 1 : d_ + 1].detach(),
seq_pad,
value=1,
),
F.pad(
torch.ones(
d - t,
args.batch_size,
args.state_size,
device=args.device,
),
seq_pad,
),
)
) # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences
overshooting_vars = tuple(zip(*overshooting_vars))
# Update belief/state using prior from previous belief/state and previous action (over entire sequence at once)
beliefs, prior_states, prior_means, prior_std_devs = transition_model(
torch.cat(overshooting_vars[4], dim=0),
torch.cat(overshooting_vars[0], dim=1),
torch.cat(overshooting_vars[3], dim=0),
None,
torch.cat(overshooting_vars[1], dim=1),
)
seq_mask = torch.cat(overshooting_vars[7], dim=1)
# Calculate overshooting KL loss with sequence mask
kl_loss += (
(1 / args.overshooting_distance)
* args.overshooting_kl_beta
* torch.max(
(
kl_divergence(
Normal(
torch.cat(overshooting_vars[5], dim=1),
torch.cat(overshooting_vars[6], dim=1),
),
Normal(prior_means, prior_std_devs),
)
* seq_mask
).sum(dim=2),
free_nats,
).mean(dim=(0, 1))
* (args.chunk_size - 1)
) # Update KL loss (compensating for extra average over each overshooting/open loop sequence)
# Calculate overshooting reward prediction loss with sequence mask
if args.overshooting_reward_scale != 0:
reward_loss += (
(1 / args.overshooting_distance)
* args.overshooting_reward_scale
* F.mse_loss(
bottle(reward_model, (beliefs, prior_states))
* seq_mask[:, :, 0],
torch.cat(overshooting_vars[2], dim=1),
reduction="none",
).mean(dim=(0, 1))
* (args.chunk_size - 1)
) # Update reward loss (compensating for extra average over each overshooting/open loop sequence)
# Apply linearly ramping learning rate schedule
if args.learning_rate_schedule != 0:
for group in optimiser.param_groups:
group["lr"] = min(
group["lr"] + args.learning_rate / args.learning_rate_schedule,
args.learning_rate,
)
# Update model parameters
optimiser.zero_grad()
(observation_loss + reward_loss + kl_loss).backward()
nn.utils.clip_grad_norm_(param_list, args.grad_clip_norm, norm_type=2)
optimiser.step()
beliefs = torch.cat([init_belief.unsqueeze(dim=0), beliefs.detach()], dim=0)
posterior_states = torch.cat(
[init_state.unsqueeze(dim=0), posterior_states.detach()], dim=0
)
policy_loss, value_loss = agent.train(beliefs, posterior_states)
# Store (0) observation loss (1) reward loss (2) KL loss (3) policy loss (4) value loss
losses.append(
[
observation_loss.item(),
reward_loss.item(),
kl_loss.item(),
policy_loss.item(),
value_loss.item(),
]
)
if args.verbose:
print(
"OL",
np.mean(losses[0]),
"RL",
np.mean(losses[1]),
"KL",
np.mean(losses[2]),
"PL",
np.mean(losses[3]),
"VL",
np.mean(losses[4]),
)
# Update and plot loss metrics
losses = tuple(zip(*losses))
metrics["observation_loss"].append(losses[0])
metrics["reward_loss"].append(losses[1])
metrics["kl_loss"].append(losses[2])
metrics["policy_loss"].append(losses[3])
metrics["value_loss"].append(losses[4])
lineplot(
metrics["episodes"][-len(metrics["observation_loss"]) :],
metrics["observation_loss"],
"observation_loss",
results_dir,
)
lineplot(
metrics["episodes"][-len(metrics["reward_loss"]) :],
metrics["reward_loss"],
"reward_loss",
results_dir,
)
lineplot(
metrics["episodes"][-len(metrics["kl_loss"]) :],
metrics["kl_loss"],
"kl_loss",
results_dir,
)
lineplot(
metrics["episodes"][-len(metrics["policy_loss"]) :],
metrics["policy_loss"],
"policy_loss",
results_dir,
)
lineplot(
metrics["episodes"][-len(metrics["value_loss"]) :],
metrics["value_loss"],
"value_loss",
results_dir,
)
# Data collection
with torch.no_grad():
observation, total_reward = env.reset(), 0
belief, posterior_state, action = (
torch.zeros(1, args.belief_size, device=args.device),
torch.zeros(1, args.state_size, device=args.device),
torch.zeros(1, env.action_size, device=args.device),
)
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
for t in pbar:
(
belief,
posterior_state,
action,
next_observation,
reward,
done,
) = update_belief_and_act(
args,
env,
agent,
transition_model,
encoder,
belief,
posterior_state,
action,
observation.to(device=args.device),
env.action_range[0],
env.action_range[1],
explore=True,
)
D.append(observation, action.cpu(), reward, done)
total_reward += reward
observation = next_observation
if args.render:
env.render()
if done:
pbar.close()
break
# Update and plot train reward metrics
metrics["steps"].append(t + metrics["steps"][-1])
metrics["episodes"].append(episode)
metrics["train_rewards"].append(total_reward)
lineplot(
metrics["episodes"][-len(metrics["train_rewards"]) :],
metrics["train_rewards"],
"train_rewards",
results_dir,
)
if args.verbose:
print("Train reward", np.mean(total_reward))
# Test model
if episode % args.test_interval == 0:
# Set models to eval mode
transition_model.eval()
observation_model.eval()
reward_model.eval()
encoder.eval()
# Initialise parallelised test environments
test_envs = EnvBatcher(
Env,
(
args.env,
args.symbolic_env,
args.seed,
args.max_episode_length,
args.action_repeat,
args.bit_depth,
),
{"img_source": args.img_source, "resource_files": args.resource_files},
args.test_episodes,
)
with torch.no_grad():
observation, total_rewards, video_frames = (
test_envs.reset(),
np.zeros((args.test_episodes,)),
[],
)
belief, posterior_state, action = (
torch.zeros(args.test_episodes, args.belief_size, device=args.device),
torch.zeros(args.test_episodes, args.state_size, device=args.device),
torch.zeros(args.test_episodes, env.action_size, device=args.device),
)
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
for t in pbar:
(
belief,
posterior_state,
action,
next_observation,
reward,
done,
) = update_belief_and_act(
args,
test_envs,
agent,
transition_model,
encoder,
belief,
posterior_state,
action,
observation.to(device=args.device),
env.action_range[0],
env.action_range[1],
)
total_rewards += reward.numpy()
if (
not args.symbolic_env and args.video
): # Collect real vs. predicted frames for video
video_frames.append(
make_grid(
torch.cat(
[
observation,
observation_model(belief, posterior_state).cpu(),
],
dim=3,
)
+ 0.5,
nrow=5,
).numpy()
) # Decentre
observation = next_observation
if done.sum().item() == args.test_episodes:
pbar.close()
break
# Update and plot reward metrics (and write video if applicable) and save metrics
metrics["test_episodes"].append(episode)
metrics["test_rewards"].append(total_rewards.tolist())
if args.verbose:
print("Test reward", np.mean(total_rewards.tolist()))
lineplot(
metrics["test_episodes"],
metrics["test_rewards"],
"test_rewards",
results_dir,
)
lineplot(
np.asarray(metrics["steps"])[np.asarray(metrics["test_episodes"]) - 1],
metrics["test_rewards"],
"test_rewards_steps",
results_dir,
xaxis="step",
)
if not args.symbolic_env and args.video:
episode_str = str(episode).zfill(len(str(args.episodes)))
write_video(
video_frames, "test_episode_%s" % episode_str, results_dir
) # Lossy compression
save_image(
torch.as_tensor(video_frames[-1]),
os.path.join(results_dir, "test_episode_%s.png" % episode_str),
)
torch.save(metrics, os.path.join(results_dir, "metrics.pth"))
# Set models to train mode
transition_model.train()
observation_model.train()
reward_model.train()
encoder.train()
# Close test environments
test_envs.close()
# Checkpoint models
if episode % args.checkpoint_interval == 0:
torch.save(
{
"transition_model": transition_model.state_dict(),
"observation_model": observation_model.state_dict(),
"reward_model": reward_model.state_dict(),
"encoder": encoder.state_dict(),
"optimiser": optimiser.state_dict(),
},
os.path.join(results_dir, "models.pth"),
)
torch.save(
{
"policy_model": agent.policy_model.state_dict(),
"value_model": agent.value_model.state_dict(),
},
os.path.join(results_dir, "agent.pth"),
)
if args.checkpoint_experience:
torch.save(
D, os.path.join(results_dir, "experience.pth")
) # Warning: will fail with MemoryError with large memory sizes