-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support amax dynamo converter #2241
Conversation
name, | ||
input_val, | ||
args[1], | ||
args_bounds_check(args, 2, replacement=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive can this check be done in a validator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check can be done in a validator, but in this context it would make the most sense for it to be done here, since we can support cases where this argument is both present and absent.
name, | ||
input_val, | ||
args[1], | ||
args_bounds_check(args, 2, replacement=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check can be done in a validator, but in this context it would make the most sense for it to be done here, since we can support cases where this argument is both present and absent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converter looks great, and the testing is very robust! Made a few small comments
if (isinstance(input_val, TRTTensor)) and ( | ||
input_val.dtype == trt.int8 or input_val.dtype == trt.int32 | ||
): | ||
input_val = cast_trt_tensor(network, input_val, trt.float32, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this required for the IReduceLayer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's required because TypeError: Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: torch.int8
if I test int8. Even though the doc says Reduce Layer supports int32, but I tested it doesn't support.
Thanks for your review! I've solved all the issues and pushed. |
68c3184
to
ecb2316
Compare
It looks like failed on C++ tests that's nothing to do with this commit? |
I reran that C++ test here |
Passed! please let me know if anything need to fix in this PR. |
@@ -440,6 +440,25 @@ def aten_ops_expand( | |||
) | |||
|
|||
|
|||
@dynamo_tensorrt_converter(torch.ops.aten.amax.default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on this schema, it seems that the dimension (dim
) can also be not present. For instance, the following model creates the subsequent graph:
class argmax(torch.nn.Module):
def forward(self, x):
return torch.argmax(x)
"""
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1,), kwargs = {})
return (argmax,)
"""
If we cannot support this case, you can add a capability_validator
function to this decorator, which will note that case as unsupported. Roughly, that could be something like:
def amax_param_validator(amax_node: Node) -> bool:
return len(amax_node.args) >= 2
@dynamo_tensorrt_converter(
torch.ops.aten.amax.default, capability_validator=amax_param_validator
) # type: ignore[misc]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive Thank you so much George! That's a good catch. I fixed and learned a lot from these details recently 😃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem - glad to hear it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Description
Support amax dynamo converter.
Function Schema:
torch.ops.aten.amax.default
Original PyTorch API:
https://pytorch.org/docs/stable/generated/torch.amax.html
Fixes issue #2095
Type of change
Checklist: