-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pass] Add a relay pass to extract fake quantized ops #10089
Conversation
I will take a look tomorrow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! A few comments
mod = tvm.IRModule.from_expr(op) | ||
fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) | ||
|
||
assert len(fake_quantized_op_freqs) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just do direct equality throughout this file
dict(fake_quantized_op_freqs) == {"nn.conv2d": 1}
|
||
using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>; | ||
|
||
class FakeQuantizedRegionExtractor : public ExprVisitor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just reuse SubgraphExtractor
in src/relay/transforms/fake_quantization_to_integer.cc
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests look great! very comprehensive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some nits
* \file src/relay/transforms/fake_quantization_to_integer.h | ||
* \brief Extract subgraph of a fake quantized region. | ||
* | ||
* https://llvm.org/doxygen/CallGraph_8h_source.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is probably copypasta?
if (op != dequantize_op_) { | ||
if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) { | ||
fake_quantized_op_freqs_.Set(op_name, | ||
int64_t(fake_quantized_op_freqs_.at(op_name)) + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to cast here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i was getting compile-time errors that fake_quantized_op_freqs_.at(op_name) + 1
is a PrimExpr
instead of a tvm::Integer
and it seemed like casting worked around the issue -- lmk if there's a better way around this?
…10089) * add relay pass to collect fake quantized ops * add more tests * more tests * lint * lint * remove unused imports * update comment * lint * reuse SubgraphExtractor and update test assertions * remove print * lint * remove unneeded comment Co-authored-by: Margaret Qian <mqian@octoml.ai>
Add a relay pass to collect fake quantized ops and frequencies from within fake quantized regions.