-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add torch_xla.experimental.compile for eager mode #7246
Conversation
ok test added, should be ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, few qqs:
result = func(*args, **kwargs) | ||
except Exception as e: | ||
# Handle exceptions (if needed) | ||
print(f"Error in target function: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the exception is tracing exception right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, execution is async so we won't be able to catch it here.
print(f"Error in target function: {e}") | ||
raise # Re-raise the exception | ||
# Sync the graph generated by the target function. | ||
torch_xla.sync() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there actaully runs the graph and you might get exceptions here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, the way that LTC works is that async execution happens in a separate thread and the runtime error will be set in the unlocker. Next time when we try to get the device lock it will find that exception and throw
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L861-L864
That being said I agree with you that there is no harm to put sync
in the try
region. Let me update that in the following pr.
""" | ||
|
||
@functools.wraps(func) # Keep function's name, docstring, etc. | ||
def wrapper(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so is the mechanism for caching the graph is already there right so no need to do anything extra?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still need to trace the whole model(run all python code), we just skip XLA compilation and lowering to HLO part.
This compile
does not modify the function bytecode or transform the function in anyway besides enabling the tracing mode. It is actually more accurate to call it trace
but compile
is more align with pytorch API.
This should only be used for the eager mode. The
compile
pretty much enable the LTC before entering the function and disable it again.TODO