Skip to content

Commit

Permalink
adapting apache#11642
Browse files Browse the repository at this point in the history
  • Loading branch information
liaopeiyuan committed Oct 1, 2022
1 parent 7abc68e commit 01989e5
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 10 deletions.
268 changes: 267 additions & 1 deletion python/tvm/relay/op/contrib/tachikoma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import logging

import tvm.ir
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.expr import const

from ...dataflow_pattern import wildcard, is_op
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table

logger = logging.getLogger("Tachikoma")
Expand All @@ -59,6 +61,10 @@ def _register_external_op_helper(op_name, supported=True):

@tvm.ir.register_op_attr(op_name, "target.tachikoma")
def _func_wrapper(expr):
args = expr.args
if any([x.checked_type.dtype == "int64" for x in args]):
logger.info("Tachikoma does not support int64.")
return False
return supported

return _func_wrapper
Expand Down Expand Up @@ -174,6 +180,68 @@ def make_tachikoma_pattern(op, with_bias, with_eltwise):
tachikoma_pattern = ()
return tachikoma_pattern

def make_qnn_conv2d_pattern():
"""Make qnn.conv2d based pattern supported by Tachikoma
Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
data = wildcard()
weight = is_constant()
bias = is_constant()
o_scl = is_constant()
dst_zp = is_constant()
act_scl = is_constant()
sum_scl = is_constant()
sum_src = wildcard()

zero_zp = is_expr(const(0, dtype="int32"))

pat = is_op("qnn.conv2d")(data, weight, zero_zp, zero_zp, is_constant(), is_constant())
pat = is_op("cast")(pat)
pat = is_op("add")(pat, bias) | pat # optional bias
pat = is_op("multiply")(pat, o_scl)
pat = is_op("clip")(pat) # TBD, not only clip
pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. Ex: act_scl == 1
pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum
pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0
pat = is_op("cast")(pat)

return "tachikoma.qnn.conv2d", pat



def make_qnn_dense_pattern():
"""Make qnn.dense based pattern supported by Tachikoma
Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
data = wildcard()
weight = is_constant()
bias = is_constant()
o_scl = is_constant()
dst_zp = is_constant()
act_scl = is_constant()
sum_scl = is_constant()
sum_src = wildcard()

zero_zp = is_expr(const(0, dtype="int32"))

pat = is_op("qnn.dense")(data, weight, zero_zp, zero_zp, is_constant(), is_constant())
pat = is_op("cast")(pat)
pat = is_op("add")(pat, bias) | pat # optional bias
pat = is_op("multiply")(pat, o_scl)
pat = is_op("clip")(pat) # TBD, not only clip
pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. ex act_scl == 1
pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum
pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0
pat = is_op("cast")(pat)

return "tachikoma.qnn.dense", pat


