Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support many unary dynamo converters #2246

Merged
merged 1 commit into from
Aug 23, 2023

Conversation

zewenli98
Copy link
Collaborator

Description

Implemented all 22 unary dynamo converters, including EXP, LOG, SQRT, RECIP, ABS, SIN, COS, TAN, SINH, COSH, ASIN, ACOS, ATAN, ASINH, ACOSH, ATANH, CEIL, FLOOR, NOT, SIGN, ROUND, ISINF.

Fixes #2199

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Aug 19, 2023
@github-actions github-actions bot requested a review from apbose August 19, 2023 00:30
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 We need a layer of indirection between aten and impl which is why functions like sign had dedicated implementations.

For each op we need:

  1. aten registration for the actual converter in aten_converters.py
  2. a dedicated impl function

So for aten::cos
We need:

  1. aten_ops_cos (converison.aten_ops_converters)
  2. impl.cos (conversion.impl.unary.ops)

@@ -405,7 +405,7 @@ def aten_ops_to_copy_dtype(
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this is committed check to see if theres a way to fix this mypy error @gs-olive can you look at the decorator as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recently noticed that mypy does not accurately show errors for these decorators. I will look more into this.

The decorator is already as strongly-typed as it can be though, as here:

def dynamo_tensorrt_converter(
key: Target,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]:

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py Outdated Show resolved Hide resolved
@zewenli98
Copy link
Collaborator Author

@zewenli98 We need a layer of indirection between aten and impl which is why functions like sign had dedicated implementations.

For each op we need:

  1. aten registration for the actual converter in aten_converters.py
  2. a dedicated impl function

So for aten::cos We need:

  1. aten_ops_cos (converison.aten_ops_converters)
  2. impl.cos (conversion.impl.unary.ops)

I thought there's almost same code for these converters supported by TensorRT, so I created unary_op in unary/ops.py, like elementwise_op in elementwise/ops.py. I was wondering why dedicated implementations are better?

@narendasan
Copy link
Collaborator

narendasan commented Aug 21, 2023

Dedicated implementations are better because they use a standardized simpler interface for all ops. So if I want to make a new converter for say LogSumExp, instead of my code looking like this:

layer = impl.unary_op(in, ..., trt.IUnaryOperation.LOG)
layer = impl.reduce_op(layer.output, ..., trt.IReduceOperation.SUM)
return impl.unary_op(layer.output, ..., trt.IUnaryOperation.EXP)

Peoples code would look closer to this:

x = impl.log(in,...)
x = impl.sum(x,...)
return impl.exp(x,...)

So a converter writer does not need to have in depth knowledge about tensorrt or its APIs, they can write code more similar to pytorch instead.

@zewenli98
Copy link
Collaborator Author

Oh I see. It makes sense. Actually at first I implemented dedicated implementations but after seeing elementwise_op, I changed to the current code. I'll change back and also for elementwise.

@narendasan
Copy link
Collaborator

narendasan commented Aug 21, 2023

So there can be a core function which is centralized. So unary_op is good, but effectively we just need an alias for specific ops so users dont need to do that themselves

So we would do something like aten_ops_cos -> impl.cos -> impl.unary.unary_op(..., trt.IUnaryOperation.COS)

@zewenli98
Copy link
Collaborator Author

OK, I'll change to something like aten_ops_cos -> impl.cos -> impl.unary.unary_op(..., trt.IUnaryOperation.COS)

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @zewenli98 rebase and should be good to merge

@zewenli98
Copy link
Collaborator Author

Do we still need # type: ignore[misc]?

@narendasan
Copy link
Collaborator

Its fine for now we can resolve it later

fix bugs

change to dedicated implementations

move input validation into ops
@zewenli98
Copy link
Collaborator Author

rebased, but I don't have access to merge.

@narendasan narendasan merged commit fabfc55 into pytorch:main Aug 23, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Get IUnaryLayer Ops exposed in dynamo.conversion.impl
4 participants