diff --git a/docs/source/index.rst b/docs/source/index.rst index cf0eb8b0125..7823096ae81 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -77,6 +77,7 @@ experimental ---------------------------------- .. automodule:: torch_xla.experimental .. autofunction:: eager_mode +.. autofunction:: compile debug ---------------------------------- diff --git a/examples/eager/train_decoder_only_eager_with_compile.py b/examples/eager/train_decoder_only_eager_with_compile.py new file mode 100644 index 00000000000..f9a134dc8a0 --- /dev/null +++ b/examples/eager/train_decoder_only_eager_with_compile.py @@ -0,0 +1,21 @@ +import sys +import os +example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) +sys.path.append(example_folder) +from train_decoder_only_base import TrainDecoderOnlyBase + +import torch_xla + + +class TrainDecoderOnlyEagerWithCompile(TrainDecoderOnlyBase): + + def __init__(self): + super().__init__() + # step fn will be compiled and rest will be run eagerly. + self.step_fn = torch_xla.experimental.compile(self.step_fn) + + +if __name__ == '__main__': + torch_xla.experimental.eager_mode(True) + trainer = TrainDecoderOnlyEagerWithCompile() + trainer.start_training() diff --git a/examples/train_decoder_only_base.py b/examples/train_decoder_only_base.py index cd99a4303a5..63a6ad028aa 100644 --- a/examples/train_decoder_only_base.py +++ b/examples/train_decoder_only_base.py @@ -41,17 +41,21 @@ def _train_update(self, step, loss, tracker, epoch): def run_optimizer(self): self.optimizer.step() + def step_fn(self, data, target): + self.optimizer.zero_grad() + logits = self.model(data) + loss = self.loss_fn( + logits.view(-1, self.config.vocab_size), target.view(-1)) + loss.backward() + self.run_optimizer() + return loss + def train_loop_fn(self, loader, epoch): tracker = xm.RateTracker() self.model.train() loader = itertools.islice(loader, self.num_steps) for step, (data, target) in enumerate(loader): - self.optimizer.zero_grad() - logits = self.model(data) - loss = self.loss_fn( - logits.view(-1, self.config.vocab_size), target.view(-1)) - loss.backward() - self.run_optimizer() + loss = self.step_fn(data, target) tracker.add(self.batch_size) if step % 10 == 0: xm.add_step_closure( diff --git a/test/eager/test_eager_with_xla_compile.py b/test/eager/test_eager_with_xla_compile.py new file mode 100644 index 00000000000..0509316a0ef --- /dev/null +++ b/test/eager/test_eager_with_xla_compile.py @@ -0,0 +1,56 @@ +import unittest +import sys + +import torch +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.core.xla_model as xm + + +class EagerWithXLACompileTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + torch_xla.experimental.eager_mode(True) + + def dummy_cos_sin(self, tensor): + return torch.cos(torch.sin(tensor)) + + def test_eager_with_compile_basic(self): + met.clear_all() + self.assertTrue(torch_xla.experimental.is_eager_mode()) + device = torch_xla.device() + + # this part happens eagerly + t1 = torch.randn(5, 5, device=device) + t1 *= 5 + + t2 = self.dummy_cos_sin(t1) + t2_compiled = torch_xla.experimental.compile(self.dummy_cos_sin)(t1) + self.assertTrue(torch.allclose(t2, t2_compiled)) + xm.wait_device_ops() + # We execute one compiled graph + self.assertEqual(met.metric_data("ExecuteTime")[0], 1) + # and many eager ops + self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 5) + + +def test_eager_execute_compiled_multiple_times(self): + met.clear_all() + self.assertTrue(torch_xla.experimental.is_eager_mode()) + device = torch_xla.device() + # this part happens eagerly + t1 = torch.randn(10, 5, device=device) + t1.add_(0.5) + compiled = torch_xla.experimental.compile(self.dummy_cos_sin) + res = compiled(compiled(t1)) + self.assertTrue( + torch.allclose(res * 0.3, + self.dummy_cos_sin(self.dummy_cos_sin(t1)) * 0.3)) + xm.wait_device_ops() + self.assertEqual(met.metric_data("ExecuteTime")[0], 2) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index c7f12f93e14..281b6751f67 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -37,3 +37,4 @@ python3 examples/fsdp/train_decoder_only_fsdp_v2.py python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py python3 examples/train_resnet_amp.py python3 examples/eager/train_decoder_only_eager.py +python3 examples/eager/train_decoder_only_eager_with_compile.py diff --git a/torch_xla/experimental/__init__.py b/torch_xla/experimental/__init__.py index 68bd36dc06d..1676e7d6349 100644 --- a/torch_xla/experimental/__init__.py +++ b/torch_xla/experimental/__init__.py @@ -1,5 +1,7 @@ -from .eager import eager_mode +from .eager import eager_mode, compile, is_eager_mode __all__ = [ "eager_mode", -] \ No newline at end of file + "compile", + "is_eager_mode", +] diff --git a/torch_xla/experimental/eager.py b/torch_xla/experimental/eager.py index 37861fdd8bf..085df0419c3 100644 --- a/torch_xla/experimental/eager.py +++ b/torch_xla/experimental/eager.py @@ -1,9 +1,48 @@ +import functools + import torch_xla def eager_mode(enable: bool): """Configure torch_xla's default executation mode. + Under eager mode only functions that was `torch_xla.compile`d will be traced and compiled. Other torch ops will be executed eagerly. """ torch_xla._XLAC._set_use_eager_mode(enable) + + +def is_eager_mode() -> bool: + """Return True if torch_xla is currently under eager mode + """ + return torch_xla._XLAC._get_use_eager_mode() + + +def compile(func): + """Compile the func with Lazy Tensor. + + Return the optimized function that takes exact same input. Compile will + run the target func under the tracing mode using Lazy tensor. + """ + + @functools.wraps(func) # Keep function's name, docstring, etc. + def wrapper(*args, **kwargs): + # compile should only be called with + assert torch_xla._XLAC._get_use_eager_mode() == True + torch_xla._XLAC._set_use_eager_mode(False) + # clear the pending graph if any + torch_xla.sync() + try: + # Target Function Execution + result = func(*args, **kwargs) + except Exception as e: + # Handle exceptions (if needed) + print(f"Error in target function: {e}") + raise # Re-raise the exception + # Sync the graph generated by the target function. + torch_xla.sync() + torch_xla._XLAC._set_use_eager_mode(True) + + return result + + return wrapper