Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi & Relay] Add quantization support for the vision transform model in GPU #7814

Merged
merged 5 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,14 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
/*! \brief Attributes for batch matmul operator */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;

TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {}
TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes for sparse_dense operator */
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,7 +2038,7 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True,
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)


def batch_matmul(x, y):
def batch_matmul(x, y, out_dtype=""):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
in batch.
Expand All @@ -2055,12 +2055,15 @@ def batch_matmul(x, y):
y : tvm.relay.Expr
The second input.

out_dtype : str, optional
Specifies the output data type for mixed precision batch matmul

Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y)
return _make.batch_matmul(x, y, out_dtype)


# pylint: disable=no-else-return,inconsistent-return-statements
Expand Down
21 changes: 15 additions & 6 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,12 +722,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
"""batch_matmul cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
x, y = inputs
if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32":
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8),
name="batch_matmul_int8.cuda",
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
plevel=10,
)
else:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
if target.kind.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,15 @@ def dense_pack_strategy(attrs, inputs, out_type, target):


# batch_matmul
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False):
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, need_out_dtype=False):
"""wrap batch_matmul topi compute"""

def _compute_batch_matmul(attrs, inputs, out_type):
args = [inputs[0], inputs[1], out_type.shape]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
if need_out_dtype:
args.append(out_type.dtype)
return [topi_compute(*args)]

return _compute_batch_matmul
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,28 @@ def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
# stop quantize after global_avg_pool2d
quantize_context().stop_quantize()
return expr


@register_annotate_function("nn.batch_matmul")
def batch_matmul_rewrite(ref_call, new_args, ctx):
"""Rewrite function for batch_matmul"""
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
if _analysis.check_constant(lhs_expr):
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT)
else:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

if rhs_kind is None or rhs_kind == QAnnotateKind.ACTIVATION:
if _analysis.check_constant(rhs_expr):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
28 changes: 28 additions & 0 deletions python/tvm/relay/quantize/_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,32 @@ def func(_):
return func


def _find_scale_by_percentile(arr, percentile=0.99999):
assert isinstance(arr, np.ndarray)
x = np.abs(arr)
max_k = int(x.size * percentile)
return np.partition(x, max_k)[max_k]


def _percentile_scale(mod, dataset):
cfg = quantize.current_qconfig()
chunk_by = cfg.calibrate_chunk_by
scales = []
for samples in collect_stats(mod, dataset, chunk_by):
logging.info("finding threshold with percentile for calibration...")
with mp.Pool() as pool:
scales += list(pool.map(_find_scale_by_percentile, samples))

def func(_):
scale = scales[func.scale_idx]
func.scale_idx += 1
return scale

func.scale_idx = 0

return func


def _set_params(mod, input_scale_func, weight_scale_func):
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
cfg = quantize.current_qconfig()
Expand Down Expand Up @@ -195,6 +221,8 @@ def wrapped_func(mod, _):
input_scale_func = _kl_scale(mod, dataset)
elif cfg.calibrate_mode == "global_scale":
input_scale_func = _global_scale
elif cfg.calibrate_mode == "percentile":
input_scale_func = _percentile_scale(mod, dataset)
else:
raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode))

Expand Down
147 changes: 147 additions & 0 deletions python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic
from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor
from .tensor_intrin import dp4a


@autotvm.register_topi_compute("batch_matmul.cuda")
Expand Down Expand Up @@ -170,3 +171,149 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None):
def schedule_batch_matmul_cublas(_, outs):
"""Schedule batch_matmul operator using CUBLAS"""
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("batch_matmul_int8.cuda")
def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
"""Batch Matmul operator for int8 on CUDA"""
if out_dtype is None:
out_dtype = x.dtype

x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul"

XB, M, XK = x.shape
YB, N, YK = y.shape
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistent"

nB = tvm.te.max(XB, YB)
nK = ((XK + 3) // 4) * 4
reduce_k = te.reduce_axis((0, nK), name="k")

# pad for _dp4a vectorize
pad_x = te.compute(
(XB, M, nK),
lambda b, i, j: tvm.te.if_then_else(
j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j]
),
)
pad_y = te.compute(
(YB, N, nK),
lambda b, i, j: tvm.te.if_then_else(
j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j]
),
)

out = te.compute(
(nB, M, N),
lambda b, i, j: te.sum(
pad_x[b if XB != 1 else 0, i, reduce_k].astype(out_dtype)
* pad_y[b if YB != 1 else 0, j, reduce_k].astype(out_dtype),
axis=[reduce_k],
),
tag="batch_matmul_int8",
)
cfg.add_flop(XB * M * N * nK * 2)
return out


@autotvm.register_topi_schedule("batch_matmul_int8.cuda")
def schedule_batch_matmul_int8(cfg, outs):
"""Batch Matmul schedule for int8 on CUDA"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "batch_matmul_int8" in op.tag:
_schedule_batch_matmul_int8(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


_dp4a = dp4a("shared", "shared", "local")


def _schedule_batch_matmul_int8(cfg, s, output):
input_x, input_y = s[output].op.input_tensors

B, M, K = get_const_tuple(input_x.shape)
_, N, _ = get_const_tuple(input_y.shape)

k_factor = 4
assert K % k_factor == 0, "Input dimension must divide {}".format(k_factor)
if K % 16 == 0:
k_factor = 16

cfg.define_split("tile_f", B, num_outputs=4)
cfg.define_split("tile_m", M, num_outputs=4)
cfg.define_split("tile_n", N, num_outputs=4)
cfg.define_split("tile_k", K // k_factor, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 256, 512, 1024])

batch_matmul_op = s.outputs[0]
s[input_x].compute_inline()
s[input_y].compute_inline()

x_cache = s.cache_read(input_x, "shared", [batch_matmul_op])
y_cache = s.cache_read(input_y, "shared", [batch_matmul_op])
batch_matmul_cache = s.cache_write(batch_matmul_op.output(0), "local")

# tile reduce axis
ko = batch_matmul_cache.op.reduce_axis[0]
ko, ki = s[batch_matmul_cache].split(ko, factor=4)
ko, kt = cfg["tile_k"].apply(s, batch_matmul_cache, ko)
# dp4a tensorize
s[batch_matmul_cache].tensorize(ki, _dp4a)

# tile axis
f, m, n = batch_matmul_op.axis
kernel_scope, f = s[batch_matmul_op].split(f, nparts=1)

bf, vf, tf, fi = cfg["tile_f"].apply(s, batch_matmul_op, f)
bm, vm, tm, mi = cfg["tile_m"].apply(s, batch_matmul_op, m)
bn, vn, tn, ni = cfg["tile_n"].apply(s, batch_matmul_op, n)
s[batch_matmul_op].reorder(bf, bm, bn, vf, vm, vn, tf, tm, tn, fi, mi, ni)

# bind axis
s[batch_matmul_op].bind(bf, tvm.te.thread_axis("blockIdx.z"))
s[batch_matmul_op].bind(bm, tvm.te.thread_axis("blockIdx.y"))
s[batch_matmul_op].bind(bn, tvm.te.thread_axis("blockIdx.x"))

s[batch_matmul_op].bind(vf, tvm.te.thread_axis("vthread"))
s[batch_matmul_op].bind(vm, tvm.te.thread_axis("vthread"))
s[batch_matmul_op].bind(vn, tvm.te.thread_axis("vthread"))

s[batch_matmul_op].bind(tf, tvm.te.thread_axis("threadIdx.z"))
s[batch_matmul_op].bind(tm, tvm.te.thread_axis("threadIdx.y"))
s[batch_matmul_op].bind(tn, tvm.te.thread_axis("threadIdx.x"))

# cache compute at
s[batch_matmul_cache].compute_at(s[batch_matmul_op], tn)
fo, mo, no = batch_matmul_cache.op.axis[:3]
s[batch_matmul_cache].reorder(ko, kt, fo, mo, no, ki)

# for load in [splited_x_op, splited_y_op]
for load in [x_cache, y_cache]:
s[load].compute_at(s[batch_matmul_cache], ko)
outer, inner = s[load].split(s[load].op.axis[-1], factor=k_factor)
s[load].vectorize(inner)

fused = s[load].op.axis[:-1] + [outer]
fused = s[load].fuse(*fused)

fused, tx = s[load].split(fused, factor=cfg["tile_n"].size[2])
fused, ty = s[load].split(fused, factor=cfg["tile_m"].size[2])
fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2])