@register_pattern_table("tachikoma")
def pattern_table():
Expand All @@ -185,6 +253,9 @@ def pattern_table():
"""
elt_list = ["nn.relu", "tanh", "sigmoid", None]
tachikoma_patterns = []
tachikoma_patterns.append(make_qnn_conv2d_pattern())
tachikoma_patterns.append(make_qnn_dense_pattern())

for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
Expand All @@ -200,6 +271,198 @@ def pattern_table():
tachikoma_patterns.append(make_tachikoma_pattern("nn.dense", with_bias, elt))
return tachikoma_patterns

class LegalizeQnnOpForTachikoma(DFPatternCallback):
"""Legalize QNN based patterns to match Tachikoma
original pattern:
OP = qnn.dense | qnn.conv2d
%1 = OP<int>(SRC, WGH) - OP<int>(src_zp, WGH) // qnn.conv2d
%2 = %1 + orig_bias // bias
%2 = (%1 - rq_in_zp) * rq_in_scl / rq_out_scl + rq_out_zp // qnn.requantize
%3 = act(%2) // activation == clip
%4 = ((%3 - sum_lh_zp) * sum_lh_scl + (SRC2 - sum_rh_zp) * sum_rh_scl) // qnn.add
/ sum_out_scl + sum_out_zp
transform to Tachikoma compatible:
%1 = OP<int>(SRC, WGH)
%2 = cast(%1, dtype="float")
%2 = (%1 + bias) * o_scl
%3 = act(%2) * act_scl
%4 = %3 + SRC2 * sum_scl
%5 = %4 + dst_zp
%6 = cast(%5, dtype="float")
where:
o_scl = rq_in_scl / rq_out_scl
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
bias = orig_bias - OP(src_zp, WGH) - rq_in_zp + rq_out_zp * rq_out_scl / rq_in_scl
dst_zp = sum_out_zp - sum_lhs_zp * sum_lhs_scl / sum_out_scl -
sum_rhs_zp * sum_rhs_scl / sum_out_scl
"""

def __init__(self):
super(LegalizeQnnOpForTachikoma, self).__init__()
self.src = wildcard()
self.wgh = wildcard()
self.bias = wildcard()
self.sum_src = wildcard()

self.src_scl = is_constant()
self.src_zp = is_constant()
self.wgh_scl = is_constant()
self.wgh_zp = is_expr(const(0))

self.rq_in_scl = is_constant()
self.rq_in_zp = is_constant()
self.rq_out_scl = is_constant()
self.rq_out_zp = is_constant()

self.sum_lhs_scl = is_constant()
self.sum_lhs_zp = is_constant()
self.sum_rhs_scl = is_constant()
self.sum_rhs_zp = is_constant()
self.sum_out_scl = is_constant()
self.sum_out_zp = is_constant()

self.root = (is_op("qnn.conv2d") | is_op("qnn.dense"))(
self.src, self.wgh, self.src_zp, self.wgh_zp, self.src_scl, self.wgh_scl
)
pat = is_op("add")(self.root, self.bias) | self.root # optional bias
pat = is_op("qnn.requantize")(
pat, self.rq_in_scl, self.rq_in_zp, self.rq_out_scl, self.rq_out_zp
)
pat = is_op("clip")(pat)
cast = is_op("cast")(pat)
pat = is_op("qnn.add")(
cast,
self.sum_src,
self.sum_lhs_scl,
self.sum_lhs_zp,
self.sum_rhs_scl,
self.sum_rhs_zp,
self.sum_out_scl,
self.sum_out_zp,
)
pat = is_op("clip")(pat)
self.pattern = pat | cast

def callback(self, pre, post, node_map):
root = node_map[self.root][0]
src = node_map[self.src][0]
wgh = node_map[self.wgh][0]
bias = node_map.get(self.bias, default=[relay.const(0, dtype="int32")])[0]
src_zp = node_map[self.src_zp][0]
rq_in_scl = node_map[self.rq_in_scl][0]
rq_in_zp = node_map[self.rq_in_zp][0]
rq_out_scl = node_map[self.rq_out_scl][0]
rq_out_zp = node_map[self.rq_out_zp][0]

final_dtype = node_map[self.pattern][0].checked_type.dtype

if root.op == relay.op.get("qnn.conv2d"):
dst_layout = root.attrs.out_layout
dst_layout = root.attrs.data_layout if dst_layout == "" else dst_layout
wgh_layout = root.attrs.kernel_layout
else:
# qnn.dense has no layout attributes. Assume that is plain
dst_layout = "NC"
wgh_layout = "OI"

# TODO(@apeskov): dst_layout may ne blocked
bias_rank = len(dst_layout) - dst_layout.index("C")

sum_src = node_map[self.sum_src][0] if self.sum_src in node_map else None
# Default values if qnn.sum is not present
sum_lhs_scl = node_map[self.sum_lhs_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_lhs_zp = node_map[self.sum_lhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_rhs_scl = node_map[self.sum_rhs_scl][0] if sum_src else relay.const(0, dtype="float32")
sum_rhs_zp = node_map[self.sum_rhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_out_scl = node_map[self.sum_out_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_out_zp = node_map[self.sum_out_zp][0] if sum_src else relay.const(0, dtype="int32")

def cast_fp(op):
return relay.op.cast(op, dtype="float32")

# recalculate some factors
o_scl = rq_in_scl / rq_out_scl
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
dst_zp = (
cast_fp(sum_out_zp)
- cast_fp(sum_lhs_zp) * sum_lhs_scl / sum_out_scl
- cast_fp(sum_rhs_zp) * sum_rhs_scl / sum_out_scl
)
bias = self.squeeze_bias(bias, dst_layout)
bias = (
cast_fp(bias)
- cast_fp(self.fake_op(src_zp, wgh, wgh_layout))
- cast_fp(rq_in_zp)
+ cast_fp(rq_out_zp) * rq_out_scl / rq_in_scl
)
bias = self.broadcast_to_rank(bias, bias_rank)

zero_zp = relay.const(0, dtype="int32")
one_scl = relay.const(1.0, dtype="float32")

# construct new graph with proper post op ordering
gr = tvm.relay.Call(
root.op,
[src, wgh, zero_zp, zero_zp, one_scl, one_scl],
root.attrs,
root.type_args,
root.span,
)
gr = relay.op.cast(gr, dtype="float32")
gr = gr + bias
gr = gr * o_scl
gr = relay.op.clip(gr, 0, 255) * act_scl
gr = gr + sum_scl * cast_fp(sum_src) if sum_src else gr
gr = gr + dst_zp
gr = relay.op.cast(gr, dtype=final_dtype)
return gr

@staticmethod
def fake_op(zp, wgh, layout):
"""Fake operator implementation for zp broadcast input"""
# Conv: reduce kernel {OC, IC, KH, KW} -> {OC} in case of group that is still correct
# Dense: reduce kernel {OC, IC} -> {OC}
wgh_int = relay.op.cast(wgh, dtype="int32")
reduced_kernel = relay.op.sum(
wgh_int, axis=[layout.index("O")], keepdims=False, exclude=True
)
return zp * reduced_kernel

@staticmethod
def squeeze_bias(bias, layout):
shape = transform.InferTypeLocal(bias).concrete_shape
c_position = layout.index("C") - len(layout) + len(shape)
squeeze_idxs = [i for i in range(len(shape)) if i != c_position]
return relay.op.squeeze(bias, squeeze_idxs)

@staticmethod
def broadcast_to_rank(op, rank):
"""Scalar or 1D tensor are supported"""
shape = transform.InferTypeLocal(op).concrete_shape
if len(shape) == 0:
return op
if len(shape) == 1:
return relay.op.expand_dims(op, 1, rank - 1)
raise ValueError("Unexpected bias rank to broadcast. Only 0 and 1 are supported.")

def legalize_qnn_for_tachikoma(mod):
"""Transform qnn primitives to Tachikoma compatible form. Eliminate source zero point and apply
strict sequence of post ops."""
mod["main"] = rewrite(LegalizeQnnOpForTachikoma(), mod["main"])

seq = tvm.transform.Sequential(
[
transform.InferType(),
# transform.SimplifyInference(), # TODO: this pass decompose nn.layer_norm
# transform.FoldScaleAxis(), # TODO: fail inside TVM in case of grouped convolutions.
transform.FoldConstant(),
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod

def partition_for_tachikoma(mod, params=None):
"""Partition the graph greedily offloading supported operators to Tachikoma.
Expand All @@ -217,6 +480,9 @@ def partition_for_tachikoma(mod, params=None):

