Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter tensor arguments from traced model. #5689

Merged
merged 7 commits into from
Oct 16, 2023

Conversation

ysiraichi
Copy link
Collaborator

This PR filters tensor arguments, when collecting tensor information, from the list of arguments that would be given to the model.

Problem: dynamo bridge assumed all arguments were tensors.
Solution: filter tensor arguments so that we correctly collect tensor information.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.
@JackCaoG
Copy link
Collaborator

Thanks, can you add a test cast that would fail without this change?

@ysiraichi ysiraichi marked this pull request as ready for review October 11, 2023 12:58
@ysiraichi ysiraichi changed the title [WIP] Filter tensor arguments from traced model. Filter tensor arguments from traced model. Oct 11, 2023
test/dynamo/test_bridge.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! LGTM pending the CI tests.

module = Emb()
module.to(device)

@torch.compile(backend="openxla")
Copy link
Collaborator Author

@ysiraichi ysiraichi Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, openxla doesn't have this problem. This only happens with openxla_eval. Since we are trying to get rid of openxla_eval, are we still interested in merging this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is aot-autograd handles the argument. What's the non-tensor input passed to openxla_eval?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, exactly. One of the inputs passed was a custom class that inherited from nn.Embedding.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check with Brian whether it is possible for us to get non-tensor input after aot-autograd

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually @bdhirsh any thought?

@JackCaoG
Copy link
Collaborator

I think it is safer to have this change in, I am going to merge it.

@JackCaoG JackCaoG merged commit 994f9fb into pytorch:master Oct 16, 2023
17 checks passed
zpcore pushed a commit that referenced this pull request Oct 19, 2023
* Filter tensor arguments from traced model.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.

* Add test.

* Fix lint issues.

* Simplified test.

* Use `openxla` instead of `openxla_eval` backend.

* Rename variables for readability.

* Use `openxla_eval` instead of `openxla`.
ghpvnist pushed a commit to ghpvnist/xla that referenced this pull request Oct 31, 2023
* Filter tensor arguments from traced model.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.

* Add test.

* Fix lint issues.

* Simplified test.

* Use `openxla` instead of `openxla_eval` backend.

* Rename variables for readability.

* Use `openxla_eval` instead of `openxla`.
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
* Filter tensor arguments from traced model.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.

* Add test.

* Fix lint issues.

* Simplified test.

* Use `openxla` instead of `openxla_eval` backend.

* Rename variables for readability.

* Use `openxla_eval` instead of `openxla`.
@lezcano lezcano added the xla:gpu label Dec 1, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Filter tensor arguments from traced model.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.

* Add test.

* Fix lint issues.

* Simplified test.

* Use `openxla` instead of `openxla_eval` backend.

* Rename variables for readability.

* Use `openxla_eval` instead of `openxla`.
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Filter tensor arguments from traced model.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.

* Add test.

* Fix lint issues.

* Simplified test.

* Use `openxla` instead of `openxla_eval` backend.

* Rename variables for readability.

* Use `openxla_eval` instead of `openxla`.
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Filter tensor arguments from traced model.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.

* Add test.

* Fix lint issues.

* Simplified test.

* Use `openxla` instead of `openxla_eval` backend.

* Rename variables for readability.

* Use `openxla_eval` instead of `openxla`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants