-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDNN] Add partitioning support for conv2d and log_softmax #10961
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks.
Not for this PR, but the direct construction of the tvm.target.cuda() here and for cublas causes some issues in the Collage branch. If there's an easy way to recover the Target already known to the te_compiler.cc machinery (via a function attribute perhaps?) that would help remove that glitch.
@@ -50,6 +51,8 @@ def partition_for_cudnn( | |||
tvm.IRModule | |||
The partitioned module. | |||
""" | |||
if params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given it's a one liner never figured out why folks want to fold that into every partition function. Cargo culting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my case, I want to pattern match against a mod where batch norm is removed by constant folding + fold scale axis. Param binding is a prereq for these passes.
cc @mbrookhart @masahi PTAL and merge if you're happy :) |
* main: (527 commits) [hexagon] 'add_hvx' test to explore HVX usage. (apache#10604) [COMMUNITY] @yzh119 -> Reviewer (apache#10993) [Metaschedule] Make custom schedule_rule registration optional (apache#10975) [ONNX] Add imports for BERT contrib operators (apache#10949) sort axes (apache#10985) [Hexagon] Remove HexagonBuffer external constructor and support (apache#10978) [CI] Update GPU image (apache#10992) [Runtime][Vulkan] Add RGP support to TVM for vulkan device (apache#10953) [FIX] resolve int64/32 for AttrStmtNode (apache#10983) [TVMC] Allow output module name to be passed as a command line argument (apache#10962) [ONNX] Add MatMulInteger importer (apache#10450) [COMMUNITY] @guberti -> Reviewer (apache#10976) Support `qnn.conv2d` in FoldExplicitPading (apache#10982) change Hexagon docker version (apache#10981) remove exception handling of autotvm xgboost extract functions (apache#10948) [CUDNN] Add partitioning support for conv2d and log_softmax (apache#10961) [Hexagon][LLVM] Enable/test tensorized Hexagon DMA on 2d transformed layout (apache#10905) [Hexagon] Move aot/graph_executor interactions into launcher (apache#10907) [HEXAGON] Split huge 1D DMA Transfers into smaller transfers with legal sizes. (apache#10971) [CI][DOCKER] Add pytest-lazy-fixture to images (apache#10970) ...
Further adds Relay partitioning support for cuDNN conv2d and log_softmax.