Skip to content

Commit

Permalink
[Topi & Relay] Add quantization support for the vision transform mode…
Browse files Browse the repository at this point in the history
…l in GPU (apache#7814)

* Add cuda batch matmul int8 support for quantized vit model

* Fix for combine parallel pass with dense and batch_matmul

* Reformat based on lint

* Add plevel & update the file download method
  • Loading branch information
huochaitiantang authored and Trevor Morris committed May 6, 2021
1 parent f7c382f commit 79da315
Show file tree
Hide file tree
Showing 16 changed files with 528 additions and 19 deletions.
8 changes: 7 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,14 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
/*! \brief Attributes for batch matmul operator */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;

TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {}
TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes for sparse_dense operator */
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,7 +2038,7 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True,
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)


def batch_matmul(x, y):
def batch_matmul(x, y, out_dtype=""):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
in batch.
Expand All @@ -2055,12 +2055,15 @@ def batch_matmul(x, y):
y : tvm.relay.Expr
The second input.
out_dtype : str, optional
Specifies the output data type for mixed precision batch matmul
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y)
return _make.batch_matmul(x, y, out_dtype)


# pylint: disable=no-else-return,inconsistent-return-statements
Expand Down
21 changes: 15 additions & 6 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,12 +722,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
"""batch_matmul cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
x, y = inputs
if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32":
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8),
name="batch_matmul_int8.cuda",
plevel=10,
)
else:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
if target.kind.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,15 @@ def dense_pack_strategy(attrs, inputs, out_type, target):


# batch_matmul
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False):
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, need_out_dtype=False):
"""wrap batch_matmul topi compute"""

def _compute_batch_matmul(attrs, inputs, out_type):
args = [inputs[0], inputs[1], out_type.shape]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
if need_out_dtype:
args.append(out_type.dtype)
return [topi_compute(*args)]

return _compute_batch_matmul
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,28 @@ def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
# stop quantize after global_avg_pool2d
quantize_context().stop_quantize()
return expr


@register_annotate_function("nn.batch_matmul")
def batch_matmul_rewrite(ref_call, new_args, ctx):
"""Rewrite function for batch_matmul"""
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
if _analysis.check_constant(lhs_expr):
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT)
else:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

if rhs_kind is None or rhs_kind == QAnnotateKind.ACTIVATION:
if _analysis.check_constant(rhs_expr):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
28 changes: 28 additions & 0 deletions python/tvm/relay/quantize/_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,32 @@ def func(_):
return func


def _find_scale_by_percentile(arr, percentile=0.99999):
assert isinstance(arr, np.ndarray)
x = np.abs(arr)
max_k = int(x.size * percentile)
return np.partition(x, max_k)[max_k]


def _percentile_scale(mod, dataset):
cfg = quantize.current_qconfig()
chunk_by = cfg.calibrate_chunk_by
scales = []
for samples in collect_stats(mod, dataset, chunk_by):
logging.info("finding threshold with percentile for calibration...")
with mp.Pool() as pool:
scales += list(pool.map(_find_scale_by_percentile, samples))

def func(_):
scale = scales[func.scale_idx]
func.scale_idx += 1
return scale

func.scale_idx = 0

return func


def _set_params(mod, input_scale_func, weight_scale_func):
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
cfg = quantize.current_qconfig()
Expand Down Expand Up @@ -195,6 +221,8 @@ def wrapped_func(mod, _):
input_scale_func = _kl_scale(mod, dataset)
elif cfg.calibrate_mode == "global_scale":
input_scale_func = _global_scale
elif cfg.calibrate_mode == "percentile":
input_scale_func = _percentile_scale(mod, dataset)
else:
raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode))

Expand Down
147 changes: 147 additions & 0 deletions python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic
from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor
from .tensor_intrin import dp4a


@autotvm.register_topi_compute("batch_matmul.cuda")
Expand Down Expand Up @@ -170,3 +171,149 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None):
def schedule_batch_matmul_cublas(_, outs):
"""Schedule batch_matmul operator using CUBLAS"""
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("batch_matmul_int8.cuda")
def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
"""Batch Matmul operator for int8 on CUDA"""
if out_dtype is None:
out_dtype = x.dtype

x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul"

XB, M, XK = x.shape
YB, N, YK = y.shape
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistent"

nB = tvm.te.max(XB, YB)
nK = ((XK + 3) // 4) * 4
reduce_k = te.reduce_axis((0, nK), name="k")

# pad for _dp4a vectorize
pad_x = te.compute(
(XB, M, nK),
lambda b, i, j: tvm.te.if_then_else(
j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j]
),
)
pad_y = te.compute(
(YB, N, nK),
lambda b, i, j: tvm.te.if_then_else(
j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j]
),
)

out = te.compute(
(nB, M, N),
lambda b, i, j: te.sum(
pad_x[b if XB != 1 else 0, i, reduce_k].astype(out_dtype)
* pad_y[b if YB != 1 else 0, j, reduce_k].astype(out_dtype),
axis=[reduce_k],
),
tag="batch_matmul_int8",
)
cfg.add_flop(XB * M * N * nK * 2)
return out


@autotvm.register_topi_schedule("batch_matmul_int8.cuda")
def schedule_batch_matmul_int8(cfg, outs):
"""Batch Matmul schedule for int8 on CUDA"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "batch_matmul_int8" in op.tag:
_schedule_batch_matmul_int8(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


_dp4a = dp4a("shared", "shared", "local")


def _schedule_batch_matmul_int8(cfg, s, output):
input_x, input_y = s[output].op.input_tensors

B, M, K = get_const_tuple(input_x.shape)
_, N, _ = get_const_tuple(input_y.shape)

k_factor = 4
assert K % k_factor == 0, "Input dimension must divide {}".format(k_factor)
if K % 16 == 0:
k_factor = 16

cfg.define_split("tile_f", B, num_outputs=4)
cfg.define_split("tile_m", M, num_outputs=4)
cfg.define_split("tile_n", N, num_outputs=4)
cfg.define_split("tile_k", K // k_factor, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 256, 512, 1024])

batch_matmul_op = s.outputs[0]
s[input_x].compute_inline()
s[input_y].compute_inline()

x_cache = s.cache_read(input_x, "shared", [batch_matmul_op])
y_cache = s.cache_read(input_y, "shared", [batch_matmul_op])
batch_matmul_cache = s.cache_write(batch_matmul_op.output(0), "local")

# tile reduce axis
ko = batch_matmul_cache.op.reduce_axis[0]
ko, ki = s[batch_matmul_cache].split(ko, factor=4)
ko, kt = cfg["tile_k"].apply(s, batch_matmul_cache, ko)
# dp4a tensorize
s[batch_matmul_cache].tensorize(ki, _dp4a)

# tile axis
f, m, n = batch_matmul_op.axis
kernel_scope, f = s[batch_matmul_op].split(f, nparts=1)

bf, vf, tf, fi = cfg["tile_f"].apply(s, batch_matmul_op, f)
bm, vm, tm, mi = cfg["tile_m"].apply(s, batch_matmul_op, m)
bn, vn, tn, ni = cfg["tile_n"].apply(s, batch_matmul_op, n)
s[batch_matmul_op].reorder(bf, bm, bn, vf, vm, vn, tf, tm, tn, fi, mi, ni)

# bind axis
s[batch_matmul_op].bind(bf, tvm.te.thread_axis("blockIdx.z"))
s[batch_matmul_op].bind(bm, tvm.te.thread_axis("blockIdx.y"))
s[batch_matmul_op].bind(bn, tvm.te.thread_axis("blockIdx.x"))

s[batch_matmul_op].bind(vf, tvm.te.thread_axis("vthread"))
s[batch_matmul_op].bind(vm, tvm.te.thread_axis("vthread"))
s[batch_matmul_op].bind(vn, tvm.te.thread_axis("vthread"))

s[batch_matmul_op].bind(tf, tvm.te.thread_axis("threadIdx.z"))
s[batch_matmul_op].bind(tm, tvm.te.thread_axis("threadIdx.y"))
s[batch_matmul_op].bind(tn, tvm.te.thread_axis("threadIdx.x"))

# cache compute at
s[batch_matmul_cache].compute_at(s[batch_matmul_op], tn)
fo, mo, no = batch_matmul_cache.op.axis[:3]
s[batch_matmul_cache].reorder(ko, kt, fo, mo, no, ki)

# for load in [splited_x_op, splited_y_op]
for load in [x_cache, y_cache]:
s[load].compute_at(s[batch_matmul_cache], ko)
outer, inner = s[load].split(s[load].op.axis[-1], factor=k_factor)
s[load].vectorize(inner)

fused = s[load].op.axis[:-1] + [outer]
fused = s[load].fuse(*fused)

fused, tx = s[load].split(fused, factor=cfg["tile_n"].size[2])
fused, ty = s[load].split(fused, factor=cfg["tile_m"].size[2])
fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2])