if params:
mod["main"] = bind_params_by_name(mod["main"], params)

mod["main"] = rewrite(LegalizeQnnOpForTachikoma(), mod["main"])

seq = tvm.transform.Sequential(
[
transform.CanonicalizeOps(),
Expand Down
7 changes: 7 additions & 0 deletions src/relay/backend/contrib/tachikoma/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
#include "../../utils.h"
#include "comp_op_matcher.h"

#define USE_JSON_RUNTIME

#ifdef USE_JSON_RUNTIME
#include "../../../../runtime/contrib/json/json_node.h"
#include "../codegen_json/codegen_json.h"
Expand Down Expand Up @@ -524,6 +526,11 @@ class TachikomaJSONSerializer : public backend::contrib::JSONSerializer {
std::vector<std::string> op_list = ParsingOpList("dense", name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("tachikoma.qnn.conv2d") != std::string::npos ||
name.find("tachikoma.qnn.dense") != std::string::npos) {
std::vector<Expr> args_loc;
call = ParseComposite(*fn, &extra_attrs, &args_loc);
args = BindToCallNodeArgs(args_loc, cn);
} else {
LOG(FATAL) << "Unrecognized tachikoma pattern: " << name;
}
Expand Down
Loading

0 comments on commit 01989e5

Please sign in to comment.