Skip to content

Commit

Permalink
[Relay][Pass] Add a relay pass to extract fake quantized ops (apache#…
Browse files Browse the repository at this point in the history
…10089)

* add relay pass to collect fake quantized ops

* add more tests

* more tests

* lint

* lint

* remove unused imports

* update comment

* lint

* reuse SubgraphExtractor and update test assertions

* remove print

* lint

* remove unneeded comment

Co-authored-by: Margaret Qian <mqian@octoml.ai>
  • Loading branch information
2 people authored and ylc committed Feb 16, 2022
1 parent 5fb57a1 commit 1a78096
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 57 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,22 @@ def list_op_freqs(mod):
return _ffi_api.ExtractOperators(mod)


def list_fake_quantized_op_freqs(mod):
"""Pass to extract fake quantized op names and the frequency that they appear
in fake quantized regions of an IRModule.
Parameters
----------
mod : tvm.IRModule
Returns
-------
ret : Dict[str, int]
Dict of fake quantized operator names to frequency
"""
return _ffi_api.ExtractFakeQuantizedOps(mod)


def search_fc_transpose(expr):
"""Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
Expand Down
80 changes: 80 additions & 0 deletions src/relay/analysis/extract_fake_quantized_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file extract_fake_quantized_ops.cc
* \brief Extract fake quantized operators from an IRModule
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

#include "../transforms/fake_quantization_to_integer.h"

namespace tvm {
namespace relay {

using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;

class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor {
public:
Map<String, tvm::Integer> Extract(const IRModule& m) {
IRModule mod(m);
mod = transform::InferType()(mod);
VisitExpr(mod->Lookup("main"));

return fake_quantized_op_freqs_;
}

private:
using MixedModeVisitor::VisitExpr_;

void VisitExpr_(const CallNode* call_node) override {
if (call_node->op == quantize_op_) {
SubgraphExtractor extractor;
ExprSet subgraph = extractor.GetSubgraph(GetRef<Expr>(call_node));

for (auto expr : subgraph) {
const Op op = Downcast<Op>(expr.as<CallNode>()->op);
if (op != dequantize_op_) {
if (fake_quantized_op_freqs_.find(op->name) != fake_quantized_op_freqs_.end()) {
fake_quantized_op_freqs_.Set(op->name,
int64_t(fake_quantized_op_freqs_.at(op->name)) + 1);
} else {
fake_quantized_op_freqs_.Set(op->name, 1);
}
}
}
}
}

Map<String, tvm::Integer> fake_quantized_op_freqs_;
const Op quantize_op_ = Op::Get("qnn.quantize");
const Op dequantize_op_ = Op::Get("qnn.dequantize");
};

Map<String, tvm::Integer> ExtractFakeQuantizedOpsPacked(const IRModule& mod) {
return ExtractFakeQuantizedOpsWrapper().Extract(mod);
}

TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps")
.set_body_typed(ExtractFakeQuantizedOpsPacked);

} // namespace relay
} // namespace tvm
109 changes: 52 additions & 57 deletions src/relay/transforms/fake_quantization_to_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
* to actual integer operations.
*/

#include <tvm/ir/affine_type.h>
#include "fake_quantization_to_integer.h"

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>

#include <unordered_map>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -75,69 +78,61 @@ using AffineTypeMap = Map<Expr, AffineType>;
using FTVMFakeQuantizationToInteger =
runtime::TypedPackedFunc<Array<ObjectRef>(const Expr& expr, const AffineTypeMap& map)>;

class SubgraphExtractor : public ExprVisitor {
public:
const ExprSet GetSubgraph(const Expr& expr) {
VisitExpr(expr);
ExprSet subgraph;
if (is_fake_quantized_) {
for (auto kv : this->visit_counter_) {
if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
if (call_node->op != quantize_op_) {
subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
}
const ExprSet SubgraphExtractor::GetSubgraph(const Expr& expr) {
VisitExpr(expr);
ExprSet subgraph;
if (is_fake_quantized_) {
for (auto kv : this->visit_counter_) {
if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
if (call_node->op != quantize_op_) {
subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
}
}
}
return subgraph;
}
const AffineTypeMap GetAffineTypes() { return affine_types_; }
void VisitExpr(const Expr& expr) override {
// When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
// i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
// abort the rewrite.
if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
expr.as<ConstantNode>() == nullptr) {
DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
<< " a fake quantize region, aborting this rewrite";
is_fake_quantized_ = false;
} else {
ExprVisitor::VisitExpr(expr);
}
return subgraph;
}
const AffineTypeMap SubgraphExtractor::GetAffineTypes() { return affine_types_; }
void SubgraphExtractor::VisitExpr(const Expr& expr) {
// When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
// i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
// abort the rewrite.
if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
expr.as<ConstantNode>() == nullptr) {
DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
<< " a fake quantize region, aborting this rewrite";
is_fake_quantized_ = false;
} else {
ExprVisitor::VisitExpr(expr);
}
}

protected:
void VisitExpr_(const CallNode* call_node) override {
if (call_node->op == quantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
ICHECK(attrs != nullptr);
// Only look at arg0 for quantize
VisitExpr(call_node->args[0]);
// Collect type of quantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
} else if (call_node->op == dequantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
ICHECK(attrs != nullptr);
// Collect type of dequantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2],
call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
attrs->axis));
} else {
// run normally on everything else.
ExprVisitor::VisitExpr_(call_node);
}
void SubgraphExtractor::VisitExpr_(const CallNode* call_node) {
const Op test_op = Downcast<Op>(call_node->op);
if (call_node->op == quantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
ICHECK(attrs != nullptr);
// Only look at arg0 for quantize
VisitExpr(call_node->args[0]);
// Collect type of quantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
} else if (call_node->op == dequantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
ICHECK(attrs != nullptr);
// Collect type of dequantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2],
call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
attrs->axis));
} else {
// run normally on everything else.
ExprVisitor::VisitExpr_(call_node);
}

const Op quantize_op_ = Op::Get("qnn.quantize");
const Op dequantize_op_ = Op::Get("qnn.dequantize");
bool is_fake_quantized_ = true;
AffineTypeMap affine_types_;
};
}

class SubgraphMutator : public ExprMutator {
public:
Expand Down
54 changes: 54 additions & 0 deletions src/relay/transforms/fake_quantization_to_integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/transforms/fake_quantization_to_integer.h
* \brief Extract subgraph of a fake quantized region.
*/
#ifndef TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
#define TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_

#include <tvm/ir/affine_type.h>
#include <tvm/relay/expr_functor.h>

#include <unordered_set>

namespace tvm {
namespace relay {

class SubgraphExtractor : public ExprVisitor {
public:
const std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> GetSubgraph(const Expr& expr);
const Map<Expr, AffineType> GetAffineTypes();
void VisitExpr(const Expr& expr) override;

protected:
void VisitExpr_(const CallNode* call_node) override;

private:
const Op quantize_op_ = Op::Get("qnn.quantize");
const Op dequantize_op_ = Op::Get("qnn.dequantize");
bool is_fake_quantized_ = true;
Map<Expr, AffineType> affine_types_;
};

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
Loading

0 comments on commit 1a78096

Please sign in to comment.