Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Add a relay pass to extract fake quantized ops #10089

Merged
merged 12 commits into from
Feb 2, 2022

Conversation

margaretqian
Copy link
Contributor

Add a relay pass to collect fake quantized ops and frequencies from within fake quantized regions.

@AndrewZhaoLuo
Copy link
Contributor

I will take a look tomorrow

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! A few comments

mod = tvm.IRModule.from_expr(op)
fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)

assert len(fake_quantized_op_freqs) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just do direct equality throughout this file

dict(fake_quantized_op_freqs) == {"nn.conv2d": 1}


using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;

class FakeQuantizedRegionExtractor : public ExprVisitor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just reuse SubgraphExtractor in src/relay/transforms/fake_quantization_to_integer.cc?

Copy link
Contributor

@anwang2009 anwang2009 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests look great! very comprehensive.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some nits

* \file src/relay/transforms/fake_quantization_to_integer.h
* \brief Extract subgraph of a fake quantized region.
*
* https://llvm.org/doxygen/CallGraph_8h_source.html
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is probably copypasta?

if (op != dequantize_op_) {
if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) {
fake_quantized_op_freqs_.Set(op_name,
int64_t(fake_quantized_op_freqs_.at(op_name)) + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to cast here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was getting compile-time errors that fake_quantized_op_freqs_.at(op_name) + 1 is a PrimExpr instead of a tvm::Integer and it seemed like casting worked around the issue -- lmk if there's a better way around this?

@AndrewZhaoLuo AndrewZhaoLuo merged commit efe662f into apache:main Feb 2, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
…10089)

* add relay pass to collect fake quantized ops

* add more tests

* more tests

* lint

* lint

* remove unused imports

* update comment

* lint

* reuse SubgraphExtractor and update test assertions

* remove print

* lint

* remove unneeded comment

Co-authored-by: Margaret Qian <mqian@octoml.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants