forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LegacyBatchingRegistrations.cpp
1294 lines (1171 loc) · 59.3 KB
/
LegacyBatchingRegistrations.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <torch/library.h>
#include <ATen/ATen.h>
#include <ATen/LegacyVmapTransforms.h>
#include <ATen/LegacyBatchedFallback.h>
#include <ATen/RedispatchFunctions.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
#include <c10/core/SymIntArrayRef.h>
#include <utility>
namespace at {
// NOTE: [What is a batching rule?]
//
// A *batching rule* implements the logic of how to call an operator on inputs
// that have zero or more additional batch dimensions. When one does a vmap, the
// dimension(s) being vmap'ed over get recorded as batch dimensions.
//
// For example, vmap(torch.add)(x, y)
// 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
// 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
// 3. and then runs `torch.add(batched_x, batched_y)`.
// NOTE: [When should I add a batching rule?]
// When you are adding a new operator, you'll need to add a batching rule so
// that vmap can work efficiently with said operator. If you do not, we'll attempt
// to generate a slow fallback for the batching rule.
// NOTE: [How to write batching rules?]
// The signature of a batching rule should look like exactly like the C++ signature
// of its operator.
//
// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
//
// At a high level, what a batching rule does is the following:
// 1. Converts (logical) BatchedTensors to views on physical tensors.
// 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
// arguments that correspond to the physical tensors.
// 3. Calls at:: operations on the physical tensors and arguments to produce
// some physical results.
// 4. Converts physical results back to BatchedTensors.
//
// Steps 1, 2, and 4 differ for operators with different batching behaviors. When
// writing a new batching rule, please select a VmapTransform that matches the
// batching behavior of your operation. The VmapTransform provides helper functions
// to do steps (1), (2), and (4).
// (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
// Note: [Future plans]
// The API for writing a batching rule isn't stable. In the future, we'd like
// to think about the problem of translating these batching rules to TorchScript.
// Ideally batching rules in eager mode vs TorchScript would look pretty similar,
// if not use the same mechanism. In order to accomplish that we might have to
// do some refactoring.
namespace{
// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
return dim == 0 || dim == -1;
}
Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional<ScalarType> dtype) {
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
// and instead returns a new scalar tensor (this also happens for dim=-1)
// If the following happens:
// >>> x = torch.randn(B0) # the per-examples are all scalars
// >>> vmap(partial(torch.sum, dim=0), x)
// then we replicate the behavior of sum(scalar_tensor, dim=0).
if (/*logical*/self.dim() == 0 && (dims.empty() || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
return self.clone();
}
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dims_physical = self_physical.getPhysicalDims(opt_dims);
auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
if (logical_tensor.dim() > 0) {
return false;
}
auto* batched = maybeGetBatchedImpl(logical_tensor);
if (batched) {
return false;
}
return true;
}
template <typename F, F Func, typename... ExtraArgs>
Tensor binary_pointwise_batching_rule(
const Tensor& self, const Tensor& other, ExtraArgs... args) {
if (self.dim() > 0 && other.dim() > 0) {
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
return physical_args[0].getPhysicalToLogicalMap().apply(result);
}
if (isPhysicalScalarTensor(self)) {
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = Func(self, other_physical.tensor(), args...);
return other_physical.getPhysicalToLogicalMap().apply(result);
}
if (isPhysicalScalarTensor(other)) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = Func(self_physical.tensor(), other, args...);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
// At this point, we know at least one of the operands is a logical Scalar tensor.
// Here we must emulate TensorIterator's special behavior on Scalars.
//
// As a motivating example, consider the following:
// x = torch.randn(3, 10)
// y = torch.randn(3, dtype=torch.double)
// vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
//
// At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
// Type Promotion dictates that the result should be FloatTensor[10].
// This means we cannot directly pass the physical tensors (x and y) to
// TensorIterator (if we did, it would promote them to DoubleTensor).
//
// FIXME(rzou): I didn't want to go down the slippery slope of emulating
// everything TensorIterator does (it would be better to refactor out the
// TensorIterator logic). The one thing that this code doesn't handle
// is cross-device logical scalar tensors.
// cpu_tensor = torch.randn(3)
// cuda_tensor = torch.randn(3, 10, device='cuda')
// vmap(torch.mul)(cpu_tensor, cuda_tensor)
//
// At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
// TensorIterator allows for this cross-device operation because one of the
// tensors is a Scalar CPU tensor. However, the following code will throw an
// error in that case. I don't expect to see many use cases for this, so
// this is probably fine as-is.
auto logical_self = self;
auto logical_other = other;
auto result_type = at::native::result_type(logical_self, logical_other);
if (logical_self.scalar_type() != result_type) {
logical_self = logical_self.to(result_type);
}
if (logical_other.scalar_type() != result_type) {
logical_other = logical_other.to(result_type);
}
auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
{std::move(logical_self), std::move(logical_other)});
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
return physical_args[0].getPhysicalToLogicalMap().apply(result);
}
Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto size_physical = self_physical.getPhysicalShape(size);
auto self_physical_dim = self_physical.tensor().dim();
TORCH_CHECK(self_physical_dim <= static_cast<int64_t>(size_physical.size()),
"expand: the number of sizes provided (", /*logical*/size.size(), ") ",
"must be greater or equal to the number of dimensions in the tensor (",
/*logical dim*/self.dim(), ")");
if (self_physical_dim == static_cast<int64_t>(size_physical.size())) {
auto result = self_physical.tensor().expand(size_physical, implicit);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
TORCH_INTERNAL_ASSERT(self_physical_dim < static_cast<int64_t>(size_physical.size()));
// Here, we know we are expanding a (logical) tensor to a larger number
// of dimensions. We have to be careful because we can't call expand directly
// due to the presence of batch dimensions.
//
// As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
// The result should be a tensor of size [B0, 2, 3].
// A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
// so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
// then expand.
auto self_physical_size = self_physical.tensor().sizes();
auto extra_dims = size_physical.size() - self_physical_dim;
VmapDimVector view_shape(size_physical.size(), 1);
std::copy(self_physical_size.begin(),
self_physical_size.begin() + self_physical.numBatchDims(),
view_shape.begin());
std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
self_physical_size.end(),
view_shape.begin() + self_physical.numBatchDims() + extra_dims);
auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::chunk(self_physical.tensor(), chunks, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
Tensor clamp_batching_rule(const Tensor& self, const optional<Scalar>& min, const optional<Scalar>& max) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::clamp(self_physical.tensor(), min, max);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor clamp_min_batching_rule(const Tensor& self, const Scalar& min) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::clamp_min(self_physical.tensor(), min);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor clamp_max_batching_rule(const Tensor& self, const Scalar& max) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::clamp_max(self_physical.tensor(), max);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
// NB: unsqueeze has some special handling of its `dim` argument so we can't call
// self_physical.getPhysicalDim directly. In particular, native::unsqueeze
// wraps the dim to (the logical dimension) + 1, so we need to do that here too.
// https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
auto dim_physical =
self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
auto result = self_physical.tensor().unsqueeze(dim_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor& fill_inplace_scalar_batching_rule(Tensor& self, const Scalar& value) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
self_physical.tensor().fill_(value);
return self;
}
Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
auto value_batched = isBatchedTensor(value);
if (value_batched) {
auto physical_args =
BroadcastingVmapTransform::logicalToPhysical({self, value});
physical_args[0].tensor().copy_(physical_args[1].tensor());
} else {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
self_physical.tensor().fill_(value);
}
return self;
}
Tensor& zero_inplace_batching_rule(Tensor &self) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
self_physical.tensor().zero_();
return self;
}
Tensor squeeze_batching_rule(const Tensor& self) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto physical_sizes = self_physical.tensor().sizes();
// Don't squeeze the batch dims!
VmapDimVector squeezed_sizes;
int64_t num_batch_dims = self_physical.numBatchDims();
squeezed_sizes.insert(
squeezed_sizes.end(),
physical_sizes.begin(),
physical_sizes.begin() + num_batch_dims);
for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
if (*it != 1) {
squeezed_sizes.push_back(*it);
}
}
auto result = self_physical.tensor().view(squeezed_sizes);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().squeeze(dim_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dims_physical = self_physical.getPhysicalDims(dims);
auto result = self_physical.tensor().squeeze(dims_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor trace_batching_rule(const Tensor& self) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
// Batched Diagonal View
auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
auto result = at::sum(self_diag, -1);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
// Batched Diagonal View
auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
// Append a dimension of size one to the grad output
auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
grad_input_diag.copy_(grad_physical_tensor);
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
}
Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
// PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
// for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
// >>> x = torch.randn(B0) # the per-examples are all scalars
// >>> vmap(lambda x: x.transpose(0, -1), x)
// then we replicate this behavior.
if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) &&
is_allowed_dim_on_scalar_tensor(dim1)) {
return self;
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim0_physical = self_physical.getPhysicalDim(dim0);
auto dim1_physical = self_physical.getPhysicalDim(dim1);
auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dims_physical = self_physical.getPhysicalDims(dims);
VmapDimVector all_dims_physical;
all_dims_physical.reserve(self_physical.tensor().dim());
for (const auto bdim : c10::irange(self_physical.numBatchDims())) {
all_dims_physical.push_back(bdim);
}
all_dims_physical.insert(
all_dims_physical.end(),
dims_physical.begin(),
dims_physical.end());
auto result = self_physical.tensor().permute(all_dims_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().select(dim_physical, index);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
}
Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
}
Tensor slice_batching_rule(
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor());
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
}
Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim1_physical = self_physical.getPhysicalDim(dim1);
auto dim2_physical = self_physical.getPhysicalDim(dim2);
auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
}
Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto source_physical = self_physical.getPhysicalDims(source);
auto destination_physical = self_physical.getPhysicalDims(destination);
auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto shape_physical = self_physical.getPhysicalShape(shape);
auto result = self_physical.tensor().reshape(shape_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::split(self_physical.tensor(), split_size, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::unbind(self_physical.tensor(), dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().unfold(dim_physical, size, step);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
"NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
"than torch.contiguous_format");
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = physical_view.tensor().contiguous(memory_format);
return physical_view.getPhysicalToLogicalMap().apply(result);
}
Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto size_physical = self_physical.getPhysicalShape(size);
auto result = self_physical.tensor().view(size_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor view_as_complex_batching_rule(const Tensor& self) {
// guard against the user passing in a batch of scalar tensors with batch
// size equal to 2.
TORCH_CHECK(!self.sizes().empty(), "Input tensor must have one or more dimensions");
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::view_as_complex(self_physical.tensor());
return self_physical.getPhysicalToLogicalMap().apply(result);
}
// Checks that the smallest batch stride is greater than the largest example
// stride. This is something we can support but we choose not to because it's
// potentially error prone.
static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) {
auto smallest_batch_stride = std::min_element(
physical_strides.begin(), physical_strides.begin() + num_batch_dims);
auto largest_example_stride = std::max_element(
physical_strides.begin() + num_batch_dims, physical_strides.end());
if (largest_example_stride == physical_strides.end()) {
// No example dimensions
return;
}
TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
"vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
"vmapped over are at the front of the tensor (in memory layout). When they are ",
"not at the front of the tensor this operation can be error prone so we "
"actively discourage it; please file us a bug report and/or try to ",
"express the as_strided operation in terms of PyTorch view operations");
}
// given (sizes, strides, storage_offset) returns the maximum location that
// can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
// with zero-size dims).
static optional<int64_t> maximum_indexable_location(
IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
auto result = native::storage_size_for(sizes, strides);
if (result == 0) {
return nullopt;
}
return result + storage_offset;
}
// Let x be the "first slice" of physical_tensor.
// This checks that the range of possible memory locations accessible by
// x.as_strided(sizes, strides, maybe_storage_offset)
// are within the bounds of possible memory locations accessible by x.
static void checkBasicAsStridedValidForSlice(
const Tensor& physical_tensor,
int64_t num_batch_dims,
IntArrayRef sizes,
IntArrayRef strides,
optional<int64_t> maybe_storage_offset) {
auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims);
auto slice_strides = physical_tensor.strides().slice(num_batch_dims);
auto base_offset = physical_tensor.storage_offset();
auto storage_offset = maybe_storage_offset.value_or(base_offset);
auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
if (!max_as_strided_loc.has_value()) {
return;
}
if (!max_slice_loc.has_value()) {
TORCH_CHECK(false,
"result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
"can access memory outside of `tensor`. `tensor` has no storage but the ",
"passed-in (size, stride, storage_offset) imply a result with some storage. ",
"This is not supported inside of vmap, please try to rewrite the ",
"`as_strided` call as a sequence of PyTorch view operations");
}
TORCH_CHECK(
*max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
"result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
"can access memory outside of `tensor`. `result` can access some",
"memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
"`tensor` can only access some memory in range [", base_offset, ", ",
*max_slice_loc, "]. This is not supported inside of vmap, please try to",
"rewrite the `as_strided` call as a sequence of PyTorch view operations");
}
Tensor _reshape_alias_batching_rule(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) {
return reshape_batching_rule(self, sizes);
}
Tensor _new_zeros_with_same_feature_meta_batching_rule(
const Tensor& self,
const Tensor& other,
int64_t unused_num_batch_dims) {
TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
"Only the 'batched grad' use case is supported in PyTorch core.");
TORCH_INTERNAL_ASSERT(unused_num_batch_dims == 0,
"num_batch_dims should not be explicitly passed in because it will be overridden");
auto self_physical_view = at::MultiBatchVmapTransform::logicalToPhysical(self);
const auto& self_physical_tensor = self_physical_view.tensor();
int64_t num_batch_dims = self_physical_view.numBatchDims();
checkBatchDimsAtFrontInLayout(self_physical_tensor.strides(), num_batch_dims);
auto result = at::_new_zeros_with_same_feature_meta(self_physical_tensor, other, num_batch_dims);
return self_physical_view.getPhysicalToLogicalMap().apply(result);
}
bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) {
TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
"Only the 'batched grad' use case is supported in PyTorch core.");
// The _has_same_storage_numel check is skipped if the tangent is a batched
// tensor because using as_strided to access storage locations not indexable
// by the input tensor is not supported in vmap
return true;
}
// What are the semantics of as_strided inside of vmap?
// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
// This returns a view on `x`, `y`, such that each y[i] has:
// - sizes: `sizes`
// - strides: `strides`
// - storage_offset: offset + i * x.stride(batch_dim)
//
// In other words, it is as if we had treated each x[i] as having storage
// offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
// (that is equivalent to x[i].as_strided(
// sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
//
// Note that this *may* be different from actually running as_strided
// in a for-loop. This is due to how as_strided takes in `offset` to be
// an *absolute* offset. As an example, consider:
// >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
// >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
// Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
// However, we consider the above for-loop comprehension to be a user error:
// a user should have written the following if they wanted to use as_strided
// in a per-sample way:
// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
Tensor as_strided_batching_rule(
const Tensor& tensor,
IntArrayRef sizes,
IntArrayRef strides,
optional<int64_t> storage_offset) {
auto physical_view = at::MultiBatchVmapTransform::logicalToPhysical(tensor);
auto num_batch_dims = physical_view.numBatchDims();
auto physical_sizes = physical_view.getPhysicalShape(sizes);
const auto& physical_tensor = physical_view.tensor();
// We can't rely on the physical as_strided call to do this for us because
// we do some sanity checks on the size/strides before calling into as_strided.
TORCH_CHECK(sizes.size() == strides.size(),
"Tensor.as_strided(size, stride, ...): size and stride must have the ",
"same length! Got size ", sizes, " and stride ", strides);
// Sanity checks:
// 1. All batch dims are at the front in memory layout (not necessary for
// correctness, but we are worried the user might be doing crazy things)
// 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
// is valid for a slice of the input tensor.
// See Note: [When will the as_strided batching rule fail?] for details.
checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims);
checkBasicAsStridedValidForSlice(
physical_tensor, num_batch_dims, sizes, strides, storage_offset);
// physical_strides = physical tensor's batch strides + (logical) strides
auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
at::VmapDimVector physical_strides;
physical_strides.reserve(num_batch_dims + strides.size());
physical_strides.insert(
physical_strides.end(), batch_strides.begin(), batch_strides.end());
physical_strides.insert(
physical_strides.end(), strides.begin(), strides.end());
// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
// is valid for all i, then it turns out that
// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
// and creates a tensor y such that each y[i] references the same memory
// locations as zi. See NOTE: [When will the as_strided batching rule fail?]
auto result = physical_view.tensor().as_strided(
physical_sizes, physical_strides, storage_offset);
return physical_view.getPhysicalToLogicalMap().apply(result);
}
// NOTE: [When will the as_strided batching rule fail?]
// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
// is valid for all i, then it turns out that
// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
// creates a tensor y such that each y[i] refers to the same memory as zi.
//
// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
// Furthermore, let's say that as a part of being "valid" this as_strided call
// does not return a result that can index memory not indexable by xs[i].
//
// WLOG, assume that there's only one batch dim and it is at the front of the
// `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
// - If the batch dim isn't at the front of the tensor, then we can just move it
// to the front with movedim/permute. This is always valid because it just swaps
// some strides around.
// - This proof also works for tensors with multiple batch dims. We just have to
// do a little accounting:
// - instead of [B], we'd have [B0, B1, ..., Bk].
// - instead of [S], we'd have [S0, S1, ..., Sk].
// - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
// - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
//
// [Equation 1]
// xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
// - sizes: sizes
// - strides: strides
// - offset: offset + S * i
//
// x.as_strided itself checks that:
// - (sizes, strides, offset) are in bounds for `x`'s storage.
// - strides are positive
// - offset is positive
//
// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
// is valid, then
// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
//
// If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
// won't error out. So all we need to check is that the memory locations are
// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
//
// xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
// xs.as_strided([B] + sizes, [S] + strides, offset)
//
// xs.as_strided([B] + sizes, [S] + strides, offset) has:
// - sizes: [B] + sizes
// - strides: [S] + strides
// - offset: offset
//
// xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
// - sizes: sizes
// - strides: strides
// - offset: offset + S * i
// These memory locations are exactly the same as what we got for [Equation 1],
// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
//
// [Hand-wavy proof of Claim 1]
// Part of our definition of being valid is that xs[i].as_strided(...)
// must return a tensor that only uses memory indexable by xs[i].
// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
// (the largest-index memory location of xs[i].as_strided(...) must be \leq
// the largest-index memory location of xs[i])
//
// Fiddling that inequality gives us:
// offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
// <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
//
// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
// <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
//
// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
// <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
//
// offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
// <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
// (the largest-index memory location of xs.as_strided(size, stride, offset)
// is \leq than the largest-index memory location of xs)
// Under the assumptions we've made, the lower bound (lowest indexed memory)
// is trivially within the storage.
//
// Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
// `xs`'s storage.
template <typename F, F Func, typename... ExtraArgs>
Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
auto* input_batched = unsafeGetBatchedImpl(input);
auto output_physical = Func(input_batched->value(), args...);
auto old_bdims = input_batched->bdims();
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
template <typename F, F Func, typename... ExtraArgs>
Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
auto* input_batched = unsafeGetBatchedImpl(input);
auto output_physical = (input_batched->value().*Func)(extra_args...);
auto old_bdims = input_batched->bdims();
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
Tensor pow_scalar_Tensor_batching_rule(const Scalar& other, const Tensor& self) {
auto* self_batched = unsafeGetBatchedImpl(self);
auto output_physical = at::pow(other, self_batched->value());
auto old_bdims = self_batched->bdims();
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
Tensor clone_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
// Memory format support is a little tricky because vmap is allowed to move
// around batch dimensions and some memory formats are rank-dependent.
// Another weird case is:
// - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
// allow the user to clone a Tensor with 3 logical dimensions and 1 batch
// dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
// and N>1 batch dims?
TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
|| memory_format == MemoryFormat::Contiguous,
"NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
"memory_format torch.preserve_format or torch.contiguous_format (got ",
*memory_format, ")");
if (memory_format == MemoryFormat::Contiguous) {
// There is an ambiguity here when the batch dims are not at the front of
// the tensor.
// >>> x = torch.randn(3, B0, 5)
// >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
// >>> y[0].is_contiguous()
// ???
// Should we make the whole tensor contiguous, or should we
// make the non-batch dims contiguous? We've chosen the latter because
// philosophically vmap hides the batch dims and operates on a per-sample level.
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
auto output_physical = at::clone(physical_view.tensor(), memory_format);
return physical_view.getPhysicalToLogicalMap().apply(output_physical);
}
TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
auto* self_batched = unsafeGetBatchedImpl(self);
auto output_physical = at::clone(self_batched->value(), memory_format);
auto old_bdims = self_batched->bdims();
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
// Note [Batching rules for matmul-like operators]
// at::matmul doesn't "de-expand" arguments to get better performance (maybe
// it should). In the batching rules for matmul-like operators (dot, mv, mm),
// we should be careful not to expand any unnecessary dimensions. e.g., if
// only one of the two arguments is a BatchedTensor, then we should try
// not to expand batch dimensions onto the other arg.
Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
auto self_batched = isBatchedTensor(self);
auto other_batched = isBatchedTensor(other);
// A shape checking API would be nice...
TORCH_CHECK(self.dim() == 2 && other.dim() == 1,
"mv(self, other): Shape mismatch: expected matrix "
"(got `self` of size ", self.sizes(), ") ",
"and vector (got `other` of size ", other.sizes(), ")");
// See Note [Batching rules for matmul-like operators] for why we have cases
if (self_batched && !other_batched) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::matmul(self_physical.tensor(), other);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
if (!self_batched && other_batched) {
// self_physical: [L, K], other_physical: [..., K]
// We view the tensors as [L, K], [..., K, 1], perform matmul to get
// a tensor of size [..., L, 1], and unsqueeze the last dim.
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
}
if (self_batched && other_batched) {
// self_physical: [..., L, K], other_physical: [..., K]
// We view the tensors as [..., L, K], [..., K, 1], perform matmul to get
// a tensor of size [..., L, 1], and unsqueeze the last dim.
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
auto result = at::matmul(
physical_args[0].tensor(),
physical_args[1].tensor().unsqueeze(-1));
return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1));
}
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}
Tensor _make_dual_batching_rule(
c10::DispatchKeySet ks,
const Tensor& primal,
const Tensor& tangent,
int64_t level
) {
DispatchKeySet after_batched_keyset =
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Batched);
return at::redispatch::_make_dual(ks & after_batched_keyset, primal, tangent, level);
}
Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
auto self_batched = isBatchedTensor(self);
auto other_batched = isBatchedTensor(other);
TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
"dot(self, other): Shape mismatch: vector "
"(got `self` of size ", self.sizes(), ") ",
"and vector (got `other` of size ", other.sizes(), ")");
// See Note [Batching rules for matmul-like operators] for why we have cases
if (self_batched && !other_batched) {
// self_physical: [..., K], other_physical: [K]
// View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
}
if (!self_batched && other_batched) {
// self_physical: [K], other_physical: [..., K]
// View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
}
if (self_batched && other_batched) {
// self_physical: [..., K], other_physical: [..., K]
// View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
auto result = at::matmul(
physical_args[0].tensor().unsqueeze(-2),
physical_args[1].tensor().unsqueeze(-1));
return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
}
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}
Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
"bmm(self, other): Shape mismatch: expected 3D `self` "
"(got `self` of size ", self.sizes(), ") ",
"and 3D `other` (got `other` of size ", other.sizes(), ")");
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
return physical_args[0].getPhysicalToLogicalMap().apply(result);
}
Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
auto self_batched = isBatchedTensor(self);
auto other_batched = isBatchedTensor(other);
TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
"mm(self, other): Shape mismatch: expected matrix "
"(got `self` of size ", self.sizes(), ") ",
"and matrix (got `other` of size ", other.sizes(), ")");
// See Note [Batching rules for matmul-like operators] for why we have cases
if (self_batched && !other_batched) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::matmul(self_physical.tensor(), other);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
if (!self_batched && other_batched) {
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = at::matmul(self, other_physical.tensor());
return other_physical.getPhysicalToLogicalMap().apply(result);
}
if (self_batched && other_batched) {
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
}
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}
Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
TORCH_INTERNAL_ASSERT(
!tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
return physical_views[0].getPhysicalToLogicalMap().apply(result);
}
Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
TORCH_INTERNAL_ASSERT(
!tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
// NB: stack wraps the dimensionality to (logical dim + 1), so we have to
// manually handle that here.
auto dim_physical =
physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
auto result = at::stack(physical_tensors, dim_physical);
return physical_views[0].getPhysicalToLogicalMap().apply(result);
}
// I am quite sad that we need to register operators with exploded TensorOptions,
// even though the native:: implementations can use TensorOptions&.
// This also makes it hard to metaprogram: i.e., we can't use
// unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!)
Tensor to_dtype_layout_batching_rule(
const Tensor& self,
optional<ScalarType> dtype,
optional<Layout> layout,
optional<Device> device,
optional<bool> pin_memory,
bool non_blocking, bool copy,
optional<MemoryFormat> memory_format) {
auto options = TensorOptions()
.dtype(dtype)
.layout(layout)
.device(device)
.pinned_memory(pin_memory);
auto* input_batched = unsafeGetBatchedImpl(self);
auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
auto old_bdims = input_batched->bdims();
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
Tensor new_zeros_batching_rule(
const Tensor& self,
IntArrayRef size,
optional<ScalarType> dtype,
optional<Layout> layout,
optional<Device> device,
optional<bool> pin_memory) {
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
auto physical_size = physical_view.getPhysicalShape(size);
auto options = TensorOptions()
.dtype(dtype)
.layout(layout)
.device(device)
.pinned_memory(pin_memory);
auto result = physical_view.tensor().new_zeros(physical_size, options);
return physical_view.getPhysicalToLogicalMap().apply(result);
}
Tensor new_empty_batching_rule(
const Tensor& self,
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,