From f84f7becfae4f61e632850c7307657220683408f Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 10 Jul 2024 12:46:31 -0700 Subject: [PATCH] Cherry-pick Egaerly execute inplace ops if in eager mode #7322 (#7666) --- test/eager/test_eager.py | 67 +++++++++++++++++++++++++++++++++++++++ test/run_tests.sh | 3 +- torch_xla/csrc/tensor.cpp | 7 ++++ 3 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 test/eager/test_eager.py diff --git a/test/eager/test_eager.py b/test/eager/test_eager.py new file mode 100644 index 00000000000..8abf4d36460 --- /dev/null +++ b/test/eager/test_eager.py @@ -0,0 +1,67 @@ +import unittest +import sys + +import torch +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.core.xla_model as xm + + +class Eager(unittest.TestCase): + + @classmethod + def setUpClass(cls): + torch_xla.experimental.eager_mode(True) + + def test_eager_basic(self): + met.clear_all() + self.assertTrue(torch_xla.experimental.is_eager_mode()) + device = torch_xla.device() + + # For some reason randn will also trigger an execution of + # size [5, 5] full of 0. + t1 = torch.randn(5, 5, device=device) + xm.wait_device_ops() + self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 2) + self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 2) + + t1 *= 5 + xm.wait_device_ops() + self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 3) + self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 3) + + def test_eager_recompile(self): + self.assertTrue(torch_xla.experimental.is_eager_mode()) + device = torch_xla.device() + + t1 = torch.randn(5, 5, device=device) + xm.wait_device_ops() + met.clear_all() + + t2 = torch.logsumexp(t1, 0) + xm.wait_device_ops() + self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 1) + self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 1) + + t3 = torch.logsumexp(t1, 0) + xm.wait_device_ops() + # make sure no recompilation + self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 1) + self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 2) + + def test_eager_in_place(self): + self.assertTrue(torch_xla.experimental.is_eager_mode()) + device = torch_xla.device() + + t1 = torch.randn(5, 5, device=device) + xm.wait_device_ops() + met.clear_all() + xm.optimization_barrier_([t1]) + xm.wait_device_ops() + self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 1) + self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 1) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/run_tests.sh b/test/run_tests.sh index c1aabd738f9..4f0f53f423c 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -203,7 +203,8 @@ function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" - run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU + run_test "$CDIR/test_autocast.py" + run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" run_test "$CDIR/eager/test_eager_with_torch_compile.py" } diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 8f516f9016a..13eb94ae30d 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -358,6 +358,13 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) { torch::lazy::MakeNode(ir_value, xla_shape.get().element_type()); } SetIrValue(std::move(ir_value), /*inplace=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + + // in place update should also be triggered eagerly if configured + if (graph_executor->UseEagerMode()) { + std::vector xtensors({c10::make_intrusive(*this)}); + graph_executor->ApplyEagerSync(xtensors); + } } void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const {