From 63845167bbb81a42ecc73dc5835868befd1baa0f Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 1 Nov 2023 19:39:47 +0000 Subject: [PATCH] server -> device --- TROUBLESHOOTING.md | 6 ++-- docs/first_steps.md | 6 ++-- docs/pytorch_xla_overview.md | 6 ++-- test/cpp/cpp_test_util.cpp | 2 +- test/cpp/test_replication.cpp | 2 +- test/metrics_compare_utils_test.py | 32 +++++++++---------- test/pjrt/test_metrics.py | 4 +-- test/pjrt/test_profiler.py | 2 +- test/spmd/test_xla_virtual_device.py | 8 ++--- test/test_metrics.py | 6 ++-- test/test_operations.py | 6 ++-- test/test_profiler.py | 4 +-- torch_xla/csrc/device.h | 2 +- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/runtime/computation_client.cc | 12 +++---- torch_xla/csrc/runtime/computation_client.h | 12 +++---- torch_xla/csrc/runtime/metrics_analysis.cc | 2 +- .../csrc/runtime/pjrt_computation_client.cc | 18 +++++------ .../csrc/runtime/pjrt_computation_client.h | 6 ++-- .../runtime/pjrt_computation_client_test.cc | 4 +-- torch_xla/csrc/tensor_util.cpp | 8 ++--- torch_xla/csrc/xla_graph_executor.cpp | 6 ++-- torch_xla/csrc/xla_sharding_util.cpp | 2 +- torch_xla/debug/metrics.py | 2 +- 24 files changed, 80 insertions(+), 80 deletions(-) diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index ca3ddf702c2..3c7122b1fa1 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -43,7 +43,7 @@ vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data ``` -If you can get the resnet to run we can conclude that torch_xla is installed correctly. +If you can get the resnet to run we can conclude that torch_xla is installed correctly. ## Performance Debugging @@ -60,10 +60,10 @@ We provide ways to automatically analyze the metrics report and provide a summar ``` pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps -pt-xla-profiler: TransferFromServerTime too frequent: 11 counts during 11 steps +pt-xla-profiler: TransferFromDeviceTime too frequent: 11 counts during 11 steps pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps -pt-xla-profiler: TransferFromServerTime too frequent: 12 counts during 12 steps +pt-xla-profiler: TransferFromDeviceTime too frequent: 12 counts during 12 steps ``` Following section will explain how to get and understand a more detail metrics report. diff --git a/docs/first_steps.md b/docs/first_steps.md index a0591a3430d..2658d2d8bb8 100644 --- a/docs/first_steps.md +++ b/docs/first_steps.md @@ -19,7 +19,7 @@ For more details and examples, please refer to the [LazyTensor guide](https://py The operations in the IR graph are executed only when values of tensors are needed. This is referred to as evaluation or materialization of tensors. Sometimes this is also called lazy evaluation and it can lead to significant [performance improvements](https://arxiv.org/pdf/2102.13267.pdf). -The _synchronous_ operations in Pytorch XLA, like printing, logging, checkpointing or callbacks block tracing and result in slower execution. In the case when an operation requires a specific value of an XLA tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value of that tensor is available to the host. Note that only the part of the graph responsible for computing that tensor value is executed. These operations do not cut the IR graph, but they trigger host-device communication through `TransferFromServer`, which results in slower performance. +The _synchronous_ operations in Pytorch XLA, like printing, logging, checkpointing or callbacks block tracing and result in slower execution. In the case when an operation requires a specific value of an XLA tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value of that tensor is available to the host. Note that only the part of the graph responsible for computing that tensor value is executed. These operations do not cut the IR graph, but they trigger host-device communication through `TransferFromDevice`, which results in slower performance. A _barrier_ is a special instruction that tells XLA to execute the IR graph and materialize the tensors. This means that the PyTorch XLA tensors will be evaluated, and the results will be available to the host. The user-exposed barrier in Pytorch XLA is [xm.mark_step()](https://github.com/pytorch/xla/blob/bdceee54eca1269ee954f6cdd1868c584d0e88a4/torch_xla/core/xla_model.py#L808), which breaks the IR graph and results in code execution on the XLA devices. One of the key properties of `xm.mark_step` is that unlike synchronous operations it does not block the further tracing while the device is executing the graph. However, it does block access to the values of the tensors that are being materialized. @@ -233,9 +233,9 @@ Now, let's examine the XL version of the model and do the same thing. We will ad This time, in addition to the large gap in the middle, which is caused by the `pipe_watermark` tracing, there are many small gaps between the inference steps within [this loop](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L814-L830). -First look closer into the large gap that is caused by `pipe_watermark`. The gap is preceded with `TransferFromServer` which indicates that something is happening on the host machine that is waiting for computation to finish before proceeding. Looking into watermark [code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), we can see that tensors are transferred to cpu and converted to numpy arrays in order to be processed with `cv2` and `pywt` libraries later. Since this part is not straightforward to optimize, we will leave this as is. +First look closer into the large gap that is caused by `pipe_watermark`. The gap is preceded with `TransferFromDevice` which indicates that something is happening on the host machine that is waiting for computation to finish before proceeding. Looking into watermark [code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), we can see that tensors are transferred to cpu and converted to numpy arrays in order to be processed with `cv2` and `pywt` libraries later. Since this part is not straightforward to optimize, we will leave this as is. -Now if we zoom in on the loop, we can see that the graph within the loop is broken into smaller parts because the `TransferFromServer` operation happens. +Now if we zoom in on the loop, we can see that the graph within the loop is broken into smaller parts because the `TransferFromDevice` operation happens. ![Alt text](assets/image-3.png) diff --git a/docs/pytorch_xla_overview.md b/docs/pytorch_xla_overview.md index 456273515ed..da087098cae 100644 --- a/docs/pytorch_xla_overview.md +++ b/docs/pytorch_xla_overview.md @@ -21,7 +21,7 @@ For more details and examples, please refer to the [LazyTensor guide](https://py The operations in the IR graph are executed only when values of tensors are needed. This is referred to as evaluation or materialization of tensors. Sometimes this is also called lazy evaluation and it can lead to significant [performance improvements](https://arxiv.org/pdf/2102.13267.pdf). -The _synchronous_ operations in Pytorch XLA, like printing, logging, checkpointing or callbacks block tracing and result in slower execution. In the case when an operation requires a specific value of an XLA tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value of that tensor is available to the host. Note that only the part of the graph responsible for computing that tensor value is executed. These operations do not cut the IR graph, but they trigger host-device communication through `TransferFromServer`, which results in slower performance. +The _synchronous_ operations in Pytorch XLA, like printing, logging, checkpointing or callbacks block tracing and result in slower execution. In the case when an operation requires a specific value of an XLA tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value of that tensor is available to the host. Note that only the part of the graph responsible for computing that tensor value is executed. These operations do not cut the IR graph, but they trigger host-device communication through `TransferFromDevice`, which results in slower performance. A _barrier_ is a special instruction that tells XLA to execute the IR graph and materialize the tensors. This means that the PyTorch XLA tensors will be evaluated, and the results will be available to the host. The user-exposed barrier in Pytorch XLA is [xm.mark_step()](https://github.com/pytorch/xla/blob/bdceee54eca1269ee954f6cdd1868c584d0e88a4/torch_xla/core/xla_model.py#L808), which breaks the IR graph and results in code execution on the XLA devices. One of the key properties of `xm.mark_step` is that unlike synchronous operations it does not block the further tracing while the device is executing the graph. However, it does block access to the values of the tensors that are being materialized. @@ -235,9 +235,9 @@ Now, let's examine the XL version of the model and do the same thing. We will ad This time, in addition to the large gap in the middle, which is caused by the `pipe_watermark` tracing, there are many small gaps between the inference steps within [this loop](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L814-L830). -First look closer into the large gap that is caused by `pipe_watermark`. The gap is preceded with `TransferFromServer` which indicates that something is happening on the host machine that is waiting for computation to finish before proceeding. Looking into watermark [code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), we can see that tensors are transferred to cpu and converted to numpy arrays in order to be processed with `cv2` and `pywt` libraries later. Since this part is not straightforward to optimize, we will leave this as is. +First look closer into the large gap that is caused by `pipe_watermark`. The gap is preceded with `TransferFromDevice` which indicates that something is happening on the host machine that is waiting for computation to finish before proceeding. Looking into watermark [code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), we can see that tensors are transferred to cpu and converted to numpy arrays in order to be processed with `cv2` and `pywt` libraries later. Since this part is not straightforward to optimize, we will leave this as is. -Now if we zoom in on the loop, we can see that the graph within the loop is broken into smaller parts because the `TransferFromServer` operation happens. +Now if we zoom in on the loop, we can see that the graph within the loop is broken into smaller parts because the `TransferFromDevice` operation happens. ![Alt text](assets/image-3.png) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 28f3d23d5c1..e9c3ccebdc9 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -302,7 +302,7 @@ std::vector Fetch( absl::Span device_data) { std::vector literals = - torch_xla::runtime::GetComputationClient()->TransferFromServer( + torch_xla::runtime::GetComputationClient()->TransferFromDevice( device_data); std::vector tensors; for (auto& literal : literals) { diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 6d7a54add0c..e4b8b168342 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -76,7 +76,7 @@ void TestSingleReplication( for (size_t i = 0; i < results.size(); ++i) { auto literals = - torch_xla::runtime::GetComputationClient()->TransferFromServer( + torch_xla::runtime::GetComputationClient()->TransferFromDevice( results[i]); ASSERT_EQ(literals.size(), 1); diff --git a/test/metrics_compare_utils_test.py b/test/metrics_compare_utils_test.py index 3d88578ccb3..69a942646a8 100644 --- a/test/metrics_compare_utils_test.py +++ b/test/metrics_compare_utils_test.py @@ -10,7 +10,7 @@ Accumulator: 10GB Rate: 16.8665 / second Percentiles: 1%=393.00KB; 5%=393.00KB; 10%=786.00KB; 20%=1.54MB; 50%=1.54MB; 80%=1.54MB; 90%=1.54MB; 95%=1.54MB; 99%=1.54MB -Metric: TransferToServerTime +Metric: TransferToDeviceTime TotalSamples: 2616 Accumulator: 01m29s615ms ValueRate: 783ms426.227us / second @@ -27,7 +27,7 @@ TotalSamples: 73216 Accumulator: 64.75TB Percentiles: 1%=393.00KB; 5%=393.00KB; 10%=786.00KB; 20%=1.54MB; 50%=1.54MB; 80%=1.54MB; 90%=1.54MB; 95%=1.54MB; 99%=1.54MB -Metric: TransferToServerTime +Metric: TransferToDeviceTime TotalSamples: 247016 Accumulator: 04d17h11m07s495ms546.299us Percentiles: 1%=05m003ms; 5%=05m004ms; 10%=05m010ms; 20%=05m015ms; 50%=05m026ms; 80%=05m035ms; 90%=05m082ms; 95%=05m108ms; 99%=05m129ms @@ -43,7 +43,7 @@ TotalSamples: 73216 Accumulator: 64.75GB Percentiles: 1%=393.00KB; 5%=393.00KB; 10%=786.00KB; 20%=1.54MB; 50%=1.54MB; 80%=1.54MB; 90%=1.54MB; 95%=1.54MB; 99%=1.54MB -Metric: TransferToServerTime +Metric: TransferToDeviceTime TotalSamples: 247016 Accumulator: 1s Percentiles: 1%=05m003ms; 5%=05m004ms; 10%=05m010ms; 20%=05m015ms; 50%=05m026ms; 80%=05m035ms; 90%=05m082ms; 95%=05m108ms; 99%=05m129ms @@ -66,7 +66,7 @@ TotalSamples: 70000 Accumulator: 74.75GB Percentiles: 1%=393.00KB; 5%=393.00KB; 10%=786.00KB; 20%=1.54MB; 50%=1.54MB; 80%=1.54MB; 90%=1.54MB; 95%=1.54MB; 99%=1.54MB -Metric: TransferToServerTime +Metric: TransferToDeviceTime TotalSamples: 247016 Accumulator: 1s Percentiles: 1%=05m003ms; 5%=05m004ms; 10%=05m010ms; 20%=05m015ms; 50%=05m026ms; 80%=05m035ms; 90%=05m082ms; 95%=05m108ms; 99%=05m129ms @@ -89,7 +89,7 @@ TotalSamples: 73216 Accumulator: 64.75GB Percentiles: 1%=393.00KB; 5%=393.00KB; 10%=786.00KB; 20%=1.54MB; 50%=1.54MB; 80%=1.54MB; 90%=1.54MB; 95%=1.54MB; 99%=1.54MB -Metric: TransferToServerTime +Metric: TransferToDeviceTime TotalSamples: 247016 Accumulator: 1s Percentiles: 1%=05m003ms; 5%=05m004ms; 10%=05m010ms; 20%=05m015ms; 50%=05m026ms; 80%=05m035ms; 90%=05m082ms; 95%=05m108ms; 99%=05m129ms @@ -157,19 +157,19 @@ def test_get_data_points_from_metrics_reports(self): 'InboundData__Percentile_90_mb': [1.54, 1.54, 1.54], 'InboundData__Percentile_95_mb': [1.54, 1.54, 1.54], 'InboundData__Percentile_99_mb': [1.54, 1.54, 1.54], - 'TransferToServerTime__TotalSamples': [2616.0, 247016.0, 247016.0], - 'TransferToServerTime__Accumulator_sec': [ + 'TransferToDeviceTime__TotalSamples': [2616.0, 247016.0, 247016.0], + 'TransferToDeviceTime__Accumulator_sec': [ 89.615, 407467.495546299, 1.0 ], - 'TransferToServerTime__Percentile_1_sec': [300.003, 300.003, 300.003], - 'TransferToServerTime__Percentile_5_sec': [300.004, 300.004, 300.004], - 'TransferToServerTime__Percentile_10_sec': [300.01, 300.01, 300.01], - 'TransferToServerTime__Percentile_20_sec': [300.015, 300.015, 300.015], - 'TransferToServerTime__Percentile_50_sec': [300.026, 300.026, 300.026], - 'TransferToServerTime__Percentile_80_sec': [300.035, 300.035, 300.035], - 'TransferToServerTime__Percentile_90_sec': [300.082, 300.082, 300.082], - 'TransferToServerTime__Percentile_95_sec': [300.108, 300.108, 300.108], - 'TransferToServerTime__Percentile_99_sec': [300.129, 300.129, 300.129], + 'TransferToDeviceTime__Percentile_1_sec': [300.003, 300.003, 300.003], + 'TransferToDeviceTime__Percentile_5_sec': [300.004, 300.004, 300.004], + 'TransferToDeviceTime__Percentile_10_sec': [300.01, 300.01, 300.01], + 'TransferToDeviceTime__Percentile_20_sec': [300.015, 300.015, 300.015], + 'TransferToDeviceTime__Percentile_50_sec': [300.026, 300.026, 300.026], + 'TransferToDeviceTime__Percentile_80_sec': [300.035, 300.035, 300.035], + 'TransferToDeviceTime__Percentile_90_sec': [300.082, 300.082, 300.082], + 'TransferToDeviceTime__Percentile_95_sec': [300.108, 300.108, 300.108], + 'TransferToDeviceTime__Percentile_99_sec': [300.129, 300.129, 300.129], 'UniqueMetric__TotalSamples': [None, None, 9000.0], 'UniqueMetric__Accumulator': [None, None, 9000.0], 'UniqueMetric__Percentile_1': [None, None, 8902.0], diff --git a/test/pjrt/test_metrics.py b/test/pjrt/test_metrics.py index 360de76c1d8..5cee1b7ea5d 100644 --- a/test/pjrt/test_metrics.py +++ b/test/pjrt/test_metrics.py @@ -13,8 +13,8 @@ "ExecuteTime", "InboundData", "OutboundData", - "TransferFromServerTime", - "TransferToServerTime", + "TransferFromDeviceTime", + "TransferToDeviceTime", ] diff --git a/test/pjrt/test_profiler.py b/test/pjrt/test_profiler.py index f458b7fafa7..383cf21e489 100644 --- a/test/pjrt/test_profiler.py +++ b/test/pjrt/test_profiler.py @@ -56,7 +56,7 @@ def test_profiler_output(self): content = file.read() ascii_content = codecs.decode(content, 'ascii', errors='ignore') - expected_methods = ('TransferToServer', 'Compile', 'ExecuteComputation') + expected_methods = ('TransferToDevice', 'Compile', 'ExecuteComputation') for method in (f'PjRtComputationClient::{m}' for m in expected_methods): self.assertIn(method, ascii_content) diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index ac304e7285d..0f435e881c9 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -117,7 +117,7 @@ def test_virtual_device_no_upload(self): t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) # t1's upload to device should be deferred self.assertIn("Tensor on host: with size [5, 5]", t1_debug_info) - self.assertNotIn("TransferToServerTime", met.metric_names()) + self.assertNotIn("TransferToDeviceTime", met.metric_names()) # t1 should be on SPMD device under spmd context self.assertIn("Device: SPMD:0", t1_debug_info) self.assertIn("IR: None", t1_debug_info) @@ -136,7 +136,7 @@ def test_virtual_device_upload_after_mark_sharding(self): self.assertIn("Tensor on host: None", t1_debug_info_new) self.assertIn("xla::device_data", t1_debug_info_new) self.assertIn("XLAShardedData", t1_debug_info_new) - self.assertIn("TransferToServerTime", met.metric_names()) + self.assertIn("TransferToDeviceTime", met.metric_names()) def test_virtual_device_upload_after_tracing(self): met.clear_all() @@ -149,7 +149,7 @@ def test_virtual_device_upload_after_tracing(self): # tensor should be uploaded to device after being used as input to other op. self.assertIn("Tensor on host: None", t1_debug_info_new) self.assertIn("xla::device_data", t1_debug_info_new) - self.assertIn("TransferToServerTime", met.metric_names()) + self.assertIn("TransferToDeviceTime", met.metric_names()) def test_virtual_device_upload_for_sharded_dataloader(self): met.clear_counters() @@ -165,7 +165,7 @@ def test_virtual_device_upload_for_sharded_dataloader(self): self.assertIn("Tensor on host: None", t1_debug_info) self.assertIn("xla::device_data", t1_debug_info) self.assertIn("XLAShardedData", t1_debug_info) - self.assertIn("TransferToServerTime", met.metric_names()) + self.assertIn("TransferToDeviceTime", met.metric_names()) if __name__ == '__main__': diff --git a/test/test_metrics.py b/test/test_metrics.py index 3037692ea65..f2d4e33eff3 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -48,8 +48,8 @@ def test_short_metrics_report_default_list(self): self.assertNotIn("TensorToData", short_report) self.assertIn("CompileTime", short_report) self.assertIn("ExecuteTime", short_report) - self.assertIn("TransferToServerTime", short_report) - self.assertIn("TransferFromServerTime", short_report) + self.assertIn("TransferToDeviceTime", short_report) + self.assertIn("TransferFromDeviceTime", short_report) self.assertIn("MarkStep", short_report) # repeat the same computation and expect to see the CachedCompile counter t3 = t1 * 2 @@ -93,7 +93,7 @@ def test_short_metrics_fallback_counter(self): metric_names=['InboundData'])) def test_metrics_report(self): - # TODO(jwtan): Add test to cover TrimIrGraph, SyncTensorsToData, TransferToServerAsync, IrValueTensorToXlaData + # TODO(jwtan): Add test to cover TrimIrGraph, SyncTensorsToData, TransferToDeviceAsync, IrValueTensorToXlaData xla_device = xm.xla_device() t1 = torch.tensor(2077, device=xla_device) t2 = t1 * 2 diff --git a/test/test_operations.py b/test/test_operations.py index 99d4c1265c1..8e31a29f3c4 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1647,13 +1647,13 @@ def test_cached_addcdiv(self): t3 = torch.randn(1, 3).to(xla_device) t1.addcdiv_(t2, t3, value=0.1) xm.mark_step() - self.assertEqual(met.metric_data("TransferToServerTime")[0], 4) + self.assertEqual(met.metric_data("TransferToDeviceTime")[0], 4) - # The following two scalars shouldn't trigger TransferToServerTime. + # The following two scalars shouldn't trigger TransferToDeviceTime. t1.addcdiv_(t2, t3, value=0.1) t1.addcdiv_(t2, t3, value=0.1) xm.mark_step() - self.assertEqual(met.metric_data("TransferToServerTime")[0], 4) + self.assertEqual(met.metric_data("TransferToDeviceTime")[0], 4) @skipOnEagerDebug def test_print_executation(self): diff --git a/test/test_profiler.py b/test/test_profiler.py index 4d9969974c3..7c65bcfd460 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -30,8 +30,8 @@ def _check_metrics_warnings_exist(self, fname): with open(fname, 'r') as f: debug_warnings = f.read() logging.info(f'PT_XLA_DEBUG_FILE Contents:\n{debug_warnings}') - self.assertTrue('TransferFromServerTime too frequent' in debug_warnings, - f'Expected "TransferFromServerTime" warning in: {fname}') + self.assertTrue('TransferFromDeviceTime too frequent' in debug_warnings, + f'Expected "TransferFromDeviceTime" warning in: {fname}') self.assertTrue('CompileTime too frequent' in debug_warnings, f'Expected "CompileTime" wraning in: {fname}') diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index 1dc939bb17a..440a945ae06 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -12,7 +12,7 @@ namespace torch_xla { -// TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToServer` +// TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToDevice` // until after the paritioning pass. This avoids transfering the full input // tensor to the device. enum class XlaDeviceType { CPU, CUDA, ROCM, GPU, TPU, XPU, NEURON, SPMD }; diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1884310c5fd..d0808fcf357 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -792,7 +792,7 @@ class PyLoweringContext { // Fetch this parameter data std::vector literals = - runtime::GetComputationClient()->TransferFromServer( + runtime::GetComputationClient()->TransferFromDevice( UnwrapXlaData(device_data)); // Create a mapping from paramater id to the tensor data diff --git a/torch_xla/csrc/runtime/computation_client.cc b/torch_xla/csrc/runtime/computation_client.cc index b2feb2e25dc..f29bfb90a94 100644 --- a/torch_xla/csrc/runtime/computation_client.cc +++ b/torch_xla/csrc/runtime/computation_client.cc @@ -53,21 +53,21 @@ int64_t ComputationClient::GetDeviceOrdinal(const std::string& device) { return std::stoi(device.substr(pos + 1)); } -metrics::Metric* ComputationClient::TransferToServerMetric() { +metrics::Metric* ComputationClient::TransferToDeviceMetric() { static metrics::Metric* metric = - new metrics::Metric("TransferToServerTime", metrics::MetricFnTime); + new metrics::Metric("TransferToDeviceTime", metrics::MetricFnTime); return metric; } -metrics::Metric* ComputationClient::TransferToServerTransformMetric() { +metrics::Metric* ComputationClient::TransferToDeviceTransformMetric() { static metrics::Metric* metric = new metrics::Metric( - "TransferToServerTransformTime", metrics::MetricFnTime); + "TransferToDeviceTransformTime", metrics::MetricFnTime); return metric; } -metrics::Metric* ComputationClient::TransferFromServerMetric() { +metrics::Metric* ComputationClient::TransferFromDeviceMetric() { static metrics::Metric* metric = - new metrics::Metric("TransferFromServerTime", metrics::MetricFnTime); + new metrics::Metric("TransferFromDeviceTime", metrics::MetricFnTime); return metric; } diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 19d72aedc61..07bf335f117 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -251,12 +251,12 @@ class ComputationClient { virtual std::optional GetDataSharding(DataPtr handle) = 0; // Transfers local tensor values to the TPU devices and fetches the handles. - virtual std::vector TransferToServer( + virtual std::vector TransferToDevice( absl::Span> tensors) = 0; // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. - virtual DataPtr TransferShardsToServer( + virtual DataPtr TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) = 0; @@ -265,7 +265,7 @@ class ComputationClient { // Reads the tensor literal values stored at TPU server sites, behind the // supplied handles. - virtual std::vector TransferFromServer( + virtual std::vector TransferFromDevice( absl::Span handles) = 0; // Compiles a set of computations. @@ -353,9 +353,9 @@ class ComputationClient { protected: // Metrics common to all client interfaces. - static metrics::Metric* TransferToServerMetric(); - static metrics::Metric* TransferToServerTransformMetric(); - static metrics::Metric* TransferFromServerMetric(); + static metrics::Metric* TransferToDeviceMetric(); + static metrics::Metric* TransferToDeviceTransformMetric(); + static metrics::Metric* TransferFromDeviceMetric(); static metrics::Metric* CompileMetric(); static metrics::Metric* ExecuteMetric(); static metrics::Metric* ExecuteReplicatedMetric(); diff --git a/torch_xla/csrc/runtime/metrics_analysis.cc b/torch_xla/csrc/runtime/metrics_analysis.cc index 5d99165459d..eccc18de765 100644 --- a/torch_xla/csrc/runtime/metrics_analysis.cc +++ b/torch_xla/csrc/runtime/metrics_analysis.cc @@ -186,7 +186,7 @@ class UnloweredOp : public Analyzer { std::vector* GetAnalyzers() { static std::vector* analyzers = new std::vector{ new MetricFrequency("CompileTime", 0.5f, 10), - new MetricFrequency("TransferFromServerTime", 0.5f), + new MetricFrequency("TransferFromDeviceTime", 0.5f), new MetricTime("CompileTime", 300e9), new MetricTime("ExecuteTime", 30e9), new UnloweredOp(), diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index c16b35562cd..81205e0b66a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -260,10 +260,10 @@ std::optional PjRtComputationClient::GetDataSharding( return std::optional(); } -std::vector PjRtComputationClient::TransferToServer( +std::vector PjRtComputationClient::TransferToDevice( absl::Span> tensors) { - metrics::TimedSection timed(TransferToServerMetric()); - tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer", + metrics::TimedSection timed(TransferToDeviceMetric()); + tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToDevice", tsl::profiler::TraceMeLevel::kInfo); std::vector datas; datas.reserve(tensors.size()); @@ -300,18 +300,18 @@ std::vector PjRtComputationClient::TransferToServer( return datas; } -ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer( +ComputationClient::DataPtr PjRtComputationClient::TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( - "PjRtComputationClient::TransferShardsToServer", + "PjRtComputationClient::TransferShardsToDevice", tsl::profiler::TraceMeLevel::kInfo); // TODO(jonbolin): Consider using CopyToDevice when sharding is REPLICATED. // We are opting out of CopyToDevice for now due to the synchronization // issues observed in ShardingUtil::InputHandler, but because CopyToDevice // directly copies buffers between devices using ICI, it can be much faster // than transferring from the host to each device. - auto data_shards = TransferToServer(tensor_shards); + auto data_shards = TransferToDevice(tensor_shards); std::vector> pjrt_data_shards; for (auto& shard : data_shards) { auto pjrt_shard = dynamic_cast(shard.get()); @@ -415,10 +415,10 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData( return handle; } -std::vector PjRtComputationClient::TransferFromServer( +std::vector PjRtComputationClient::TransferFromDevice( absl::Span handles) { - metrics::TimedSection timed(TransferFromServerMetric()); - tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromServer", + metrics::TimedSection timed(TransferFromDeviceMetric()); + tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); std::vector literals; literals.reserve(handles.size()); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 91b48f509f8..7702d91122a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -35,16 +35,16 @@ class PjRtComputationClient : public ComputationClient { std::optional GetDataSharding(DataPtr handle) override; - std::vector TransferToServer( + std::vector TransferToDevice( absl::Span> tensors) override; // Use XLA replication to re-assemble the sharded data. DataPtr ReplicateShardedData(const DataPtr& handle); - std::vector TransferFromServer( + std::vector TransferFromDevice( absl::Span handles) override; - DataPtr TransferShardsToServer(absl::Span> tensor_shards, + DataPtr TransferShardsToDevice(absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc index d6240f08e98..0c677393e2d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc @@ -66,12 +66,12 @@ TEST(PjRtComputationClientTest, Init) { // Execute the graph. std::vector results = client->ExecuteComputation( - *computations[0], client->TransferToServer(absl::MakeConstSpan(args)), + *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), device, options); // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); - auto result_literals = client->TransferFromServer(results); + auto result_literals = client->TransferFromDevice(results); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 94410b2184c..280d205ac95 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -591,7 +591,7 @@ torch::lazy::BackendDataPtr TensorToXlaData( source_tensors.push_back(std::make_shared(tensor, shape, device.toString())); auto handles = - runtime::GetComputationClient()->TransferToServer(source_tensors); + runtime::GetComputationClient()->TransferToDevice(source_tensors); XLA_CHECK_EQ(handles.size(), 1); return handles.front(); } @@ -817,7 +817,7 @@ std::vector CreateTensorsData( source_tensors.push_back(std::make_shared(tensors[i], std::move(shape), devices[i])); } return WrapXlaData( - runtime::GetComputationClient()->TransferToServer(source_tensors)); + runtime::GetComputationClient()->TransferToDevice(source_tensors)); } std::vector CreateTensorsData( @@ -851,7 +851,7 @@ std::vector CreateTensorsData( } else { source_tensors.push_back(std::make_shared(tensors[i], std::move(shape), devices[i])); new_handles = - runtime::GetComputationClient()->TransferToServer(source_tensors); + runtime::GetComputationClient()->TransferToDevice(source_tensors); } handles.insert(handles.end(), new_handles.begin(), new_handles.end()); } @@ -879,7 +879,7 @@ std::vector XlaDataToTensors( absl::Span xla_data, at::ScalarType dest_element_type) { std::vector literals = - runtime::GetComputationClient()->TransferFromServer( + runtime::GetComputationClient()->TransferFromDevice( UnwrapXlaData(xla_data)); std::vector tensors; tensors.reserve(literals.size()); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 45ed57b1b04..63163e0a381 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -421,7 +421,7 @@ std::vector XLAGraphExecutor::GetTensors( async != nullptr ? async->tensors_data : absl::Span()); - // Execution is async in PJRT, so TransferFromServer may block until execution + // Execution is async in PJRT, so TransferFromDevice may block until execution // completes. Release the GIL so other threads can proceed and unblock any // collective computations. // HACK: This method may be called outside of python (mainly in C++ tests) or @@ -436,7 +436,7 @@ std::vector XLAGraphExecutor::GetTensors( save = PyEval_SaveThread(); } std::vector literals = - runtime::GetComputationClient()->TransferFromServer( + runtime::GetComputationClient()->TransferFromDevice( UnwrapXlaData(tensors_data)); if (save) { PyEval_RestoreThread(save); @@ -1333,7 +1333,7 @@ XLAGraphExecutor::SyncTensorsGraphInternal( // second `SyncTensorsGraphInternal` will find there is nothing to sync and // return. It is possible that by the time second `SyncTensorsGraphInternal` // returned, first computation is still running. If user trying to call - // `TransferFromServer` on placeholder XLAData, runtime will segfault. Force + // `TransferFromDevice` on placeholder XLAData, runtime will segfault. Force // the `SyncTensorsGraphInternal` to block until previous computation either // here or in `ScheduleSyncTensorsGraph` will solve this issue. TensorCollectionBarrier(&coll); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 059fc3ba2b5..128b8840d37 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -730,7 +730,7 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( CreateComputationShapeFromTensor(local_shards[j], &shard_device); source_tensors.push_back(std::make_shared(local_shards[j], shard_shape, devices[j])); } - return runtime::GetComputationClient()->TransferShardsToServer( + return runtime::GetComputationClient()->TransferShardsToDevice( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } diff --git a/torch_xla/debug/metrics.py b/torch_xla/debug/metrics.py index cc2f94021f9..108b0d61bf5 100644 --- a/torch_xla/debug/metrics.py +++ b/torch_xla/debug/metrics.py @@ -76,6 +76,6 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None): if not metric_names: metric_names = [ 'CompileTime', 'ExecuteTime', 'ExecuteReplicatedTime', - 'TransferToServerTime', 'TransferFromServerTime' + 'TransferToDeviceTime', 'TransferFromDeviceTime' ] return torch_xla._XLAC._short_xla_metrics_report(counter_names, metric_names)