Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support amax dynamo converter #2241

Merged
merged 8 commits into from
Aug 25, 2023

Conversation

zewenli98
Copy link
Collaborator

Description

Support amax dynamo converter.

Fixes issue #2095

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Aug 17, 2023
name,
input_val,
args[1],
args_bounds_check(args, 2, replacement=False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive can this check be done in a validator?

Copy link
Collaborator

@gs-olive gs-olive Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be done in a validator, but in this context it would make the most sense for it to be done here, since we can support cases where this argument is both present and absent.

name,
input_val,
args[1],
args_bounds_check(args, 2, replacement=False),
Copy link
Collaborator

@gs-olive gs-olive Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be done in a validator, but in this context it would make the most sense for it to be done here, since we can support cases where this argument is both present and absent.

py/torch_tensorrt/dynamo/conversion/impl/reduce.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/reduce.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converter looks great, and the testing is very robust! Made a few small comments

py/torch_tensorrt/dynamo/conversion/impl/reduce.py Outdated Show resolved Hide resolved
Comment on lines +23 to +26
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required for the IReduceLayer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's required because TypeError: Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: torch.int8 if I test int8. Even though the doc says Reduce Layer supports int32, but I tested it doesn't support.

py/torch_tensorrt/dynamo/conversion/impl/reduce.py Outdated Show resolved Hide resolved
@zewenli98
Copy link
Collaborator Author

Converter looks great, and the testing is very robust! Made a few small comments

Thanks for your review! I've solved all the issues and pushed.

@zewenli98
Copy link
Collaborator Author

It looks like failed on C++ tests that's nothing to do with this commit?

@gs-olive
Copy link
Collaborator

I reran that C++ test here

@zewenli98
Copy link
Collaborator Author

Passed! please let me know if anything need to fix in this PR.

@@ -440,6 +440,25 @@ def aten_ops_expand(
)


@dynamo_tensorrt_converter(torch.ops.aten.amax.default)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this schema, it seems that the dimension (dim) can also be not present. For instance, the following model creates the subsequent graph:

class argmax(torch.nn.Module):
   def forward(self, x):
        return torch.argmax(x)

"""
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1,), kwargs = {})
    return (argmax,)
"""

If we cannot support this case, you can add a capability_validator function to this decorator, which will note that case as unsupported. Roughly, that could be something like:

def amax_param_validator(amax_node: Node) -> bool:
    return len(amax_node.args) >= 2

@dynamo_tensorrt_converter(
    torch.ops.aten.amax.default, capability_validator=amax_param_validator
)  # type: ignore[misc]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive Thank you so much George! That's a good catch. I fixed and learned a lot from these details recently 😃

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem - glad to hear it!

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@gs-olive gs-olive merged commit a65c95c into pytorch:main Aug 25, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants