diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index c3c58e54517cc..42adefa314293 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -942,8 +942,14 @@ struct DenseAttrs : public tvm::AttrsNode { /*! \brief Attributes for batch matmul operator */ struct BatchMatmulAttrs : public tvm::AttrsNode { tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + DataType out_dtype; - TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {} + TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") { + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } }; /*! \brief Attributes for sparse_dense operator */ diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index a1147fec4d7eb..5acdafef0ed96 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2038,7 +2038,7 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True, return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale) -def batch_matmul(x, y): +def batch_matmul(x, y, out_dtype=""): r""" Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -2055,12 +2055,15 @@ def batch_matmul(x, y): y : tvm.relay.Expr The second input. + out_dtype : str, optional + Specifies the output data type for mixed precision batch matmul + Returns ------- result: tvm.relay.Expr The computed result. """ - return _make.batch_matmul(x, y) + return _make.batch_matmul(x, y, out_dtype) # pylint: disable=no-else-return,inconsistent-return-statements diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1a67425266077..61ab421427322 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -722,12 +722,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): """batch_matmul cuda strategy""" strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul), - wrap_topi_schedule(topi.cuda.schedule_batch_matmul), - name="batch_matmul.cuda", - plevel=10, - ) + x, y = inputs + if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32": + strategy.add_implementation( + wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True), + wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8), + name="batch_matmul_int8.cuda", + plevel=10, + ) + else: + strategy.add_implementation( + wrap_compute_batch_matmul(topi.cuda.batch_matmul), + wrap_topi_schedule(topi.cuda.schedule_batch_matmul), + name="batch_matmul.cuda", + plevel=10, + ) if target.kind.name == "cuda" and "cublas" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas), diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 0645c80872f91..845995e6ace4a 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -746,13 +746,15 @@ def dense_pack_strategy(attrs, inputs, out_type, target): # batch_matmul -def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False): +def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, need_out_dtype=False): """wrap batch_matmul topi compute""" def _compute_batch_matmul(attrs, inputs, out_type): args = [inputs[0], inputs[1], out_type.shape] if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) + if need_out_dtype: + args.append(out_type.dtype) return [topi_compute(*args)] return _compute_batch_matmul diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 6c395e257cc72..ff673d23144a8 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -407,3 +407,28 @@ def global_avg_pool2d_rewrite(ref_call, new_args, ctx): # stop quantize after global_avg_pool2d quantize_context().stop_quantize() return expr + + +@register_annotate_function("nn.batch_matmul") +def batch_matmul_rewrite(ref_call, new_args, ctx): + """Rewrite function for batch_matmul""" + if quantize_context().check_to_skip(ref_call): + return None + + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) + rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + + if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION: + if _analysis.check_constant(lhs_expr): + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT) + else: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + + if rhs_kind is None or rhs_kind == QAnnotateKind.ACTIVATION: + if _analysis.check_constant(rhs_expr): + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + else: + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) + + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py index a906a98dccd45..b56c09cdad097 100644 --- a/python/tvm/relay/quantize/_calibrate.py +++ b/python/tvm/relay/quantize/_calibrate.py @@ -109,6 +109,32 @@ def func(_): return func +def _find_scale_by_percentile(arr, percentile=0.99999): + assert isinstance(arr, np.ndarray) + x = np.abs(arr) + max_k = int(x.size * percentile) + return np.partition(x, max_k)[max_k] + + +def _percentile_scale(mod, dataset): + cfg = quantize.current_qconfig() + chunk_by = cfg.calibrate_chunk_by + scales = [] + for samples in collect_stats(mod, dataset, chunk_by): + logging.info("finding threshold with percentile for calibration...") + with mp.Pool() as pool: + scales += list(pool.map(_find_scale_by_percentile, samples)) + + def func(_): + scale = scales[func.scale_idx] + func.scale_idx += 1 + return scale + + func.scale_idx = 0 + + return func + + def _set_params(mod, input_scale_func, weight_scale_func): quantize_op = _op.get("relay.op.annotation.simulated_quantize") cfg = quantize.current_qconfig() @@ -195,6 +221,8 @@ def wrapped_func(mod, _): input_scale_func = _kl_scale(mod, dataset) elif cfg.calibrate_mode == "global_scale": input_scale_func = _global_scale + elif cfg.calibrate_mode == "percentile": + input_scale_func = _percentile_scale(mod, dataset) else: raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode)) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 04e484f526d23..fb91912f29a01 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -23,6 +23,7 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from .. import nn, generic from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor +from .tensor_intrin import dp4a @autotvm.register_topi_compute("batch_matmul.cuda") @@ -170,3 +171,149 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None): def schedule_batch_matmul_cublas(_, outs): """Schedule batch_matmul operator using CUBLAS""" return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("batch_matmul_int8.cuda") +def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): + """Batch Matmul operator for int8 on CUDA""" + if out_dtype is None: + out_dtype = x.dtype + + x_shape = get_const_tuple(x.shape) + y_shape = get_const_tuple(y.shape) + assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul" + + XB, M, XK = x.shape + YB, N, YK = y.shape + assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" + assert XK == YK, "shapes of x and y is inconsistent" + + nB = tvm.te.max(XB, YB) + nK = ((XK + 3) // 4) * 4 + reduce_k = te.reduce_axis((0, nK), name="k") + + # pad for _dp4a vectorize + pad_x = te.compute( + (XB, M, nK), + lambda b, i, j: tvm.te.if_then_else( + j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j] + ), + ) + pad_y = te.compute( + (YB, N, nK), + lambda b, i, j: tvm.te.if_then_else( + j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j] + ), + ) + + out = te.compute( + (nB, M, N), + lambda b, i, j: te.sum( + pad_x[b if XB != 1 else 0, i, reduce_k].astype(out_dtype) + * pad_y[b if YB != 1 else 0, j, reduce_k].astype(out_dtype), + axis=[reduce_k], + ), + tag="batch_matmul_int8", + ) + cfg.add_flop(XB * M * N * nK * 2) + return out + + +@autotvm.register_topi_schedule("batch_matmul_int8.cuda") +def schedule_batch_matmul_int8(cfg, outs): + """Batch Matmul schedule for int8 on CUDA""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "batch_matmul_int8" in op.tag: + _schedule_batch_matmul_int8(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + + +_dp4a = dp4a("shared", "shared", "local") + + +def _schedule_batch_matmul_int8(cfg, s, output): + input_x, input_y = s[output].op.input_tensors + + B, M, K = get_const_tuple(input_x.shape) + _, N, _ = get_const_tuple(input_y.shape) + + k_factor = 4 + assert K % k_factor == 0, "Input dimension must divide {}".format(k_factor) + if K % 16 == 0: + k_factor = 16 + + cfg.define_split("tile_f", B, num_outputs=4) + cfg.define_split("tile_m", M, num_outputs=4) + cfg.define_split("tile_n", N, num_outputs=4) + cfg.define_split("tile_k", K // k_factor, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 256, 512, 1024]) + + batch_matmul_op = s.outputs[0] + s[input_x].compute_inline() + s[input_y].compute_inline() + + x_cache = s.cache_read(input_x, "shared", [batch_matmul_op]) + y_cache = s.cache_read(input_y, "shared", [batch_matmul_op]) + batch_matmul_cache = s.cache_write(batch_matmul_op.output(0), "local") + + # tile reduce axis + ko = batch_matmul_cache.op.reduce_axis[0] + ko, ki = s[batch_matmul_cache].split(ko, factor=4) + ko, kt = cfg["tile_k"].apply(s, batch_matmul_cache, ko) + # dp4a tensorize + s[batch_matmul_cache].tensorize(ki, _dp4a) + + # tile axis + f, m, n = batch_matmul_op.axis + kernel_scope, f = s[batch_matmul_op].split(f, nparts=1) + + bf, vf, tf, fi = cfg["tile_f"].apply(s, batch_matmul_op, f) + bm, vm, tm, mi = cfg["tile_m"].apply(s, batch_matmul_op, m) + bn, vn, tn, ni = cfg["tile_n"].apply(s, batch_matmul_op, n) + s[batch_matmul_op].reorder(bf, bm, bn, vf, vm, vn, tf, tm, tn, fi, mi, ni) + + # bind axis + s[batch_matmul_op].bind(bf, tvm.te.thread_axis("blockIdx.z")) + s[batch_matmul_op].bind(bm, tvm.te.thread_axis("blockIdx.y")) + s[batch_matmul_op].bind(bn, tvm.te.thread_axis("blockIdx.x")) + + s[batch_matmul_op].bind(vf, tvm.te.thread_axis("vthread")) + s[batch_matmul_op].bind(vm, tvm.te.thread_axis("vthread")) + s[batch_matmul_op].bind(vn, tvm.te.thread_axis("vthread")) + + s[batch_matmul_op].bind(tf, tvm.te.thread_axis("threadIdx.z")) + s[batch_matmul_op].bind(tm, tvm.te.thread_axis("threadIdx.y")) + s[batch_matmul_op].bind(tn, tvm.te.thread_axis("threadIdx.x")) + + # cache compute at + s[batch_matmul_cache].compute_at(s[batch_matmul_op], tn) + fo, mo, no = batch_matmul_cache.op.axis[:3] + s[batch_matmul_cache].reorder(ko, kt, fo, mo, no, ki) + + # for load in [splited_x_op, splited_y_op] + for load in [x_cache, y_cache]: + s[load].compute_at(s[batch_matmul_cache], ko) + outer, inner = s[load].split(s[load].op.axis[-1], factor=k_factor) + s[load].vectorize(inner) + + fused = s[load].op.axis[:-1] + [outer] + fused = s[load].fuse(*fused) + + fused, tx = s[load].split(fused, factor=cfg["tile_n"].size[2]) + fused, ty = s[load].split(fused, factor=cfg["tile_m"].size[2]) + fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2]) + + s[load].bind(tz, tvm.te.thread_axis("threadIdx.z")) + s[load].bind(ty, tvm.te.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.te.thread_axis("threadIdx.x")) + + # max unroll + s[batch_matmul_op].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[batch_matmul_op].pragma(kernel_scope, "unroll_explicit", False) + + return s diff --git a/python/tvm/topi/testing/batch_matmul.py b/python/tvm/topi/testing/batch_matmul.py index a48c92967c77c..96d1fcbb5bc31 100644 --- a/python/tvm/topi/testing/batch_matmul.py +++ b/python/tvm/topi/testing/batch_matmul.py @@ -19,7 +19,7 @@ import numpy as np -def batch_matmul(x, y): +def batch_matmul(x, y, out_dtype=None): """batch_matmul operator implemented in numpy. Parameters @@ -30,6 +30,9 @@ def batch_matmul(x, y): y : numpy.ndarray 3-D with shape [batch, N, K] + out_dtype: string, optional + Specify the dtype of output + Returns ------- out : numpy.ndarray @@ -38,7 +41,10 @@ def batch_matmul(x, y): XB, M, _ = x.shape YB, N, _ = y.shape batch = max(XB, YB) - out = np.zeros((batch, M, N)).astype(x.dtype) + dtype = x.dtype if out_dtype is None else out_dtype + out = np.zeros((batch, M, N)).astype(dtype) for i in range(batch): - out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T) + out[i] = np.dot( + x[i if XB != 1 else 0].astype(dtype), y[i if YB != 1 else 0].T.astype(dtype) + ) return out diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 36a5ec1c0e72c..e5a20abd76242 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -46,7 +46,7 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); -Expr MakeBatchMatmul(Expr lhs, Expr rhs); +Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype); Expr MakeExpandDims(Expr data, int axis, int num_newaxis); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index b2404cc1954b2..32c0a21d46c7c 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -943,14 +943,19 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs oshape.Set(2, y_shape[1]); } + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = x->dtype; + } // assign output type - reporter->Assign(types[2], TensorType(oshape, x->dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } // Positional relay function to create batch_matmul operator used by frontend FFI. -Expr MakeBatchMatmul(Expr x, Expr y) { +Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) { auto attrs = make_object(); + attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.batch_matmul"); return Call(op, {x, y}, Attrs(attrs), {}); } diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index d77ede3acbf92..968628fbfe39c 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -539,6 +539,41 @@ Expr CastHintRealize(const Call& ref_call, const Array& new_args, const Ob RELAY_REGISTER_OP("annotation.cast_hint") .set_attr("FQRealizeRewrite", CastHintRealize); +Expr BatchMatmulRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { + const QConfig& cfg = QConfig::Current(); + ICHECK_EQ(new_args.size(), 2); + if (!new_args[0]->IsInstance() || !new_args[1]->IsInstance()) { + return Expr(nullptr); + } + const auto* lhs = new_args[0].as(); + const auto* rhs = new_args[1].as(); + + Expr ldata = lhs->data; + Expr rdata = rhs->data; + DataType dtype = cfg->dtype_input; + + if (lhs->dtype != dtype) { + ldata = Cast(ldata, dtype); + } + if (rhs->dtype != dtype) { + rdata = Cast(rdata, dtype); + } + + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_object(); + *attrs = *ref_attrs; + DataType out_dtype = cfg->dtype_activation; + attrs->out_dtype = out_dtype; + + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); + return QRealizeIntExpr(ret, dom_scale, out_dtype); +} + +RELAY_REGISTER_OP("nn.batch_matmul") + .set_attr("FQRealizeRewrite", BatchMatmulRealize); + Pass QuantizeRealizePass() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc index 20a7c7ff78157..f8c46d93c6755 100644 --- a/src/relay/transforms/combine_parallel_batch_matmul.cc +++ b/src/relay/transforms/combine_parallel_batch_matmul.cc @@ -57,15 +57,20 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { StructuralEqual eq; + const auto* attrs_a = a->attrs.as(); + const auto* attrs_b = b->attrs.as(); + ICHECK(attrs_a); + ICHECK(attrs_b); const auto* rhs_a = a->args[1]->type_as(); const auto* rhs_b = b->args[1]->type_as(); const auto* restype_a = a->type_as(); const auto* restype_b = b->type_as(); // shape[2] is the contraction axis and automatically consistent // if it were valid batch_matmul ops + auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) && (rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) && - eq(rhs_a->shape[0], rhs_b->shape[0]); + eq(rhs_a->shape[0], rhs_b->shape[0]) && eq(attrs_a->out_dtype, attrs_b->out_dtype); return res; } @@ -78,7 +83,10 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { weights.push_back(call->args[1]); } Expr new_weight = MakeConcatenate(Tuple(weights), 1); - return Downcast(MakeBatchMatmul(data, new_weight)); + + const auto* origin_attrs = branches[0][0]->attrs.as(); + ICHECK(origin_attrs); + return Downcast(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype)); } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; } diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index d9ca4bf2042ef..3cd9cca4fec44 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -70,7 +70,9 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { } CHECK_EQ(num_args, 2); - return Downcast(MakeBatchMatmul(new_args[0], new_args[1])); + const auto* origin_attrs = branches[0][0]->attrs.as(); + ICHECK(origin_attrs); + return Downcast(MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype)); } virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { diff --git a/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py b/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py new file mode 100644 index 0000000000000..2d581d8c2acd7 --- /dev/null +++ b/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import os +import sys +from tvm import relay +from tvm.relay import quantize as qtz +import logging +import onnx +import tvm.testing +import mxnet as mx +from test_quantization_accuracy import Config, get_val_data, eval_acc + +logging.basicConfig(level=logging.INFO) + + +def calibrate_dataset(model_name, rec_val, batch_size, calibration_samples): + val_data, _ = get_val_data(model_name, rec_val=rec_val, batch_size=batch_size) + val_data.reset() + for i, batch in enumerate(val_data): + if i * batch_size >= calibration_samples: + break + data = batch.data[0].asnumpy() + yield {"data": data} + + +def download_file(url_base, file_name): + if not os.path.exists(file_name) or not os.path.isfile(file_name): + import urllib.request as urllib2 + + url = "{}/{}".format(url_base, file_name) + try: + print("download from {}".format(url)) + if sys.version_info >= (3,): + urllib2.urlretrieve(url, file_name) + else: + f = urllib2.urlopen(url) + data = f.read() + with open(file_name, "wb") as code: + code.write(data) + except Exception as err: + if os.path.exists(file_name): + os.remove(file_name) + raise Exception("download {} failed due to {}!".format(file_name, repr(err))) + + +def get_onnx_model(model_name, batch_size, qconfig, target=None, original=False, dataset=None): + assert model_name == "vit32", "Only support vit32 model!" + base = "https://github.com/TheGreatCold/tvm-vit/raw/d2aa1e60eef42e2fdedbd1e13aa85ac5faf0a7fc" + logfile = "gtx1660_vit_B32_224.log" + onnx_path = "vit_B32_224.onnx" + + download_file(base, logfile) + download_file(base, onnx_path) + + onnx_graph = onnx.load(open(onnx_path, "rb")) + data_shape = (batch_size, 3, 224, 224) + mod, params = relay.frontend.from_onnx(onnx_graph, {"data": data_shape}) + + with tvm.transform.PassContext(opt_level=3): + qfunc = relay.quantize.prerequisite_optimize(mod, params=params) + logging.debug("original") + logging.debug(qfunc.astext(show_meta_data=False)) + if original: + return qfunc, logfile + + with qconfig: + logging.debug("current quantize config") + logging.debug(qtz.current_qconfig()) + + if dataset is not None: + with tvm.target.cuda(): + with tvm.autotvm.apply_history_best(logfile): + qfunc = qtz.quantize(qfunc, params, dataset=dataset) + else: + qfunc = qtz.quantize(qfunc, params) + + logging.debug("after quantize") + logging.debug(qfunc.astext(show_meta_data=False)) + return qfunc, logfile + + +@tvm.testing.requires_gpu +def test_onnx_quantize_acc(cfg, rec_val, batch_size=1, original=False): + qconfig = qtz.qconfig( + skip_conv_layers=[0], + skip_dense_layer=False, + nbit_input=cfg.nbit_input, + nbit_weight=cfg.nbit_input, + dtype_input=cfg.dtype_input, + dtype_weight=cfg.dtype_input, + dtype_activation=cfg.dtype_output, + debug_enabled_ops=None, + calibrate_mode="percentile", + calibrate_chunk_by=8, + ) + + dataset = list(calibrate_dataset(cfg.model, rec_val, batch_size, 64)) + model, logfile = get_onnx_model( + cfg.model, batch_size, qconfig, tvm.target.cuda(), original=original, dataset=dataset + ) + val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=batch_size) + + with tvm.autotvm.apply_history_best(logfile): + acc = eval_acc(model, val_data, batch_fn, log_interval=1000) + assert acc > cfg.expected_acc + return acc + + +if __name__ == "__main__": + # TODO(for user): replace the line with the path to imagenet validation dataset + rec_val = "/scratch/tqchen/imagenet/val.rec" + + configs = [ + Config( + "vit32", + nbit_input=8, + dtype_input="int8", + nbit_output=32, + dtype_output="int32", + global_scale=8.0, + expected_acc=0.727, + ), + ] + + for config in configs: + + # float32 model + acc = test_onnx_quantize_acc(config, rec_val, batch_size=1, original=True) + print("{}-float32: {}".format(config.model, acc)) + + # int8 model + acc = test_onnx_quantize_acc(config, rec_val, batch_size=1, original=False) + print("{}-int8: {}".format(config.model, acc)) diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 326416f3c501b..6ebff0e6ac8b9 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -74,6 +74,29 @@ def _check_batch_flatten(node): relay.analysis.post_order_visit(qmod["main"], _check_batch_flatten) +def test_batch_matmul_rewrite(): + data = relay.var("data", shape=(1, 4, 16, 16)) + data2 = relay.sigmoid(relay.var("data", shape=(4, 16, 64))) + out = relay.nn.conv2d(data, relay.var("weight"), kernel_size=(3, 3), padding=(1, 1), channels=8) + + out = relay.nn.batch_flatten(out) + out = relay.reshape(out, [1, 32, 64]) + out = relay.nn.batch_matmul(out, data2) + + qmod = quantize_and_build(out) + + def _check_batch_matmul(node): + if isinstance(node, Call): + + if node.op.name in ["nn.batch_matmul", "nn.conv2d"]: + assert node.checked_type.dtype == "int32" + elif node.op.name == "nn.batch_flatten": + assert node.checked_type.dtype == "int8" + + # check if batch_matmul is quantized + relay.analysis.post_order_visit(qmod["main"], _check_batch_matmul) + + def get_calibration_dataset(mod, input_name): dataset = [] input_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] @@ -106,6 +129,13 @@ def test_calibrate_memory_bound(): relay.quantize.quantize(mod, params, dataset) +def test_calibrate_percentile(): + mod, params = testing.synthetic.get_workload() + dataset = get_calibration_dataset(mod, "data") + with relay.quantize.qconfig(calibrate_mode="percentile"): + relay.quantize.quantize(mod, params, dataset) + + #################################### # Quant/Dequant Partitioning Tests # #################################### @@ -343,9 +373,11 @@ def visit_call(self, call): if __name__ == "__main__": test_mul_rewrite() test_batch_flatten_rewrite() + test_batch_matmul_rewrite() test_calibrate_target(False) test_calibrate_target(True) test_calibrate_memory_bound() + test_calibrate_percentile() test_add_partition() test_conv2d_partition() diff --git a/tests/python/topi/python/test_topi_batch_matmul.py b/tests/python/topi/python/test_topi_batch_matmul.py index 05f2c3029bc9d..e6edd53ba1158 100644 --- a/tests/python/topi/python/test_topi_batch_matmul.py +++ b/tests/python/topi/python/test_topi_batch_matmul.py @@ -24,6 +24,7 @@ from tvm.contrib.pickle_memoize import memoize import tvm.testing +from common import Int8Fallback _batch_matmul_implement = { "generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul), @@ -91,6 +92,45 @@ def check_device(target, dev): check_device(target, dev) +def verify_batch_matmul_int8(x_batch, y_batch, M, N, K): + dtype = "int8" + out_dtype = "int32" + assert x_batch == y_batch or x_batch == 1 or y_batch == 1 + x = te.placeholder((x_batch, M, K), name="x", dtype=dtype) + y = te.placeholder((y_batch, N, K), name="y", dtype=dtype) + + # use memoize to pickle the test data for next time use + @memoize("topi.tests.test_topi_batch_matmul") + def get_ref_data(): + a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype) + c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype=out_dtype) + return (a_np, b_np, c_np) + + # get the test data + a_np, b_np, c_np = get_ref_data() + + def check_device(device): + dev = tvm.device(device, 0) + if device == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version): + print("Skip because int8 intrinsics are not available") + return + + print("Running on target: %s" % device) + with tvm.target.Target(device): + out = topi.cuda.batch_matmul_int8(x, y, None, out_dtype) + s = topi.cuda.schedule_batch_matmul_int8([out]) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev) + f = tvm.build(s, [x, y, out], device, name="batch_matmul_int8") + f(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in ["cuda"]: + check_device(device) + + @tvm.testing.uses_gpu def test_batch_matmul(): verify_batch_matmul(1, 1, 16, 16, 32) @@ -106,5 +146,18 @@ def test_batch_matmul(): verify_batch_matmul(5, 5, 16, 16, 32, dynamic=True) +@tvm.testing.requires_cuda +@tvm.testing.requires_gpu +def test_batch_matmul_int8(): + with Int8Fallback(): + verify_batch_matmul_int8(1, 1, 2, 3, 1) + verify_batch_matmul_int8(1, 1, 16, 24, 32) + verify_batch_matmul_int8(5, 5, 24, 16, 32) + verify_batch_matmul_int8(30, 30, 16, 20, 32) + verify_batch_matmul_int8(1, 5, 16, 16, 32) + verify_batch_matmul_int8(5, 1, 16, 16, 32) + + if __name__ == "__main__": test_batch_matmul() + test_batch_matmul_int8()