Skip to content

Commit

Permalink
add torch_xla.experimental.compile for eager mode (#7246)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jun 12, 2024
1 parent f4665a7 commit 90168e8
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ experimental
----------------------------------
.. automodule:: torch_xla.experimental
.. autofunction:: eager_mode
.. autofunction:: compile

debug
----------------------------------
Expand Down
21 changes: 21 additions & 0 deletions examples/eager/train_decoder_only_eager_with_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_decoder_only_base import TrainDecoderOnlyBase

import torch_xla


class TrainDecoderOnlyEagerWithCompile(TrainDecoderOnlyBase):

def __init__(self):
super().__init__()
# step fn will be compiled and rest will be run eagerly.
self.step_fn = torch_xla.experimental.compile(self.step_fn)


if __name__ == '__main__':
torch_xla.experimental.eager_mode(True)
trainer = TrainDecoderOnlyEagerWithCompile()
trainer.start_training()
16 changes: 10 additions & 6 deletions examples/train_decoder_only_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,21 @@ def _train_update(self, step, loss, tracker, epoch):
def run_optimizer(self):
self.optimizer.step()

def step_fn(self, data, target):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(
logits.view(-1, self.config.vocab_size), target.view(-1))
loss.backward()
self.run_optimizer()
return loss

def train_loop_fn(self, loader, epoch):
tracker = xm.RateTracker()
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(
logits.view(-1, self.config.vocab_size), target.view(-1))
loss.backward()
self.run_optimizer()
loss = self.step_fn(data, target)
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(
Expand Down
56 changes: 56 additions & 0 deletions test/eager/test_eager_with_xla_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
import sys

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm


class EagerWithXLACompileTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
torch_xla.experimental.eager_mode(True)

def dummy_cos_sin(self, tensor):
return torch.cos(torch.sin(tensor))

def test_eager_with_compile_basic(self):
met.clear_all()
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()

# this part happens eagerly
t1 = torch.randn(5, 5, device=device)
t1 *= 5

t2 = self.dummy_cos_sin(t1)
t2_compiled = torch_xla.experimental.compile(self.dummy_cos_sin)(t1)
self.assertTrue(torch.allclose(t2, t2_compiled))
xm.wait_device_ops()
# We execute one compiled graph
self.assertEqual(met.metric_data("ExecuteTime")[0], 1)
# and many eager ops
self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 5)


def test_eager_execute_compiled_multiple_times(self):
met.clear_all()
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()
# this part happens eagerly
t1 = torch.randn(10, 5, device=device)
t1.add_(0.5)
compiled = torch_xla.experimental.compile(self.dummy_cos_sin)
res = compiled(compiled(t1))
self.assertTrue(
torch.allclose(res * 0.3,
self.dummy_cos_sin(self.dummy_cos_sin(t1)) * 0.3))
xm.wait_device_ops()
self.assertEqual(met.metric_data("ExecuteTime")[0], 2)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ python3 examples/fsdp/train_decoder_only_fsdp_v2.py
python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py
python3 examples/train_resnet_amp.py
python3 examples/eager/train_decoder_only_eager.py
python3 examples/eager/train_decoder_only_eager_with_compile.py
6 changes: 4 additions & 2 deletions torch_xla/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .eager import eager_mode
from .eager import eager_mode, compile, is_eager_mode

__all__ = [
"eager_mode",
]
"compile",
"is_eager_mode",
]
39 changes: 39 additions & 0 deletions torch_xla/experimental/eager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,48 @@
import functools

import torch_xla


def eager_mode(enable: bool):
"""Configure torch_xla's default executation mode.
Under eager mode only functions that was `torch_xla.compile`d will be
traced and compiled. Other torch ops will be executed eagerly.
"""
torch_xla._XLAC._set_use_eager_mode(enable)


def is_eager_mode() -> bool:
"""Return True if torch_xla is currently under eager mode
"""
return torch_xla._XLAC._get_use_eager_mode()


def compile(func):
"""Compile the func with Lazy Tensor.
Return the optimized function that takes exact same input. Compile will
run the target func under the tracing mode using Lazy tensor.
"""

@functools.wraps(func) # Keep function's name, docstring, etc.
def wrapper(*args, **kwargs):
# compile should only be called with
assert torch_xla._XLAC._get_use_eager_mode() == True
torch_xla._XLAC._set_use_eager_mode(False)
# clear the pending graph if any
torch_xla.sync()
try:
# Target Function Execution
result = func(*args, **kwargs)
except Exception as e:
# Handle exceptions (if needed)
print(f"Error in target function: {e}")
raise # Re-raise the exception
# Sync the graph generated by the target function.
torch_xla.sync()
torch_xla._XLAC._set_use_eager_mode(True)

return result

return wrapper

0 comments on commit 90168e8

Please sign in to comment.