Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

index: fix index of 0-element tensor by 0-element tensor. #7113

Merged
merged 3 commits into from
May 28, 2024

Conversation

ysiraichi
Copy link
Collaborator

This PR adds support for indexing a 0-element tensor with a 0-element tensor index. It also adds a fast path whenever there are 0-element tensor indices. If found, we know that we will return a 0-element tensor.

In summary, here are the steps of this fast path:

  • Assume there is, at least, one 0-element tensor index
  • Deduce the output shape from the indices and start_dim
    • The output shape will be a concatenation of 3 sub-sequences:
      1. The shape of the indexed tensor up to dimension start_dim
      2. The shape of the indices
      3. The shape of the indexed tensor after dimension start_dim + len(indices)

For further details, see PyTorch's implementation of output shape computation.

cc @miladm @JackCaoG @lezcano

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a meta implementation of the meta of index_Tensor at https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/_meta_registrations.py#L2983

Perhaps it's better to replicate its behaviour here, as it's quite tricky.

I assume we cannot use this information here, even though dynamo knows about it, right?

@ysiraichi
Copy link
Collaborator Author

I think that I'm basically doing the same thing that this part of the meta implementation is doing. The only thing is that, at this point, we have already preprocessed the inputs such that:

  • base is permuted so that the indexing dimensions are adjacent
  • start_dim is the dimension that indices[0] indexes
  • indices is a list of defined (i.e. tensor.defined() == true) tensors that indexes adjacent dimensions starting from start_dim

I assume we cannot use this information here, even though dynamo knows about it, right?

I think the problem is two-fold:

  1. This function lives in Python. Thus, we need some way of calling it (e.g. this PoC)

  2. This function only retrieves the output, throwing away all preprocess done in the beginning of it

    • For lowering, we would need to do the pre-processing, again

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. Thank you for looking into this!

@ysiraichi
Copy link
Collaborator Author

Thank you, Mario. I think I will wait for @JackCaoG review, too, before merging this PR.

@JackCaoG
Copy link
Collaborator

Let me take a look today

@JackCaoG JackCaoG merged commit be3b08e into master May 28, 2024
19 checks passed
@JackCaoG JackCaoG deleted the ysiraichi/fix-index-zero-tensor-by-zero-tensor branch May 28, 2024 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[torchbench] vision_maskrcnn failing on inference with dynamo after bfloat16 conversion.
3 participants