Skip to content

Commit

Permalink
Add qnn batch_matmul operator (apache#8401)
Browse files Browse the repository at this point in the history
* Add qnn batch_matmul operator

- add support of the different out type for x86 batch_matmul

* Fix code style

* Add out_dtype to generic batch_matmul

* Restore fixe in batch_matmul for dynamic shapes

* Fix documentation for qnn.batch_matmul

* Remove debug code

* Modify zero point for qnn batch_matmul test
  • Loading branch information
elvin-n authored and ylc committed Sep 29, 2021
1 parent 1ef2132 commit 74baf1c
Show file tree
Hide file tree
Showing 8 changed files with 600 additions and 67 deletions.
6 changes: 4 additions & 2 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,14 +521,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
if is_dynamic(out_type) or is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True),
wrap_compute_batch_matmul(
topi.nn.batch_matmul, need_auto_scheduler_layout=True, need_out_dtype=True
),
wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
name="batch_matmul.generic",
plevel=10,
)
else:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul),
wrap_compute_batch_matmul(topi.x86.batch_matmul, need_out_dtype=True),
wrap_topi_schedule(topi.x86.schedule_batch_matmul),
name="batch_matmul.x86",
plevel=10,
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,44 @@ def subtract(
)


def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype="int32"):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
in batch.
.. math::
\mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)
Parameters
----------
x : tvm.relay.Expr
The first quantized input.
A quantized tensor is represented in following manner
`A = scale_a x (QA - zp_A)`
where QA is quantized tensor, scale_a and zp_A are quantization
params.
y : tvm.relay.Expr
The second quantized input.
x_zero_point: tvm.relay.Expr
The first input zero point.
y_zero_point: tvm.relay.Expr
The second input zero point.
x_scale: tvm.relay.Expr
The scale for the first input tensor.
y_scale: tvm.relay.Expr
The scale for the second input tensor.
out_dtype : str, optional
Specifies the output data type for mixed precision dense can be int32 or int16.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype)


# register fuse pattern for qnn ops
reg.register_pattern("qnn.quantize", OpPattern.OPAQUE)
reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE)
28 changes: 21 additions & 7 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..utils import get_const_tuple


def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="", out_dtype=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch. Supports broadcasting for batch dimension.
Expand Down Expand Up @@ -67,12 +67,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
N = y.shape[1]
oshape = (batch, M, N)

output = te.compute(
oshape,
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
attrs={"layout_free_placeholders": [y]},
)
if out_dtype is None or out_dtype == x.dtype:
output = te.compute(
oshape,
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k
),
tag="batch_matmul",
attrs={"layout_free_placeholders": [y]},
)
else:
output = te.compute(
oshape,
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k].astype(out_dtype)
* y[b if YB != 1 else 0, j, k].astype(out_dtype),
axis=k,
),
tag="batch_matmul",
attrs={"layout_free_placeholders": [y]},
)

if auto_scheduler_rewritten_layout:
output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)
Expand Down
25 changes: 19 additions & 6 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


@autotvm.register_topi_compute("batch_matmul.x86")
def batch_matmul(cfg, x, y, out_shape=None):
def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch. Supports broadcasting in batch dimension.
Expand Down Expand Up @@ -60,11 +60,24 @@ def batch_matmul(cfg, x, y, out_shape=None):
_default_batch_matmul_config(cfg, M, N, K)

k = te.reduce_axis((0, K), name="k")
C = te.compute(
(B, M, N),
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
)
if out_dtype is None or out_dtype == x.dtype:
C = te.compute(
(B, M, N),
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k
),
tag="batch_matmul",
)
else:
C = te.compute(
(B, M, N),
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k].astype(out_dtype)
* y[b if YB != 1 else 0, j, k].astype(out_dtype),
axis=k,
),
tag="batch_matmul",
)
return C


Expand Down
53 changes: 1 addition & 52 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -935,57 +935,6 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,).
// relay.nn.batch_matmul
TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs);

bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;

const auto* param = attrs.as<BatchMatmulAttrs>();
Array<PrimExpr> y_shape;
if (param->auto_scheduler_rewritten_layout.size() == 0) {
y_shape = y->shape;
} else {
y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
{"b", "j", "k"});
}

ICHECK(x->shape.size() == 3 && y_shape.size() == 3);
bool is_dyn = false;
Array<tvm::PrimExpr> oshape;
for (size_t i = 0; i < 3; ++i) {
if (x->shape[i].as<tir::AnyNode>() != nullptr || y_shape[i].as<tir::AnyNode>() != nullptr) {
is_dyn = true;
oshape.push_back(Any());
} else {
if (i == 0) {
oshape.push_back(max(x->shape[i], y_shape[i]));
} else {
oshape.push_back(x->shape[i]);
}
}
}
if (!is_dyn) {
ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
reporter->AssertEQ(y_shape[0], 1))
<< "BatchDot: batch dimensions don't match, "
<< " x shape=" << x->shape << ", y shape=" << y_shape;
ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape << ", y shape=" << y_shape;
}
oshape.Set(2, y_shape[1]);

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = x->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) {
auto attrs = make_object<BatchMatmulAttrs>();
Expand Down Expand Up @@ -1013,7 +962,7 @@ are data in batch.
.add_argument("x", "3D Tensor", "First input.")
.add_argument("y", "3D Tensor", "Second input.")
.set_support_level(10)
.add_type_rel("BatchMatmul", BatchMatmulRel);
.add_type_rel("BatchMatmul", BatchMatmulRel<BatchMatmulAttrs>);

// relay.nn.cross_entropy
bool CrossEntropyRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
54 changes: 54 additions & 0 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
#ifndef TVM_RELAY_OP_NN_NN_H_
#define TVM_RELAY_OP_NN_NN_H_

#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/relay/type.h>

#include <algorithm>
#include <utility>

#include "../op_common.h"
Expand Down Expand Up @@ -137,6 +139,58 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

template <typename AttrType>
bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;

const AttrType* param = attrs.as<AttrType>();
Array<PrimExpr> y_shape;
if (param->auto_scheduler_rewritten_layout.size() == 0) {
y_shape = y->shape;
} else {
y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
{"b", "j", "k"});
}

ICHECK(x->shape.size() == 3 && y_shape.size() == 3);
bool is_dyn = false;
Array<tvm::PrimExpr> oshape;
for (size_t i = 0; i < 3; ++i) {
if (x->shape[i].as<tir::AnyNode>() != nullptr || y_shape[i].as<tir::AnyNode>() != nullptr) {
is_dyn = true;
oshape.push_back(Any());
} else {
if (i == 0) {
oshape.push_back(max(x->shape[i], y_shape[i]));
} else {
oshape.push_back(x->shape[i]);
}
}
}
if (!is_dyn) {
ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
reporter->AssertEQ(y_shape[0], 1))
<< "BatchDot: batch dimensions don't match, "
<< " x shape=" << x->shape << ", y shape=" << y_shape;
ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape << ", y shape=" << y_shape;
}
oshape.Set(2, y_shape[1]);

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = x->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_NN_H_
Loading

0 comments on commit 74baf1c

Please sign in to comment.