-
Notifications
You must be signed in to change notification settings - Fork 315
/
functional.py
1432 lines (1220 loc) · 48.6 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import math
import warnings
from functools import wraps
from typing import Optional, Tuple, Union
import torch
try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling
__all__ = [
"generalized_advantage_estimate",
"vec_generalized_advantage_estimate",
"td0_advantage_estimate",
"td0_return_estimate",
"td1_return_estimate",
"vec_td1_return_estimate",
"td1_advantage_estimate",
"vec_td1_advantage_estimate",
"td_lambda_return_estimate",
"vec_td_lambda_return_estimate",
"td_lambda_advantage_estimate",
"vec_td_lambda_advantage_estimate",
"vtrace_advantage_estimate",
]
from torchrl.objectives.value.utils import (
_custom_conv1d,
_get_num_per_traj,
_inv_pad_sequence,
_make_gammas_tensor,
_split_and_pad_sequence,
)
SHAPE_ERR = (
"All input tensors (value, reward and done states) must share a unique shape."
)
def _transpose_time(fun):
"""Checks the time_dim argument of the function to allow for any dim.
If not -2, makes a transpose of all the multi-dim input tensors to bring
time at -2, and does the opposite transform for the outputs.
"""
ERROR = (
"The tensor shape and the time dimension are not compatible: "
"got {} and time_dim={}."
)
@wraps(fun)
def transposed_fun(*args, **kwargs):
time_dim = kwargs.pop("time_dim", -2)
def transpose_tensor(tensor):
if not isinstance(tensor, torch.Tensor) or tensor.numel() <= 1:
return tensor, False
if time_dim >= 0:
timedim = time_dim - tensor.ndim
else:
timedim = time_dim
if timedim < -tensor.ndim or timedim >= 0:
raise RuntimeError(ERROR.format(tensor.shape, timedim))
if tensor.ndim >= 2:
single_dim = False
tensor = tensor.transpose(timedim, -2)
elif tensor.ndim == 1 and timedim == -1:
single_dim = True
tensor = tensor.unsqueeze(-1)
else:
raise RuntimeError(ERROR.format(tensor.shape, timedim))
return tensor, single_dim
if time_dim != -2:
single_dim = False
if args:
args, single_dim = zip(*(transpose_tensor(arg) for arg in args))
single_dim = any(single_dim)
for k, item in list(kwargs.items()):
item, sd = transpose_tensor(item)
single_dim = single_dim or sd
kwargs[k] = item
# We don't pass time_dim because it isn't supposed to be used thereafter
out = fun(*args, **kwargs)
if isinstance(out, torch.Tensor):
out = transpose_tensor(out)[0]
if single_dim:
out = out.squeeze(-2)
return out
if single_dim:
return tuple(transpose_tensor(_out)[0].squeeze(-2) for _out in out)
return tuple(transpose_tensor(_out)[0] for _out in out)
# We don't pass time_dim because it isn't supposed to be used thereafter
out = fun(*args, **kwargs)
if isinstance(out, tuple):
for _out in out:
if _out.ndim < 2:
raise RuntimeError(ERROR.format(_out.shape, time_dim))
else:
if out.ndim < 2:
raise RuntimeError(ERROR.format(out.shape, time_dim))
return out
return transposed_fun
########################################################################
# GAE
# ---
@_transpose_time
def generalized_advantage_estimate(
gamma: float,
lmbda: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generalized advantage estimate of a trajectory.
Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
https://arxiv.org/pdf/1506.02438.pdf for more context.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
dtype = next_state_value.dtype
device = state_value.device
not_done = (~done).int()
not_terminated = (~terminated).int()
*batch_size, time_steps, lastdim = not_done.shape
advantage = torch.empty(
*batch_size, time_steps, lastdim, device=device, dtype=dtype
)
prev_advantage = 0
g_not_terminated = gamma * not_terminated
delta = reward + (g_not_terminated * next_state_value) - state_value
discount = lmbda * gamma * not_done
for t in reversed(range(time_steps)):
prev_advantage = advantage[..., t, :] = delta[..., t, :] + (
prev_advantage * discount[..., t, :]
)
value_target = advantage + state_value
return advantage, value_target
def _geom_series_like(t, r, thr):
"""Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`.
Drops all elements which are smaller than `thr` (unless in compile mode).
"""
if is_dynamo_compiling():
if isinstance(r, torch.Tensor):
rs = r.expand_as(t)
else:
rs = torch.full_like(t, r)
else:
if isinstance(r, torch.Tensor):
r = r.item()
if r == 0.0:
return torch.zeros_like(t)
elif r >= 1.0:
lim = t.numel()
else:
lim = int(math.log(thr) / math.log(r))
rs = torch.full_like(t[:lim], r)
rs[0] = 1.0
rs = rs.cumprod(0)
rs = rs.unsqueeze(-1)
return rs
def _fast_vec_gae(
reward: torch.Tensor,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor,
gamma: float,
lmbda: float,
thr: float = 1e-7,
):
"""Fast vectorized Generalized Advantage Estimate when gamma and lmbda are scalars.
In contrast to `vec_generalized_advantage_estimate` this function does not need
to allocate a big tensor of the form [B, T, T].
Args:
reward (torch.Tensor): a [*B, T, F] tensor containing rewards
state_value (torch.Tensor): a [*B, T, F] tensor containing state values (value function)
next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
done (torch.Tensor): a [B, T] boolean tensor containing the done states.
terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
gamma (scalar): the gamma decay (trajectory discount)
lmbda (scalar): the lambda decay (exponential mean discount)
thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
Defaults to 1e-7.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
"""
# _get_num_per_traj and _split_and_pad_sequence need
# time dimension at last position
done = done.transpose(-2, -1)
terminated = terminated.transpose(-2, -1)
reward = reward.transpose(-2, -1)
state_value = state_value.transpose(-2, -1)
next_state_value = next_state_value.transpose(-2, -1)
gammalmbda = gamma * lmbda
not_terminated = (~terminated).int()
td0 = reward + not_terminated * gamma * next_state_value - state_value
num_per_traj = _get_num_per_traj(done)
td0_flat, mask = _split_and_pad_sequence(td0, num_per_traj, return_mask=True)
gammalmbdas = _geom_series_like(td0_flat[0], gammalmbda, thr=thr)
advantage = _custom_conv1d(td0_flat.unsqueeze(1), gammalmbdas)
advantage = advantage.squeeze(1)
advantage = advantage[mask].view_as(reward)
value_target = advantage + state_value
advantage = advantage.transpose(-1, -2)
value_target = value_target.transpose(-1, -2)
return advantage, value_target
@_transpose_time
def vec_generalized_advantage_estimate(
gamma: Union[float, torch.Tensor],
lmbda: Union[float, torch.Tensor],
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Vectorized Generalized advantage estimate of a trajectory.
Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
https://arxiv.org/pdf/1506.02438.pdf for more context.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
dtype = state_value.dtype
*batch_size, time_steps, lastdim = terminated.shape
value = gamma * lmbda
if isinstance(value, torch.Tensor) and value.numel() > 1:
# create tensor while ensuring that gradients are passed
not_done = (~done).to(dtype)
gammalmbdas = not_done * value
else:
# when gamma and lmbda are scalars, use fast_vec_gae implementation
return _fast_vec_gae(
reward=reward,
state_value=state_value,
next_state_value=next_state_value,
done=done,
terminated=terminated,
gamma=gamma,
lmbda=lmbda,
)
gammalmbdas = _make_gammas_tensor(gammalmbdas, time_steps, True)
gammalmbdas = gammalmbdas.cumprod(-2)
first_below_thr = gammalmbdas < 1e-7
# if we have multiple gammas, we only want to truncate if _all_ of
# the geometric sequences fall below the threshold
first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1)
if first_below_thr.any():
first_below_thr = torch.where(first_below_thr)[0][0].item()
gammalmbdas = gammalmbdas[..., :first_below_thr, :]
not_terminated = (~terminated).to(dtype)
td0 = reward + not_terminated * gamma * next_state_value - state_value
if len(batch_size) > 1:
td0 = td0.flatten(0, len(batch_size) - 1)
elif not len(batch_size):
td0 = td0.unsqueeze(0)
td0_r = td0.transpose(-2, -1)
shapes = td0_r.shape[:2]
if lastdim != 1:
# then we flatten again the first dims and reset a singleton in between
td0_r = td0_r.flatten(0, 1).unsqueeze(1)
advantage = _custom_conv1d(td0_r, gammalmbdas)
if lastdim != 1:
advantage = advantage.squeeze(1).unflatten(0, shapes)
if len(batch_size) > 1:
advantage = advantage.unflatten(0, batch_size)
elif not len(batch_size):
advantage = advantage.squeeze(0)
advantage = advantage.transpose(-2, -1)
value_target = advantage + state_value
return advantage, value_target
########################################################################
# TD(0)
# -----
def td0_advantage_estimate(
gamma: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""TD(0) advantage estimate of a trajectory.
Also known as bootstrapped Temporal Difference or one-step return.
Args:
gamma (scalar): exponential mean discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
returns = td0_return_estimate(gamma, next_state_value, reward, terminated)
advantage = returns - state_value
return advantage
def td0_return_estimate(
gamma: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
done: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# noqa: D417
"""TD(0) discounted return estimate of a trajectory.
Also known as bootstrapped Temporal Difference or one-step return.
Args:
gamma (scalar): exponential mean discount.
next_state_value (Tensor): value function result with new_state input.
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
reward (Tensor): reward of taking actions in the environment.
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
Keyword Args:
done (Tensor): Deprecated. Use ``terminated`` instead.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if done is not None and terminated is None:
terminated = done.clone()
warnings.warn(
"done for td0_return_estimate is deprecated. Pass ``terminated`` instead."
)
if not (next_state_value.shape == reward.shape == terminated.shape):
raise RuntimeError(SHAPE_ERR)
not_terminated = (~terminated).int()
advantage = reward + gamma * not_terminated * next_state_value
return advantage
########################################################################
# TD(1)
# ----------
@_transpose_time
def td1_return_estimate(
gamma: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
*,
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(1) return estimate.
Args:
gamma (scalar): exponential mean discount.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
raise RuntimeError(SHAPE_ERR)
not_done = (~done).int()
not_terminated = (~terminated).int()
returns = torch.empty_like(next_state_value)
T = returns.shape[-2]
single_gamma = False
if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape):
single_gamma = True
gamma = torch.full_like(next_state_value, gamma)
if rolling_gamma is None:
rolling_gamma = True
elif not rolling_gamma and single_gamma:
raise RuntimeError(
"rolling_gamma=False is expected only with time-sensitive gamma values"
)
done_but_not_terminated = (done & ~terminated).int()
if rolling_gamma:
gamma = gamma * not_terminated
g = next_state_value[..., -1, :]
for i in reversed(range(T)):
# if not done (and hence not terminated), get the bootstrapped value
# if done but not terminated, get nex_val
# if terminated, take nothing (gamma = 0)
dnt = done_but_not_terminated[..., i, :]
g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
(1 - dnt) * g + dnt * next_state_value[..., i, :]
)
else:
for k in range(T):
g = 0
_gamma = gamma[..., k, :]
nd = not_terminated
_gamma = _gamma.unsqueeze(-2) * nd
for i in reversed(range(k, T)):
dnt = done_but_not_terminated[..., i, :]
g = reward[..., i, :] + _gamma[..., i, :] * (
(1 - dnt) * g + dnt * next_state_value[..., i, :]
)
returns[..., k, :] = g
return returns
def td1_advantage_estimate(
gamma: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
time_dim: int = -2,
) -> torch.Tensor:
"""TD(1) advantage estimate.
Args:
gamma (scalar): exponential mean discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
if not state_value.shape == next_state_value.shape:
raise RuntimeError("shape of state_value and next_state_value must match")
returns = td1_return_estimate(
gamma,
next_state_value,
reward,
done,
terminated=terminated,
rolling_gamma=rolling_gamma,
time_dim=time_dim,
)
advantage = returns - state_value
return advantage
@_transpose_time
def vec_td1_return_estimate(
gamma,
next_state_value,
reward,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: Optional[bool] = None,
time_dim: int = -2,
):
"""Vectorized TD(1) return estimate.
Args:
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of the gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to ``-2``.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
return vec_td_lambda_return_estimate(
gamma=gamma,
next_state_value=next_state_value,
reward=reward,
done=done,
terminated=terminated,
rolling_gamma=rolling_gamma,
lmbda=1,
time_dim=time_dim,
)
def vec_td1_advantage_estimate(
gamma,
state_value,
next_state_value,
reward,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
time_dim: int = -2,
):
"""Vectorized TD(1) advantage estimate.
Args:
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
return (
vec_td1_return_estimate(
gamma,
next_state_value,
reward,
done=done,
terminated=terminated,
rolling_gamma=rolling_gamma,
time_dim=time_dim,
)
- state_value
)
########################################################################
# TD(lambda)
# ----------
@_transpose_time
def td_lambda_return_estimate(
gamma: float,
lmbda: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
*,
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(:math:`\lambda`) return estimate.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
raise RuntimeError(SHAPE_ERR)
not_terminated = (~terminated).int()
returns = torch.empty_like(next_state_value)
next_state_value = next_state_value * not_terminated
*batch, T, lastdim = returns.shape
# if gamma is not a tensor of the same shape as other inputs, we use rolling_gamma = True
single_gamma = False
if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape):
single_gamma = True
gamma = torch.full_like(next_state_value, gamma)
single_lambda = False
if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape):
single_lambda = True
lmbda = torch.full_like(next_state_value, lmbda)
if rolling_gamma is None:
rolling_gamma = True
elif not rolling_gamma and single_gamma and single_lambda:
raise RuntimeError(
"rolling_gamma=False is expected only with time-sensitive gamma or lambda values"
)
if rolling_gamma:
g = next_state_value[..., -1, :]
for i in reversed(range(T)):
dn = done[..., i, :].int()
nv = next_state_value[..., i, :]
lmd = lmbda[..., i, :]
# if done, the bootstrapped gain is the next value, otherwise it's the
# value we computed during the previous iter
g = g * (1 - dn) + nv * dn
g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
(1 - lmd) * nv + lmd * g
)
else:
for k in range(T):
g = next_state_value[..., -1, :]
_gamma = gamma[..., k, :]
_lambda = lmbda[..., k, :]
for i in reversed(range(k, T)):
dn = done[..., i, :].int()
nv = next_state_value[..., i, :]
g = g * (1 - dn) + nv * dn
g = reward[..., i, :] + _gamma * ((1 - _lambda) * nv + _lambda * g)
returns[..., k, :] = g
return returns
def td_lambda_advantage_estimate(
gamma: float,
lmbda: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
# not a kwarg because used directly
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(:math:`\lambda`) advantage estimate.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
if not state_value.shape == next_state_value.shape:
raise RuntimeError("shape of state_value and next_state_value must match")
returns = td_lambda_return_estimate(
gamma,
lmbda,
next_state_value,
reward,
done,
terminated=terminated,
rolling_gamma=rolling_gamma,
time_dim=time_dim,
)
advantage = returns - state_value
return advantage
def _fast_td_lambda_return_estimate(
gamma: Union[torch.Tensor, float],
lmbda: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor,
thr: float = 1e-7,
):
"""Fast vectorized TD lambda return estimate.
In contrast to the generalized `vec_td_lambda_return_estimate` this function does not need
to allocate a big tensor of the form [B, T, T], but it only works with gamma/lmbda being scalars.
Args:
gamma (scalar): the gamma decay, can be a tensor with a single element (trajectory discount)
lmbda (scalar): the lambda decay (exponential mean discount)
next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
reward (torch.Tensor): a [*B, T, F] tensor containing rewards
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for end of episode.
thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
Defaults to 1e-7.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
"""
device = reward.device