-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add torch_xla.experimental.compile for eager mode #7246
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import sys | ||
import os | ||
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) | ||
sys.path.append(example_folder) | ||
from train_decoder_only_base import TrainDecoderOnlyBase | ||
|
||
import torch_xla | ||
|
||
|
||
class TrainDecoderOnlyEagerWithCompile(TrainDecoderOnlyBase): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
# step fn will be compiled and rest will be run eagerly. | ||
self.step_fn = torch_xla.experimental.compile(self.step_fn) | ||
|
||
|
||
if __name__ == '__main__': | ||
torch_xla.experimental.eager_mode(True) | ||
trainer = TrainDecoderOnlyEagerWithCompile() | ||
trainer.start_training() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import unittest | ||
import sys | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.debug.metrics as met | ||
import torch_xla.core.xla_model as xm | ||
|
||
|
||
class EagerWithXLACompileTest(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
torch_xla.experimental.eager_mode(True) | ||
|
||
def dummy_cos_sin(self, tensor): | ||
return torch.cos(torch.sin(tensor)) | ||
|
||
def test_eager_with_compile_basic(self): | ||
met.clear_all() | ||
self.assertTrue(torch_xla.experimental.is_eager_mode()) | ||
device = torch_xla.device() | ||
|
||
# this part happens eagerly | ||
t1 = torch.randn(5, 5, device=device) | ||
t1 *= 5 | ||
|
||
t2 = self.dummy_cos_sin(t1) | ||
t2_compiled = torch_xla.experimental.compile(self.dummy_cos_sin)(t1) | ||
self.assertTrue(torch.allclose(t2, t2_compiled)) | ||
xm.wait_device_ops() | ||
# We execute one compiled graph | ||
self.assertEqual(met.metric_data("ExecuteTime")[0], 1) | ||
# and many eager ops | ||
self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 5) | ||
|
||
|
||
def test_eager_execute_compiled_multiple_times(self): | ||
met.clear_all() | ||
self.assertTrue(torch_xla.experimental.is_eager_mode()) | ||
device = torch_xla.device() | ||
# this part happens eagerly | ||
t1 = torch.randn(10, 5, device=device) | ||
t1.add_(0.5) | ||
compiled = torch_xla.experimental.compile(self.dummy_cos_sin) | ||
res = compiled(compiled(t1)) | ||
self.assertTrue( | ||
torch.allclose(res * 0.3, | ||
self.dummy_cos_sin(self.dummy_cos_sin(t1)) * 0.3)) | ||
xm.wait_device_ops() | ||
self.assertEqual(met.metric_data("ExecuteTime")[0], 2) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from .eager import eager_mode | ||
from .eager import eager_mode, compile, is_eager_mode | ||
|
||
__all__ = [ | ||
"eager_mode", | ||
] | ||
"compile", | ||
"is_eager_mode", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,48 @@ | ||
import functools | ||
|
||
import torch_xla | ||
|
||
|
||
def eager_mode(enable: bool): | ||
"""Configure torch_xla's default executation mode. | ||
|
||
Under eager mode only functions that was `torch_xla.compile`d will be | ||
traced and compiled. Other torch ops will be executed eagerly. | ||
""" | ||
torch_xla._XLAC._set_use_eager_mode(enable) | ||
|
||
|
||
def is_eager_mode() -> bool: | ||
"""Return True if torch_xla is currently under eager mode | ||
""" | ||
return torch_xla._XLAC._get_use_eager_mode() | ||
|
||
|
||
def compile(func): | ||
"""Compile the func with Lazy Tensor. | ||
|
||
Return the optimized function that takes exact same input. Compile will | ||
run the target func under the tracing mode using Lazy tensor. | ||
""" | ||
|
||
@functools.wraps(func) # Keep function's name, docstring, etc. | ||
def wrapper(*args, **kwargs): | ||
# compile should only be called with | ||
assert torch_xla._XLAC._get_use_eager_mode() == True | ||
torch_xla._XLAC._set_use_eager_mode(False) | ||
# clear the pending graph if any | ||
torch_xla.sync() | ||
try: | ||
# Target Function Execution | ||
result = func(*args, **kwargs) | ||
except Exception as e: | ||
# Handle exceptions (if needed) | ||
print(f"Error in target function: {e}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the exception is tracing exception right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, execution is async so we won't be able to catch it here. |
||
raise # Re-raise the exception | ||
# Sync the graph generated by the target function. | ||
torch_xla.sync() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there actaully runs the graph and you might get exceptions here too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right, the way that LTC works is that async execution happens in a separate thread and the runtime error will be set in the unlocker. Next time when we try to get the device lock it will find that exception and throw That being said I agree with you that there is no harm to put |
||
torch_xla._XLAC._set_use_eager_mode(True) | ||
|
||
return result | ||
|
||
return wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so is the mechanism for caching the graph is already there right so no need to do anything extra?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still need to trace the whole model(run all python code), we just skip XLA compilation and lowering to HLO part.
This
compile
does not modify the function bytecode or transform the function in anyway besides enabling the tracing mode. It is actually more accurate to call ittrace
butcompile
is more align with pytorch API.