-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Improve diagnostics performance #99936
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99936
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit a8c6282: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
I'm inclined to include a sanity test like below, but didn't due to concerns over its flakiness. Ideas are welcomed. def test_export_remains_efficient_with_diagnostics(self):
model_name = "gpt2"
# Download pytorch model
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
start_time = time.time()
with common_utils.TemporaryFileName() as path:
torch.onnx.dynamo_export(model, **inputs).save(path)
elapsed_time = time.time() - start_time
time_threshold_in_seconds = 15.0
self.assertTrue(
elapsed_time < time_threshold_in_seconds,
(
f"Exporting GPT2 model took too long! "
f"{elapsed_time} seconds > {time_threshold_in_seconds} seconds."
f"This is a sanity check that `torch.onnx.dynamo_export` remains "
f"reasonably efficient with all the diagnostics and analysis enabled. "
f"The time constraint is loosely set such that the test should pass "
f"on most machines."
),
) |
Curious on the speed gain? |
I was just about to suggest we include GPT-2 as some kind of nodes/s baseline test with a large tolerance.
This PR is a good balance and we should merge it as-is (unless you wan to add the test), but we may want to further only gather stack info under trace/debug level for another ~2x bump later? |
@justinchuby for context, I stuck |
import torch
import transformers
torch.onnx.dynamo_export(
transformers.GPT2Model.from_pretrained("gpt2"),
**transformers.GPT2Tokenizer.from_pretrained("gpt2")(
"Tokenize me",
return_tensors="pt",
),
).save("gpt2.onnx") |
@abock thanks for posting speed gain. Yep I think it should be configurable through api. We'd want export to be fast so any perf heavy diagnosing should hide behind it. Merging after adding comments per @justinchuby 's suggestion. |
Summary - Do not call `fx_graph_module.print_readable` when recording `fx.GraphModule` function argument diagnostics. - Cache `inspect.getsourcelines` results. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
Summary - Do not call `fx_graph_module.print_readable` when recording `fx.GraphModule` function argument diagnostics. - Cache `inspect.getsourcelines` results. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary
fx_graph_module.print_readable
when recordingfx.GraphModule
function argument diagnostics.inspect.getsourcelines
results.