s[load].bind(tz, tvm.te.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.te.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.te.thread_axis("threadIdx.x"))

# max unroll
s[batch_matmul_op].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
s[batch_matmul_op].pragma(kernel_scope, "unroll_explicit", False)

return s
12 changes: 9 additions & 3 deletions python/tvm/topi/testing/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def batch_matmul(x, y):
def batch_matmul(x, y, out_dtype=None):
"""batch_matmul operator implemented in numpy.
Parameters
Expand All @@ -30,6 +30,9 @@ def batch_matmul(x, y):
y : numpy.ndarray
3-D with shape [batch, N, K]
out_dtype: string, optional
Specify the dtype of output
Returns
-------
out : numpy.ndarray
Expand All @@ -38,7 +41,10 @@ def batch_matmul(x, y):
XB, M, _ = x.shape
YB, N, _ = y.shape
batch = max(XB, YB)
out = np.zeros((batch, M, N)).astype(x.dtype)
dtype = x.dtype if out_dtype is None else out_dtype
out = np.zeros((batch, M, N)).astype(dtype)
for i in range(batch):
out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T)
out[i] = np.dot(
x[i if XB != 1 else 0].astype(dtype), y[i if YB != 1 else 0].T.astype(dtype)
)
return out
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Expr MakeConcatenate(Expr data, int axis);

Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);

Expr MakeBatchMatmul(Expr lhs, Expr rhs);
Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype);

Expr MakeExpandDims(Expr data, int axis, int num_newaxis);

Expand Down
9 changes: 7 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,14 +943,19 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
oshape.Set(2, y_shape[1]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = x->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, x->dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x, Expr y) {
Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) {
auto attrs = make_object<BatchMatmulAttrs>();
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.batch_matmul");
return Call(op, {x, y}, Attrs(attrs), {});
}
Expand Down
Loading

0 comments on commit 79da315

Please sign in to comment.