Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore non-XLA nodes and their direct dependents. #6170

Merged
merged 5 commits into from
Jan 11, 2024

Conversation

ysiraichi
Copy link
Collaborator

Fix: #5966

This PR generalizes FallBackNodeCollector into UnsupportedNodesCollector, improving and solving a few issues the old implementation had:

  • nodes which resulted in a non-XLA tensor weren't flagged as fallback
  • nodes whose arguments were a container with non-XLA tensors weren't flagged as fallback (e.g. stack)
  • tracking these "unsupported" nodes weren't really only fallback nodes
    • but, nodes that can't exist in partition boundaries (e.g. arguments and return values)

@ysiraichi
Copy link
Collaborator Author

cc @JackCaoG @miladm

@ysiraichi
Copy link
Collaborator Author

Note: this PR can, possibly, create more partitions than before, affecting performance. Maybe we should make sure there are no regressions before actually landing it.

@JackCaoG
Copy link
Collaborator

I think the definition of the fallback is that we have to execute this op on fallback devices(uusually) cpu. for the cases you mentioned

  1. nodes which resulted in a non-XLA tensor weren't flagged as fallback
    Does the operation being executed on cpu or gpu/tpu?

  2. nodes whose arguments were a container with non-XLA tensors weren't flagged as fallback (e.g. stack)
    I guess this one is technically a fallback, since op will be executed on cpu. It will show up in dynamo fallback messages, but shouldn't be in the XLA's metrics aten counter I think?

I don't fully understand what 3 means, can you give ma an example?

@ysiraichi
Copy link
Collaborator Author

ysiraichi commented Jan 8, 2024

The problem was that FallBackNodeCollector was doing more than just collecting operations that were executed on CPU. It was also flagging nodes, that had non-XLA tensors arguments, as fallback (even though they were executed on XLA).

Reason: my guess is that it did so, in order to guarantee that the input/output of the generated partitions were all XLA tensors (which I believe extract_internal assumes).

Assuming that's the case, it still missed the 2 cases I mentioned:

  • nodes which result in non-XLA tensor: might end up as output of some partition
  • nodes whose arguments were a container with non-XLA tensors: arguments might end up as the input of some partition

Solution:

  • disambiguate fallback from unsupported nodes (from the perspective of the partitioner)
  • add the missing 2 cases

@JackCaoG let me know what you think.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jan 8, 2024

yea that make sense, is this pr ready for review?

@ysiraichi
Copy link
Collaborator Author

Yes, it is.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, minor questions on test cases.

@ysiraichi ysiraichi force-pushed the ysiraichi/ignore-non-xla-nodes branch from c38c1a0 to 65cd98a Compare January 10, 2024 22:34
@ysiraichi ysiraichi force-pushed the ysiraichi/ignore-non-xla-nodes branch from 65cd98a to 416bccb Compare January 10, 2024 23:28
@ysiraichi
Copy link
Collaborator Author

@JackCaoG could you approve this PR?

@ysiraichi ysiraichi merged commit 8141078 into master Jan 11, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[torchbench] AssertionError: All tensors should be on xla
2 participants