-
Notifications
You must be signed in to change notification settings - Fork 480
/
xla_graph_executor.cpp
1364 lines (1262 loc) · 58.4 KB
/
xla_graph_executor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "torch_xla/csrc/xla_graph_executor.h"
#include <Python.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/ir_util.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/unique.h>
#include <torch/csrc/lazy/core/util.h>
#include <algorithm>
#include <cmath>
#include <condition_variable>
#include <exception>
#include <fstream>
#include <functional>
#include <mutex>
#include <set>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "stablehlo/dialect/Serialization.h" // from @stablehlo
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
#include "torch_xla/csrc/ops/cast.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/ops/expand.h"
#include "torch_xla/csrc/ops/expand_symint.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/ops/view.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/cache.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/xla_backend_impl.h"
#include "torch_xla/csrc/xla_sharding_util.h"
#include "tsl/platform/errors.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/literal_util.h"
#include "xla/shape_util.h"
namespace torch_xla {
namespace {
torch::lazy::Value IrValueFromScalar(const at::Scalar& value,
at::ScalarType scalar_type,
const torch::lazy::BackendDevice& device) {
at::Tensor tensor = at::scalar_tensor(value, at::TensorOptions(scalar_type));
torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device);
return torch::lazy::MakeNode<DeviceData>(std::move(device_data));
}
bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) {
return ir_value->op() != xla_not_supported;
}
} // namespace
auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* {
static DeviceContextArena* arena = new DeviceContextArena();
return arena;
}
std::vector<XLATensorPtr> XLAGraphExecutor::DeviceContextArena::GetLiveTensors(
const torch::lazy::BackendDevice* device) {
std::vector<XLATensorPtr> tensors;
auto fn = [&](DeviceContext* devctx) {
std::lock_guard<std::mutex> lock(devctx->lock);
for (auto& uid_wptr : devctx->tensors_data) {
auto data =
std::dynamic_pointer_cast<XLATensor::Data>(uid_wptr.second.lock());
if (data != nullptr) {
tensors.push_back(XLATensor::Create(std::move(data)));
}
}
};
ForAllDeviceContexts(fn, device);
return tensors;
}
torch::lazy::Value XLAGraphExecutor::DeviceContextArena::GetRngSeed(
const torch::lazy::BackendDevice& device) {
static const at::ScalarType kSeedType = at::ScalarType::Long;
static const uint64_t kSeedMul = 214013;
static const uint64_t kSeedAdd = 2531011;
DeviceContext* devctx = GetDeviceContext(device);
std::lock_guard<std::mutex> lock(devctx->lock);
if (!devctx->seed_ir_value) {
devctx->seed_ir_value =
IrValueFromScalar(MakeIntScalar(devctx->seed), kSeedType, device);
}
// Keep the running seed as scalar as well, so we can return it directly
// without executing graphs.
devctx->running_seed = kSeedAdd + kSeedMul * devctx->running_seed;
// Compose new seeds from the root seed, to avoid creating too many XLA
// computation parameters which might overflow the TPU capacity.
torch::lazy::Value k = ScalarOp(MakeIntScalar(kSeedMul),
MakeXlaPrimitiveType(kSeedType, &device));
torch::lazy::Value b = ScalarOp(MakeIntScalar(kSeedAdd),
MakeXlaPrimitiveType(kSeedType, &device));
devctx->seed_ir_value = b + k * devctx->seed_ir_value;
return devctx->seed_ir_value;
}
torch::lazy::BackendDataPtr
XLAGraphExecutor::DeviceContextArena::GetBaseSeedData(
const torch::lazy::BackendDevice& device) {
static const at::ScalarType kSeedType = at::ScalarType::Long;
DeviceContext* devctx = GetDeviceContext(device);
std::lock_guard<std::mutex> lock(devctx->lock);
at::Tensor tensor = at::scalar_tensor(MakeIntScalar(devctx->seed),
at::TensorOptions(kSeedType));
torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device);
devctx->seed_ir_value = torch::lazy::MakeNode<DeviceData>(device_data);
devctx->running_seed = devctx->seed;
return torch_xla::DeviceData::Cast(devctx->seed_ir_value.node.get())->data();
}
void XLAGraphExecutor::DeviceContextArena::SaveGraphAsString(
torch::lazy::hash_t hash, absl::Span<const XLATensorPtr> tensors,
const std::vector<size_t>* indices, DebugUtil::GraphFormat format) {
static bool should_save_graph =
runtime::sys_util::GetEnvOrdinalPath(
"XLA_SAVE_TENSORS_FILE", "", bridge::GetCurrentDevice().ordinal()) !=
"";
if (should_save_graph &&
hash_to_graph_map.find(hash) == hash_to_graph_map.end()) {
hash_to_graph_map[hash] =
DebugUtil::GetTensorsGraphInfo(tensors, indices, format);
}
}
void XLAGraphExecutor::DeviceContextArena::SaveOutputShapes(
torch::lazy::hash_t hash, std::vector<xla::Shape> output_shapes) {
if (hash_to_output_shape_map.find(hash) == hash_to_output_shape_map.end()) {
hash_to_output_shape_map[hash] = std::move(output_shapes);
}
}
std::string XLAGraphExecutor::DeviceContextArena::GetGraphByHash(
torch::lazy::hash_t hash) {
auto iter = hash_to_graph_map.find(hash);
if (iter == hash_to_graph_map.end()) {
TF_LOG(INFO) << "Trying to dump graph with an invalid hash";
return "";
}
return iter->second;
}
std::vector<xla::Shape>*
XLAGraphExecutor::DeviceContextArena::GetOutputShapesByHash(
torch::lazy::hash_t hash) {
auto iter = hash_to_output_shape_map.find(hash);
XLA_CHECK(iter != hash_to_output_shape_map.end())
<< "Hash not found, can't retrive output shape";
return &(iter->second);
}
torch::lazy::Value XLAGraphExecutor::DeviceContextArena::IrValueFromScalar(
const at::Scalar& value, at::ScalarType scalar_type,
const torch::lazy::BackendDevice& device) {
at::Tensor tensor = at::scalar_tensor(value, at::TensorOptions(scalar_type));
torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device);
return torch::lazy::MakeNode<DeviceData>(std::move(device_data));
}
XLAGraphExecutor::Async::Async(
SyncTensorCollection* coll,
std::vector<torch::lazy::BackendDataPtr> parameters_data,
std::vector<torch::lazy::BackendDataPtr> tensors_data,
ComputationCache::TypePtr cached_computation)
: torch::lazy::LazyGraphExecutor::Async(coll, parameters_data, tensors_data,
nullptr),
cached_computation(std::move(cached_computation)) {}
XLAGraphExecutor* XLAGraphExecutor::Get() {
static XLAGraphExecutor arena = XLAGraphExecutor();
return &arena;
}
void XLAGraphExecutor::RegisterTensor(
std::shared_ptr<torch::lazy::LazyTensor::Data> data) {
DeviceContextArena::Get()->RegisterTensor(data);
TORCH_LAZY_COUNTER("CreateXlaTensor", 1);
}
void XLAGraphExecutor::UnregisterTensor(torch::lazy::LazyTensor::Data* data) {
DeviceContextArena::Get()->UnregisterTensor(data);
TORCH_LAZY_COUNTER("DestroyXlaTensor", 1);
}
void XLAGraphExecutor::ApplyEagerSync(std::vector<XLATensorPtr>& tensors) {
SyncTensorsGraph(&tensors, {}, /*wait=*/false, /*sync_ltc_data=*/false);
}
torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device) {
torch::lazy::BackendDataPtr data =
GetDeviceData(value, MaybeUpcastToHostTorchType(type), device);
data->SetInfo(
std::make_shared<torch::lazy::LazyGraphExecutor::DeviceDataInfo>(
/*tensor_id=*/-1, /*read_only=*/true));
return torch::lazy::MakeNode<DeviceData>(std::move(data));
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForConstant(
const at::Scalar& value, const xla::Shape& shape) {
torch::lazy::Value ir_value =
ScalarOp(std::move(value), shape.element_type());
if (!shape.dimensions().empty()) {
ir_value = torch::lazy::MakeNode<Expand>(
ir_value, torch::lazy::ToVector<int64_t>(shape.dimensions()));
}
return ir_value;
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device) {
if (torch::lazy::IsSpecialScalar(value)) {
return ScalarOp(std::move(value), type);
}
return GetDeviceDataIrValue(value, type, device);
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, const torch::lazy::BackendDevice& device) {
return GetIrValueForScalar(
value, MakeXlaPrimitiveType(GetScalarType(value), &device), device);
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
absl::Span<const int64_t> dimensions,
const torch::lazy::BackendDevice& device) {
torch::lazy::Value ir_value = GetIrValueForScalar(value, type, device);
if (!dimensions.empty()) {
ir_value = torch::lazy::MakeNode<Expand>(
ir_value, torch::lazy::ToVector<int64_t>(dimensions));
}
return ir_value;
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
c10::SymIntArrayRef sym_size, const torch::lazy::BackendDevice& device) {
torch::lazy::Value ir_value = GetIrValueForScalar(value, type, device);
SymIntElements size_elements = SymIntElements(sym_size);
return torch::lazy::MakeNode<ExpandSymInt>(ir_value, size_elements);
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, const xla::Shape& shape,
const torch::lazy::BackendDevice& device) {
return GetIrValueForScalar(value, shape.element_type(), shape.dimensions(),
device);
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, const xla::Shape& shape,
c10::optional<at::ScalarType> logical_element_type,
const torch::lazy::BackendDevice& device) {
xla::PrimitiveType type =
logical_element_type
? MakeXlaPrimitiveType(*logical_element_type, &device)
: shape.element_type();
return GetIrValueForScalar(value, type, shape.dimensions(), device);
}
torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, const xla::Shape& shape,
SymIntElements size_elements,
c10::optional<at::ScalarType> logical_element_type,
const torch::lazy::BackendDevice& device) {
xla::PrimitiveType primitive_type =
logical_element_type
? MakeXlaPrimitiveType(*logical_element_type, &device)
: shape.element_type();
torch::lazy::Value ir_value =
GetIrValueForScalar(value, primitive_type, device);
return torch::lazy::MakeNode<ExpandSymInt>(ir_value, size_elements);
}
torch::lazy::Value XLAGraphExecutor::GetRngSeed(
const torch::lazy::BackendDevice& device) {
return DeviceContextArena::Get()->GetRngSeed(device);
}
void XLAGraphExecutor::SetRngSeed(const torch::lazy::BackendDevice& device,
uint64_t seed) {
DeviceContextArena::Get()->SetRngSeed(device, seed);
}
uint64_t XLAGraphExecutor::GetRunningSeed(
const torch::lazy::BackendDevice& device) {
return DeviceContextArena::Get()->GetRunningSeed(device);
}
torch::lazy::BackendDataPtr XLAGraphExecutor::GetBaseSeedData(
const torch::lazy::BackendDevice& device) {
return DeviceContextArena::Get()->GetBaseSeedData(device);
}
std::string XLAGraphExecutor::DumpHloComputation(
const std::vector<XLATensorPtr>& tensors, EmitMode mode) {
std::vector<torch::lazy::Value> ir_values;
for (auto& tensor : tensors) {
torch::lazy::Value ir_value = tensor->CurrentIrValue();
if (ir_value) {
ir_values.push_back(std::move(ir_value));
}
}
return !ir_values.empty()
? DumpUtil::ToHlo(ir_values, bridge::GetCurrentDevice(), mode)
: std::string();
}
std::vector<XLATensorPtr> XLAGraphExecutor::GetLiveTensors(
const torch::lazy::BackendDevice* device) {
return DeviceContextArena::Get()->GetLiveTensors(device);
}
void XLAGraphExecutor::SyncTensorsGraph(std::vector<XLATensorPtr>* tensors,
absl::Span<const std::string> devices,
bool wait, bool sync_ltc_data,
bool warm_up_cache_only) {
TF_VLOG(4) << "Trying to sync the value of " << tensors->size()
<< " tensor(s)";
tsl::profiler::TraceMe activity("SyncTensorsGraph",
tsl::profiler::TraceMeLevel::kInfo);
SyncTensorsConfig config;
config.sync_ltc_data = sync_ltc_data;
if (warm_up_cache_only) {
config.force_ltc_data = false;
}
auto async =
SyncTensorsGraphInternal(tensors, devices, config, warm_up_cache_only);
if (wait && async != nullptr && !warm_up_cache_only) {
async->mwait.Wait();
}
}
void XLAGraphExecutor::SyncLiveTensorsGraph(
const torch::lazy::BackendDevice* device,
c10::ArrayRef<std::string> devices, bool wait) {
tsl::profiler::TraceMe activity("SyncLiveTensorsGraph",
tsl::profiler::TraceMeLevel::kInfo);
auto tensors = GetLiveTensors(device);
TF_VLOG(4) << tensors.size() << " live tensors: devices=("
<< c10::Join(",", devices) << ")";
SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true);
}
void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
// TODO(jwtan): Replace this with TORCH_LAZY_COUNTER. We need MarkStep to
// remain as XLA_COUNTER to support
// runtime::metrics::CreatePerformanceReport(). For more information, see
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
XLA_COUNTER("MarkStep", 1);
DeviceContextArena::Get()->MarkStep(device);
torch::lazy::ScopePusher::ResetScopes();
ResetTrimCounter();
}
void XLAGraphExecutor::WaitDeviceOps(absl::Span<const std::string> devices) {
std::set<torch::lazy::BackendDevice> wait_devices;
if (!devices.empty()) {
for (auto& device_str : devices) {
wait_devices.insert(ParseDeviceString(device_str));
}
} else {
if (UseVirtualDevice()) {
wait_devices.insert(ParseDeviceString("SPMD:0"));
} else {
for (auto& device_str :
runtime::GetComputationClient()->GetLocalDevices()) {
wait_devices.insert(ParseDeviceString(device_str));
}
}
}
// The DeviceLockerArena::Get()->LockDevices() API returns a vector of
// torch::lazy::ExceptionCleanup object, which is going to be freed
// immediately, turning this operation into a lock barrier.
DeviceLockerArena::Get()->LockDevices(wait_devices);
TF_VLOG(4) << "XLAGraphExecutor::WaitDeviceOps completed";
}
std::vector<at::Tensor> XLAGraphExecutor::GetTensors(
std::vector<XLATensorPtr>* tensors) {
TF_VLOG(4) << "Trying to get the value of " << tensors->size()
<< " tensor(s)";
SyncTensorsConfig config;
config.force_ltc_data = false;
auto async = SyncTensorsGraphInternal(tensors, {}, config);
if (async != nullptr) {
async->mwait.Wait();
}
std::vector<torch::lazy::BackendDataPtr> tensors_data = GatherTensorsXlaData(
*tensors, async != nullptr ? async->indices : absl::Span<const size_t>(),
async != nullptr ? async->tensors_data
: absl::Span<const torch::lazy::BackendDataPtr>());
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(tensors_data);
return FetchTensors(tensors, literals,
async != nullptr ? &async->indices : nullptr);
}
torch::lazy::hash_t XLAGraphExecutor::GetGraphHash(
const std::vector<XLATensorPtr>& tensors) {
SyncTensorsConfig config;
config.sync_ltc_data = true;
config.force_ltc_data = false;
SyncTensorCollection coll = CollectSyncTensors(tensors, config);
absl::Span<const size_t> indices = coll.indices;
std::vector<torch::lazy::Value> ir_values;
std::vector<xla::Shape> output_shapes;
ir_values.reserve(indices.size());
output_shapes.reserve(indices.size());
for (auto index : indices) {
XLATensorPtr tensor = tensors[index];
ir_values.push_back(tensor->CurrentIrValue());
output_shapes.push_back(MakeShapeWithDeviceLayout(
tensor->shape(),
static_cast<XlaDeviceType>(tensor->GetDevice().type())));
}
PostOrderData po_data = RunPostOrder(ir_values, &coll);
torch::lazy::hash_t res_hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
DeviceContextArena::Get()->SaveOutputShapes(res_hash,
std::move(output_shapes));
DeviceContextArena::Get()->SaveGraphAsString(res_hash, tensors,
&coll.indices);
return res_hash;
}
void XLAGraphExecutor::MaybeDumpGraph(std::string name,
torch::lazy::hash_t hash) {
thread_local const std::string save_file =
runtime::sys_util::GetEnvOrdinalPath(
"XLA_SAVE_TENSORS_FILE", "", bridge::GetCurrentDevice().ordinal());
if (!save_file.empty()) {
std::string graph = DeviceContextArena::Get()->GetGraphByHash(hash);
if (graph.size() == 0) {
return;
}
std::ofstream graph_file(save_file, std::ios_base::app);
graph_file << "[" << name << "]\n" << graph << "\n";
}
}
XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
static ComputationCache* cache = new ComputationCache(kMaxCacheSize);
return cache;
}
void XLAGraphExecutor::ClearPendingIrs(
std::vector<XLATensorPtr> tensors,
const torch::lazy::BackendDevice& device) {
std::unordered_set<int64_t> tensor_ids;
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
tensors[i]->CurrentDataHandle() == nullptr) {
torch::lazy::Value ir_value = tensors[i]->CurrentIrValue();
// Only clear the IR that is not a DeviceData Node.
if (ir_value) {
DeviceData* device_data =
torch_xla::DeviceData::Cast(ir_value.node.get());
if (device_data != nullptr) {
tensors[i]->data()->handle = device_data->data();
} else {
xla::Shape shape = MakeShapeWithDeviceLayout(
tensors[i]->shape(), static_cast<XlaDeviceType>(device.type()));
torch::lazy::BackendDataPtr handle =
runtime::GetComputationClient()->CreateDataPlaceholder(
device.toString(), std::move(shape));
tensors[i]->data()->handle = handle;
TF_VLOG(4) << "Replacing the IR " << ir_value.node.get()->ToString()
<< " of Tensor with ID " << tensors[i]->GetUniqueId()
<< " with placeholder";
}
tensors[i]->AssignIrValue(torch::lazy::Value());
tensors[i]->data()->view = nullptr;
tensors[i]->data()->tensor_data = c10::nullopt;
}
}
}
}
XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
const std::vector<XLATensorPtr>& tensors, const SyncTensorsConfig& config) {
tsl::profiler::TraceMe activity("CollectSyncTensors",
tsl::profiler::TraceMeLevel::kInfo);
torch::lazy::Unique<torch::lazy::BackendDevice> unique_device;
for (size_t i = 0; i < tensors.size(); ++i) {
unique_device.set(tensors[i]->GetDevice());
}
SyncTensorCollection coll;
if (!unique_device) {
return coll;
}
std::vector<at::Tensor> at_tensors;
std::vector<std::string> devices;
std::vector<XLATensor::ShardingSpecPtr> shardings;
std::vector<size_t> at_tensor_index;
std::unordered_set<int64_t> tensor_ids;
// The force_ltc_data controls aliasing compilation, so effectively the same
// graph with on/off force_ltc_data should not match, hash wise.
coll.hash = torch::lazy::MHash(config.force_ltc_data);
coll.config = config;
coll.device = *unique_device;
coll.indices.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
// A tensor's xla_data might not be up to date if there is a view
// associated with it. Make sure to sync those tensors here too.
(tensors[i]->CurrentDataHandle() == nullptr ||
(tensors[i]->data()->view != nullptr &&
!tensors[i]->data()->view->IsUpToDate()))) {
torch::lazy::Value ir_value = tensors[i]->CurrentIrValue();
if (ir_value) {
if (ShouldSyncIrValue(ir_value)) {
// Add only tensors which need to be synced.
coll.hash = torch::lazy::HashCombine(coll.hash, ir_value.hash());
coll.indices.push_back(i);
// `sharding_spec()` checks sharding equality. If IR node has no
// sharding, then sync XLATensor sharding to the IR node. XLATensor's
// sharding takes the precedence as the source of the truth.
XLATensor::ShardingSpecPtr sharding = tensors[i]->sharding_spec();
if (sharding) {
dynamic_cast<XlaNode*>(ir_value.node.get())
->SetSharding(sharding->sharding, ir_value.index);
}
}
} else if (config.force_ltc_data) {
// The tensor only has at::Tensor data. We need to queue it for a
// device upload.
c10::optional<at::Tensor> tensor_data = tensors[i]->CurrentTensorData();
XLA_CHECK(tensor_data);
at_tensors.push_back(*tensor_data);
shardings.push_back(tensors[i]->sharding_spec());
devices.push_back(tensors[i]->GetDevice().toString());
at_tensor_index.push_back(i);
}
}
}
if (!at_tensors.empty()) {
TORCH_LAZY_COUNTER("SyncTensorsToData", at_tensors.size());
// Create data handles with shardings. If a tensor has a
// sharding annotation, then a BackendDataPtr with PjRtShardedData is
// returned; if there is no sharding annotation, then a BackendDataPtr with
// PjRtData is returned.
std::vector<torch::lazy::BackendDataPtr> handles =
CreateTensorsData(at_tensors, shardings, devices);
for (size_t i = 0; i < handles.size(); ++i) {
// If we are here, it means that the IR torch::lazy::Value for the
// tensor is not present. Also, we uploaded the at::Tensor data to the
// device, but such data is still valid so we leave it live on the XLA
// tensor (so that a following ToTensor() does not need to fetch it from
// device).
tensors[at_tensor_index[i]]->data()->handle = std::move(handles[i]);
}
}
TF_VLOG(4) << "Tensors graph hash " << torch::lazy::HashToString(coll.hash)
<< " on device " << coll.device;
return coll;
}
void XLAGraphExecutor::TensorCollectionBarrier(SyncTensorCollection* coll) {
tsl::profiler::TraceMe activity("TensorCollectionBarrier",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(4) << "waiting barrier for device " << coll->device.toString()
<< " start";
torch::lazy::LazyGraphExecutor::TensorCollectionBarrier(coll);
// TODO(yeounoh) lock SPMD device
TF_VLOG(4) << "waiting barrier for device " << coll->device.toString()
<< " done";
}
std::vector<torch::lazy::BackendDataPtr>
XLAGraphExecutor::ExecuteComputationWithBarrier(
torch::lazy::hash_t hash, const std::vector<at::IValue>& graph_inputs,
const torch::lazy::BackendDevice& device) {
tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier",
tsl::profiler::TraceMeLevel::kInfo);
MaybeDumpGraph("dynamo", hash);
auto cachedComputation =
XLAGraphExecutor::Get()->GetComputationCache()->Get(hash);
TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash)
<< ") is_sharded=" << cachedComputation->is_sharded << std::endl;
// TODO implement a fallback mechanism, or make sure those entries
// never get kicked out
XLA_CHECK(cachedComputation)
<< "Failed to get computation by hash " << torch::lazy::HashToString(hash)
<< ". Maybe the entry get "
"kicked out of the LRU cache";
DebugUtil::analyze_graph_execution_python_frame(
DebugUtil::GraphAnalysisSource::DynamoExecution,
/*graph_hash=*/hash,
/*program_shape=*/&(cachedComputation->computation->program_shape()));
// Create DataPlaceHolder that will get filled in async executions.
std::vector<xla::Shape>* output_shapes =
DeviceContextArena::Get()->GetOutputShapesByHash(hash);
std::vector<torch::lazy::BackendDataPtr> placeholders;
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
if (static_cast<XlaDeviceType>(device.type()) == XlaDeviceType::SPMD) {
sharding_specs =
std::vector<XLATensor::ShardingSpecPtr>(output_shapes->size());
// TODO(JackCaoG): Use LRU cache and add same cache to non-dynamo path.
static std::unordered_map<torch::lazy::hash_t,
std::vector<XLATensor::ShardingSpecPtr>,
torch::lazy::HashReducer>
output_sharding_hash;
// For any given graph(each hash correspodning to one graph) there is only
// one output sharding. We can cache this sharding here to avoid retrive
// the sharding from the computation every time.
if (output_sharding_hash.find(hash) == output_sharding_hash.end()) {
TORCH_LAZY_COUNTER("UncachedOutputSharding", 1);
output_sharding_hash[hash] = ShardingUtil::GetOutputSharding(
output_shapes, cachedComputation->computation, device);
}
placeholders =
ShardingUtil::CreateShardedPlaceholder(output_sharding_hash[hash]);
} else {
for (const xla::Shape& shape : *output_shapes) {
torch::lazy::BackendDataPtr handle =
runtime::GetComputationClient()->CreateDataPlaceholder(
device.toString(), std::move(shape));
placeholders.push_back(handle);
}
}
SyncTensorCollection coll;
coll.device = device;
{
tsl::profiler::TraceMe activity("DeviceBarrier",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(5) << "Lock device " << device.toString() << "...";
coll.unlocker = DeviceLockerArena::Get()->LockDevices({device});
TF_VLOG(5) << "Locking device " << device.toString() << " Done!";
}
std::vector<torch::lazy::BackendDataPtr> arguments;
{
// GetXlaData must be called within a lock region, otherwise it might
// extract the placeholder inserted by previous execution.
TORCH_LAZY_TIMED("RunCachedGraphInputData");
// setup the arguments
for (auto& ivalue : graph_inputs) {
torch::lazy::BackendDataPtr dataptr;
if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) {
dataptr = xla_tensor_ptr->GetXlaData();
} else {
XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD)
<< "SPMD device data should already be on the XLA backend "
"(XLATensor).";
dataptr = torch_xla::TensorToXlaData(ivalue.toTensor(), device);
}
arguments.push_back(dataptr);
}
}
std::shared_ptr<XLAGraphExecutor::Async> async = std::make_shared<Async>(
&coll, std::move(arguments), placeholders, std::move(cachedComputation));
auto syncfn = [async, hash, sharding_specs]() {
try {
tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier_syncfn",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(3) << "Executing Dynamo IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";
std::vector<torch::lazy::BackendDataPtr> results;
if (async->cached_computation->is_sharded) {
std::vector<std::string> devices =
runtime::GetComputationClient()->GetLocalDevices();
runtime::ComputationClient::ExecuteReplicatedOptions execute_options;
// OutputHandler creates sharded data for sharded
// tensor results. Both sharded and unsharded results should be
// "Assign"ed to the corresponding data placeholders.
std::vector<runtime::ComputationClient::DataPtr> outputs =
runtime::GetComputationClient()->ExecuteReplicated(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), devices,
execute_options);
results = WrapXlaData(outputs);
TF_VLOG(3) << "Executing Dynamo IR sharded graph hash "
<< torch::lazy::HashToString(hash) << " on devices "
<< absl::StrJoin(devices, ",") << " done!";
} else {
results = torch::lazy::getBackend()->ExecuteComputation(
async->cached_computation->computation, async->parameters_data,
async->device);
TF_VLOG(3) << "Executing Dynamo IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " done!";
}
// Updating placeholder with actual output handle.
{
tsl::profiler::TraceMe activity("update_placeholder",
tsl::profiler::TraceMeLevel::kInfo);
for (size_t i = 0; i < results.size(); ++i) {
XLA_CHECK(async->tensors_data[i] != nullptr);
async->tensors_data[i]->Assign(*results[i]);
}
}
} catch (...) {
// There are two paths of discovery of an exception happening on an
// asynchronous task. One happens if the creator of the asynchronous task
// explicitly waits for completion, in which case the exception will be
// thrown from the Wait() API. Re-throwing the exception below makes sure
// this will be captured by the completer function created below, and
// surfaced by the Wait() API. But we also need to surface the exception
// even in case the caller does not wait, and that is accomplished by
// setting the unlockers status. In that case the exception will be
// surfaced when the user tries to acquire the device locks the next time.
for (auto& unlocker : async->unlocker) {
unlocker.SetStatus(std::current_exception());
}
throw;
}
};
thread::Schedule(async->mwait.Completer(std::move(syncfn)));
return placeholders;
}
std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::ExecuteStablehlo(
std::string bytecode, const std::vector<at::IValue>& graph_inputs,
const torch::lazy::BackendDevice& device) {
// Convert StableHLO to HLO for XLA compilation.
// TODO(lsy323): Pass StableHLO to PjrtComputationClient for compilation
// after StableHLO compilation API is added in ComputationClient.
mlir::MLIRContext context;
mlir::OwningOpRef<mlir::ModuleOp> module =
mlir::stablehlo::deserializePortableArtifact(bytecode, &context);
mlir::ModuleOp mlir_module = *module;
xla::HloProto hlo_proto;
ConvertStableHloToHlo(&mlir_module, &context, &hlo_proto);
xla::HloModuleProto* hlo_module_proto = hlo_proto.mutable_hlo_module();
xla::XlaComputation computation(*hlo_module_proto);
// Get program output shape.
// TODO(lsy323): Get shape info from MLIR Module.
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
xla::Shape shape = MakeShapeWithDeviceLayout(
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
std::vector<runtime::ComputationClient::CompileInstance> instances;
instances.emplace_back(
std::move(computation), device.toString(),
runtime::GetComputationClient()->GetCompilationDevices(
device.toString(),
runtime::GetComputationClient()->GetLocalDevices()),
&shape);
std::vector<std::shared_ptr<runtime::ComputationClient::Computation>>
computations =
runtime::GetComputationClient()->Compile(std::move(instances));
std::vector<torch::lazy::BackendDataPtr> arguments;
{
// GetXlaData must be called within a lock region, otherwise it might
// extract the placeholder inserted by previous execution.
// setup the arguments
for (auto& ivalue : graph_inputs) {
torch::lazy::BackendDataPtr dataptr;
if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) {
dataptr = xla_tensor_ptr->GetXlaData();
} else {
dataptr = torch_xla::TensorToXlaData(ivalue.toTensor(), device);
}
arguments.push_back(dataptr);
}
}
std::vector<runtime::ComputationClient::DataPtr> result_data =
runtime::GetComputationClient()->ExecuteComputation(
*computations[0], UnwrapXlaData(arguments), device.toString());
return WrapXlaData(result_data);
}
std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::GatherTensorsXlaData(
const std::vector<XLATensorPtr>& tensors, absl::Span<const size_t> indices,
absl::Span<const torch::lazy::BackendDataPtr> tensors_data) {
std::vector<torch::lazy::BackendDataPtr> result_tensors_data;
std::unordered_map<int64_t, size_t> uid_index_map;
size_t indices_index = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
int64_t tensor_id = tensors[i]->GetUniqueId();
auto it = uid_index_map.find(tensor_id);
if (it != uid_index_map.end()) {
// Current tensor is a duplicate of a previously processed tensor that
// had an IR XlaNode to sync. Get the XLA data from the tensor_data_map.
result_tensors_data.push_back(result_tensors_data[it->second]);
} else if (indices_index < indices.size() && i == indices[indices_index]) {
// If we are at the current index (it means that the tensor at index
// 'i' had an IR node to sync), use the XLA data held within the Async
// object.
uid_index_map.emplace(tensor_id, result_tensors_data.size());
result_tensors_data.push_back(tensors_data[indices_index]);
++indices_index;
} else if (!tensors[i]->CurrentTensorData()) {
torch::lazy::BackendDataPtr handle = tensors[i]->CurrentDataHandle();
XLA_CHECK(handle != nullptr);
result_tensors_data.push_back(std::move(handle));
}
}
return result_tensors_data;
}
std::vector<torch::lazy::Value> XLAGraphExecutor::CollectRoots(
const std::vector<XLATensorPtr>& tensors,
absl::Span<const size_t> indices) {
std::vector<torch::lazy::Value> roots;
roots.reserve(indices.size());
for (auto index : indices) {
roots.push_back(tensors.at(index)->CurrentIrValue());
}
return roots;
}
std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::SetTensorData(
std::vector<XLATensorPtr>* tensors, const SyncTensorsConfig& config,
absl::Span<const size_t> indices,
const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec) {
tsl::profiler::TraceMe activity("SetTensorData",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<torch::lazy::BackendDataPtr> tensors_data;
tensors_data.reserve(indices.size());
for (int i = 0; i < indices.size(); i++) {
auto index = indices[i];
XLATensorPtr& tensor = (*tensors)[index];
// If the config.force_ltc_data flag is true, the purpose of this tensor
// sync operation is to truncate the IR graph and materialize device data
// in place of IR graph, on selected tensors. But since operation will
// complete asynchronously, if a tensor does not already have device data,
// we need to install a placeholder. Since at this point we hold a lock on
// the device where the tensors reside (locks held within the coll
// structure, and moved into the async variable), any other operation
// trying to access the tensor's device data will have to wait until the
// asynchronous operation completes.
torch::lazy::BackendDataPtr handle = tensor->CurrentDataHandle();
if (!handle && tensor->CurrentIrValue()) {
handle = torch::lazy::getBackend()->GetComputationDataFromNode(
tensor->CurrentIrValue().node.get());
}
if (handle == nullptr && config.force_ltc_data) {
handle = tensor_data_vec[i];
// Note: We are not using SetXlaData method here since that method
// resets the ir_value. We have already done the resetting as part
// of ExtractIRAndPrepareXlaData_ to overlap with previous execution.
tensor->data()->handle = handle;
tensor->data()->view = nullptr;
tensor->data()->tensor_data = c10::nullopt;
}
tensors_data.emplace_back(std::move(handle));
}
return tensors_data;
}
void XLAGraphExecutor::ExtractIRAndPrepareXlaData_(
std::vector<XLATensorPtr>* tensors, const SyncTensorsConfig& config,
const absl::Span<const size_t> indices,
std::vector<torch::lazy::Value>& ir_values,
std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec) {
tsl::profiler::TraceMe activity("ExtractIRAndPrepareXlaData_",
tsl::profiler::TraceMeLevel::kInfo);
ir_values.reserve(indices.size());
tensor_data_vec.reserve(indices.size());
for (auto index : indices) {
XLATensorPtr& tensor = (*tensors)[index];
torch::lazy::Value ir_value = tensor->CurrentIrValue();
ir_values.push_back(ir_value);
const torch::lazy::BackendDevice& tensor_device = tensor->GetDevice();
xla::Shape shape = MakeShapeWithDeviceLayout(
tensor->shape(), static_cast<XlaDeviceType>(tensor_device.type()));
torch::lazy::BackendDataPtr handle =
runtime::GetComputationClient()->CreateDataPlaceholder(
tensor_device.toString(), std::move(shape));
tensor_data_vec.push_back(handle);
if (tensor->CurrentDataHandle() == nullptr && config.force_ltc_data) {
tensor->AssignIrValue(torch::lazy::Value());
}
}
}
std::vector<at::Tensor> XLAGraphExecutor::FetchTensors(
std::vector<XLATensorPtr>* tensors, absl::Span<const xla::Literal> literals,
const std::vector<size_t>* indices) {
std::vector<at::Tensor> results;
size_t literals_index = 0;
size_t sync_index = 0;
results.reserve(tensors->size());
for (size_t i = 0; i < tensors->size(); ++i) {
if (indices != nullptr && sync_index < indices->size() &&
i == (*indices)[sync_index]) {
results.push_back(MakeTensorFromXlaLiteral(literals[literals_index],
(*tensors)[i]->dtype()));
++literals_index;
++sync_index;
} else {
c10::optional<at::Tensor> tensor_data =
(*tensors)[i]->CurrentTensorData();
if (tensor_data) {
results.push_back(*tensor_data);
} else {
XLA_CHECK_LT(literals_index, literals.size());
results.push_back(MakeTensorFromXlaLiteral(literals[literals_index],
(*tensors)[i]->dtype()));
++literals_index;
}
}
}
return results;
}
std::shared_ptr<XLAGraphExecutor::Async>
XLAGraphExecutor::ScheduleSyncTensorsGraph(
SyncTensorCollection* coll,
std::vector<torch::lazy::BackendDataPtr> parameters_data,
std::vector<torch::lazy::BackendDataPtr> tensors_data,
std::vector<XLATensor::ShardingSpecPtr> sharding_specs,
ComputationCache::TypePtr cached_computation) {
DebugUtil::analyze_graph_execution_python_frame(
DebugUtil::GraphAnalysisSource::Execution,
/*graph_hash=*/coll->hash,
/*program_shape=*/&(cached_computation->computation->program_shape()));
tsl::profiler::TraceMe activity("ScheduleSyncTensorsGraph",
tsl::profiler::TraceMeLevel::kInfo);
TensorCollectionBarrier(coll);
std::shared_ptr<XLAGraphExecutor::Async> async = std::make_shared<Async>(
coll, std::move(parameters_data), std::move(tensors_data),
std::move(cached_computation));
auto syncfn = [async, hash = coll->hash, sharding_specs = sharding_specs]() {
try {
std::vector<torch::lazy::BackendDataPtr> results;
// Execute replicated if the compiled computation is partitioned.
if (async->cached_computation->is_sharded) {
std::vector<std::string> devices =
runtime::GetComputationClient()->GetLocalDevices();
runtime::ComputationClient::ExecuteReplicatedOptions execute_options;
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash)
<< " on devices: " << absl::StrJoin(devices, ",");
// OutputHandler creates sharded data for sharded
// tensor results. Both sharded and unsharded results should be
// "Assign"ed to the corresponding data placeholders.
std::vector<runtime::ComputationClient::DataPtr> outputs =
runtime::GetComputationClient()->ExecuteReplicated(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), devices,
execute_options);
results = WrapXlaData(outputs);
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash)
<< " on devices: " << absl::StrJoin(devices, ",")
<< " done!";
} else {
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";
results = torch::lazy::getBackend()->ExecuteComputation(
async->cached_computation->computation, async->parameters_data,
async->device);