s[load].bind(tz, tvm.te.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.te.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.te.thread_axis("threadIdx.x"))

# max unroll
s[batch_matmul_op].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
s[batch_matmul_op].pragma(kernel_scope, "unroll_explicit", False)

return s
12 changes: 9 additions & 3 deletions python/tvm/topi/testing/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def batch_matmul(x, y):
def batch_matmul(x, y, out_dtype=None):
"""batch_matmul operator implemented in numpy.

Parameters
Expand All @@ -30,6 +30,9 @@ def batch_matmul(x, y):
y : numpy.ndarray
3-D with shape [batch, N, K]

out_dtype: string, optional
Specify the dtype of output

Returns
-------
out : numpy.ndarray
Expand All @@ -38,7 +41,10 @@ def batch_matmul(x, y):
XB, M, _ = x.shape
YB, N, _ = y.shape
batch = max(XB, YB)
out = np.zeros((batch, M, N)).astype(x.dtype)
dtype = x.dtype if out_dtype is None else out_dtype
out = np.zeros((batch, M, N)).astype(dtype)
for i in range(batch):
out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T)
out[i] = np.dot(
x[i if XB != 1 else 0].astype(dtype), y[i if YB != 1 else 0].T.astype(dtype)
)
return out
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Expr MakeConcatenate(Expr data, int axis);

Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);

Expr MakeBatchMatmul(Expr lhs, Expr rhs);
Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype);

Expr MakeExpandDims(Expr data, int axis, int num_newaxis);

Expand Down
9 changes: 7 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,14 +943,19 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
oshape.Set(2, y_shape[1]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = x->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, x->dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x, Expr y) {
Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) {
auto attrs = make_object<BatchMatmulAttrs>();
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.batch_matmul");
return Call(op, {x, y}, Attrs(attrs), {});
}
Expand Down
Loading