From f5ce7084abd054d5a8ece4b51502bb3e5d48a526 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 30 Jun 2021 22:29:43 +0800 Subject: [PATCH] [Matmul] Add matmul op (#8234) * Add Matmul Op * Recover DenseAttrs * Add grad for matmul & some update * Update matmul cuda default schedule * Add blas support for matmul * Lint fix add update doc strings --- include/tvm/relay/attrs/nn.h | 26 ++++ python/tvm/relay/frontend/tensorflow.py | 19 ++- python/tvm/relay/frontend/tensorflow_ops.py | 20 ++- python/tvm/relay/op/_tensor_grad.py | 29 ++++ python/tvm/relay/op/nn/_nn.py | 63 +++++++- python/tvm/relay/op/nn/nn.py | 44 ++++++ python/tvm/relay/op/op_attrs.py | 5 + python/tvm/relay/op/strategy/cuda.py | 32 ++++ python/tvm/relay/op/strategy/generic.py | 36 +++++ python/tvm/relay/op/strategy/x86.py | 72 +++++++++ python/tvm/topi/cuda/dense.py | 50 +++++-- python/tvm/topi/generic/nn.py | 17 +++ python/tvm/topi/gpu/dense.py | 30 ++++ python/tvm/topi/nn/dense.py | 138 +++++++++++++++--- python/tvm/topi/x86/dense.py | 119 ++++++++++----- rust/tvm/src/ir/relay/attrs/nn.rs | 12 ++ src/relay/op/make_op.h | 3 + src/relay/op/nn/nn.cc | 40 ++++- src/relay/op/nn/nn.h | 72 +++++---- src/relay/qnn/op/dense.cc | 2 +- .../auto_scheduler_layout_rewrite.cc | 10 +- .../frontend/tensorflow/test_forward.py | 12 +- tests/python/relay/test_op_grad_level2.py | 17 +++ tests/python/relay/test_op_level1.py | 63 +++++++- tests/python/topi/python/test_topi_matmul.py | 26 ++++ 25 files changed, 842 insertions(+), 115 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index dc202674eb08..3c7574562676 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -961,6 +961,32 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for matmul operator */ +struct MatmulAttrs : public tvm::AttrsNode { + IndexExpr units; + DataType out_dtype; + bool transpose_a; + bool transpose_b; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + + TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") { + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(transpose_a) + .set_default(false) + .describe("Whether the first input tensor is in transposed format."); + + TVM_ATTR_FIELD(transpose_b) + .set_default(false) + .describe("Whether the second input tensor is in transposed format."); + } +}; + /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0bdec953a540..e297398ffe5b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -44,6 +44,16 @@ __all__ = ["from_tensorflow"] +# The default configurations of Relay TensorFlow frontend. +TF_DEFAULT_CONFIGS = { + # By default, TVM converts `tf.matmul` to `transpose(weight) + nn.dense`, which introduces + # unnecessary overhead in weight transpose. Change this flag to False to directly convert to + # `nn.matmul` to get rid of the overhead. + # However, please note that `nn.matmul` is in experimental so it may have some performance + # issues. + "use_dense": True, +} + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1204,7 +1214,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): return func, self._params -def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): +def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True): """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. @@ -1222,6 +1232,11 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): outputs : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. + use_dense_op : bool (Optional) = True + Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`. + The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be + transposed, may insert extra `transpose` to the original graph. + Returns ------- mod : tvm.IRModule @@ -1230,6 +1245,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): params : dict of str to tvm.nd.NDArray Dict of converted parameters stored in tvm.nd.NDArray format """ + global TF_DEFAULT_CONFIGS + TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op g = GraphProto() mod, params = g.from_tensorflow(graph, layout, shape, outputs) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index be15f83faf0f..004174f076fd 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1113,13 +1113,23 @@ def _impl(inputs, attr, params, mod): def _matmul(): def _impl(inputs, attr, params, mod): + from .tensorflow import TF_DEFAULT_CONFIGS + channels = _infer_channels(inputs[1], not attr["transpose_b"]) - if attr["transpose_a"]: - inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) - if not attr["transpose_b"]: - inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) + if TF_DEFAULT_CONFIGS["use_dense"]: + if attr["transpose_a"]: + inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) + if not attr["transpose_b"]: + inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) + return AttrCvt( + op_name="dense", + extras={"units": channels}, + ignores=["transpose_a", "transpose_b", "T"], + )(inputs, attr) return AttrCvt( - op_name="dense", extras={"units": channels}, ignores=["transpose_a", "transpose_b", "T"] + op_name="matmul", + extras={"units": channels}, + ignores=["T"], )(inputs, attr) return _impl diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 09b1435aac0f..fa2772c1299f 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -554,6 +554,35 @@ def dense_grad(orig, grad): ] +@register_gradient("nn.matmul") +def matmul_grad(orig, grad): + """Returns [grad' @ tensor_b, tensor_a @ grad']""" + tensor_a, tensor_b = orig.args + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True): + return [ + collapse_sum_like( + _nn.matmul(tensor_b, grad, transpose_a=True, transpose_b=True), tensor_a + ), + collapse_sum_like( + _nn.matmul(grad, tensor_a, transpose_a=True, transpose_b=True), tensor_b + ), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False): + return [ + collapse_sum_like(_nn.matmul(tensor_b, grad, transpose_b=True), tensor_a), + collapse_sum_like(_nn.matmul(tensor_a, grad), tensor_b), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True): + # Keep using Dense op here for not involving extra ops + # TODO(jcf94): Merge all to nn.matmul when it is finally ready + return dense_grad(orig, grad) + # (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False) + return [ + collapse_sum_like(_nn.matmul(grad, tensor_b, transpose_b=True), tensor_a), + collapse_sum_like(_nn.matmul(tensor_a, grad, transpose_a=True), tensor_b), + ] + + @register_gradient("nn.batch_matmul") def batch_matmul_grad(orig, grad): """gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 04d38ce39422..056cb5694a48 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -52,6 +52,32 @@ reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) +@reg.register_legalize("nn.matmul") +def legalize_matmul(attrs, inputs, types): + """Legalize matmul op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current matmul + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.matmul_legalize(attrs, inputs, types) + + +# matmul +reg.register_strategy("nn.matmul", strategy.matmul_strategy) +reg.register_pattern("nn.matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + @reg.register_legalize("nn.dense") def legalize_dense(attrs, inputs, types): """Legalize dense op. @@ -1160,21 +1186,44 @@ def batch_flatten_shape_func(attrs, inputs, _): @script -def _dense_shape_func(data_shape, weight_shape): - out = output_tensor((data_shape.shape[0],), "int64") +def _matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b): + out = output_tensor((tensor_a_shape.shape[0],), "int64") for i in const_range(out.shape[0] - 1): - out[i] = data_shape[i] - out[out.shape[0] - 1] = weight_shape[0] + out[i] = tensor_a_shape[i] + if transpose_a: + out[out.shape[0] - 2] = out[out.shape[0] - 1] + out[out.shape[0] - 1] = tensor_b_shape[0] if transpose_b else tensor_b_shape[1] return out +@reg.register_shape_func("nn.matmul", False) +def matmul_shape_func(attrs, inputs, _): + """Shape function for matmul op.""" + ret = [ + _matmul_shape_func( + inputs[0], + inputs[1], + expr.IntImm("bool", attrs.transpose_a), + expr.IntImm("bool", attrs.transpose_b), + ) + ] + return ret + + @reg.register_shape_func("nn.dense", False) def dense_shape_func(attrs, inputs, _): + """Shape function for dense op. This is an alias of matmul_nt operator for data tensor in + non-transposed format and weight tensor in transposed format. """ - Shape function for dense op. - """ - ret = [_dense_shape_func(inputs[0], inputs[1])] + ret = [ + _matmul_shape_func( + inputs[0], + inputs[1], + expr.IntImm("bool", False), + expr.IntImm("bool", True), + ) + ] return ret diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index bef899eeaaab..4c94102275bb 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1471,6 +1471,50 @@ def bias_add(data, bias, axis=1): return _make.bias_add(data, bias, axis) +def matmul(tensor_a, tensor_b, units=None, out_dtype="", transpose_a=False, transpose_b=False): + """Matmul operator. + Applies a linear transformation. The A & B can be transposed. + + .. math:: + + `C = A * B` + + Parameters + ---------- + data : tvm.relay.Expr + The first input of the operator, + of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., units_in, d_n)`. + + weight : tvm.relay.Expr + The second input expressions, 2-D matrix, + of shape `(units_in, units)` or `(units, units_in)`. + + units : Optional[int] + Number of hidden units of the matmul transformation. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision matmul, + of shape `(d_1, d_2, ..., d_n, units)`. + + transpose_a : Optional[bool] = False + Whether the data tensor is in transposed format. + + transpose_b : Optional[bool] = False + Whether the weight tensor is in transposed format. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + # Since currently `nn.dense` has better topi schedule support, will prefer to use `dense` + # rather than `matmul` for better compatibility + if not transpose_a and transpose_b: + # TODO(jcf94): Remove this when `nn.matmul` is finnaly ready + return dense(tensor_a, tensor_b, units, out_dtype) + return _make.matmul(tensor_a, tensor_b, units, out_dtype, transpose_a, transpose_b) + + def dense(data, weight, units=None, out_dtype=""): """Dense operator. Applies a linear transformation diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 74c4e2f1da49..780badc89fc4 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -64,6 +64,11 @@ class BiasAddAttrs(Attrs): """Atttribute of nn.bias_add""" +@tvm._ffi.register_object("relay.attrs.MatmulAttrs") +class MatmulAttrs(Attrs): + """Attributes for nn.matmul""" + + @tvm._ffi.register_object("relay.attrs.DenseAttrs") class DenseAttrs(Attrs): """Attributes for nn.dense""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 683f3ecdb22b..dd265e4b4d5b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -698,6 +698,38 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target): return strategy +@matmul_strategy.register(["cuda", "gpu"]) +def matmul_strategy_cuda(attrs, inputs, out_type, target): + """Matmul cuda strategy.""" + strategy = _op.OpStrategy() + + if is_auto_scheduler_enabled(): + strategy.add_implementation( + wrap_compute_matmul(topi.nn.matmul), + naive_schedule, + name="matmul.cuda", + ) + else: + logger.warning( + "Matmul is not optimized for cuda. Recommend to use cublas for better performance." + ) + # Temporary use this as a basic schedule + strategy.add_implementation( + wrap_compute_matmul(topi.gpu.matmul_default), + wrap_topi_schedule(topi.gpu.schedule_matmul_default), + name="matmul_default.gpu", + ) + + if target.kind.name == "cuda" and "cublas" in target.libs: + strategy.add_implementation( + wrap_compute_matmul(topi.cuda.matmul_cublas), + wrap_topi_schedule(topi.cuda.schedule_matmul_cublas), + name="matmul_cublas.cuda", + plevel=25, + ) + return strategy + + @dense_strategy.register(["cuda", "gpu"]) def dense_strategy_cuda(attrs, inputs, out_type, target): """dense cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index bf317a95b711..32799e5ac73c 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -712,6 +712,42 @@ def dilation2d_strategy(attrs, inputs, out_type, target): return strategy +# matmul +def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False): + """wrap matmul topi compute""" + + def _compute_matmul(attrs, inputs, out_type): + """Compute definition of matmul""" + out_dtype = attrs.out_dtype + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + args = [ + inputs[0], + inputs[1], + None, + out_dtype, + attrs.transpose_a, + attrs.transpose_b, + ] + if need_auto_scheduler_layout: + args.append(get_auto_scheduler_rewritten_layout(attrs)) + return [topi_compute(*args)] + + return _compute_matmul + + +@override_native_generic_func("matmul_strategy") +def matmul_strategy(attrs, inputs, out_type, target): + """matmul generic strategy""" + logger.warning("matmul is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_matmul(topi.nn.matmul), + wrap_topi_schedule(topi.generic.schedule_matmul), + name="matmul.generic", + ) + return strategy + + # dense def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False): """wrap dense topi compute""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index c21ec4d13906..d09d90a50d41 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -370,6 +370,78 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): return strategy +@matmul_strategy.register("cpu") +def matmul_strategy_cpu(attrs, inputs, out_type, target): + """matmul x86 strategy""" + strategy = _op.OpStrategy() + + same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype + dtype = inputs[0].dtype + u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" + if "cblas" in target.libs: + length_before = len(strategy.specializations) if strategy.specializations else 0 + with SpecializedCondition(same_type and dtype in ["float32", "float64"]): + strategy.add_implementation( + wrap_compute_matmul(topi.x86.matmul_cblas), + wrap_topi_schedule(topi.x86.schedule_matmul_cblas), + name="matmul_cblas.x86", + plevel=13, + ) + length_after = len(strategy.specializations) if strategy.specializations else 0 + if length_before == length_after: + logger.warning( + "Currently cblas only support the data type to be float32 or float64. Skip." + ) + if "mkl" in target.libs: + length_before = len(strategy.specializations) if strategy.specializations else 0 + with SpecializedCondition(same_type and dtype in ["float32", "float64"] or u8s8s32): + strategy.add_implementation( + wrap_compute_matmul(topi.x86.matmul_mkl), + wrap_topi_schedule(topi.x86.schedule_matmul_mkl), + name="matmul_mkl.x86", + plevel=14, + ) + length_after = len(strategy.specializations) if strategy.specializations else 0 + if length_before == length_after: + logger.warning( + "Currently mkl only support the data type to be float32, float64 or input with " + "uint8 and int8 while output wiht int32. Skip." + ) + if "mkldnn" in target.libs: + length_before = len(strategy.specializations) if strategy.specializations else 0 + with SpecializedCondition(same_type and dtype == "float32"): + strategy.add_implementation( + wrap_compute_matmul(topi.x86.matmul_mkldnn), + wrap_topi_schedule(topi.x86.schedule_matmul_mkldnn), + name="matmul_mkldnn.x86", + plevel=15, + ) + length_after = len(strategy.specializations) if strategy.specializations else 0 + if length_before == length_after: + logger.warning("Currently mkldnn only support the data type to be float32. Skip.") + + if is_auto_scheduler_enabled(): + strategy.add_implementation( + wrap_compute_matmul(topi.nn.matmul, need_auto_scheduler_layout=True), + naive_schedule, + name="matmul.generic", + plevel=11, + ) + else: + # If no cblas/mkl/mkldnn strategy choosed + if not strategy.specializations: + logger.warning( + "Matmul is not optimized for x86. " + "Recommend to use cblas/mkl/mkldnn for better performance." + ) + strategy.add_implementation( + wrap_compute_matmul(topi.nn.matmul), + naive_schedule, + name="matmul.generic", + ) + return strategy + + @dense_strategy.register("cpu") def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 0f410aef9afd..4035dce48aaf 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -28,18 +28,24 @@ logger = logging.getLogger("topi") -@autotvm.register_topi_compute("dense_cublas.cuda") -def dense_cublas(cfg, data, weight, bias=None, out_dtype=None): - """Dense operator on CUDA with CUBLAS""" - assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense" +def _matmul_cublas_common( + cfg, + tensor_a, + tensor_b, + bias=None, + out_dtype=None, + transpose_a=False, + transpose_b=False, +): + assert len(tensor_a.shape) == 2 and len(tensor_b.shape) == 2, "only support 2-dim matmul" if bias is not None: assert len(bias.shape) == 1 if out_dtype is None: - out_dtype = data.dtype - assert out_dtype == data.dtype, "Mixed precision not supported." - batch, in_dim = get_const_tuple(data.shape) - out_dim, _ = get_const_tuple(weight.shape) - matmul = cublas.matmul(data, weight, False, True) + out_dtype = tensor_a.dtype + assert out_dtype == tensor_a.dtype, "Mixed precision not supported." + batch, in_dim = get_const_tuple(tensor_a.shape) + out_dim, _ = get_const_tuple(tensor_b.shape) + matmul = cublas.matmul(tensor_a, tensor_b, transpose_a, transpose_b) if all(isinstance(d, int) for d in [batch, in_dim, out_dim]): cfg.add_flop(batch * in_dim * out_dim * 2) if bias is not None: @@ -49,6 +55,32 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None): return matmul +@autotvm.register_topi_compute("matmul_cublas.cuda") +def matmul_cublas( + cfg, + tensor_a, + tensor_b, + bias=None, + out_dtype=None, + transpose_a=False, + transpose_b=False, +): + """Matmul operator on CUDA with CUBLAS""" + return _matmul_cublas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b) + + +@autotvm.register_topi_schedule("matmul_cublas.cuda") +def schedule_matmul_cublas(_, outs): + """Schedule matmul operator using CUBLAS""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("dense_cublas.cuda") +def dense_cublas(cfg, data, weight, bias=None, out_dtype=None): + """Dense operator on CUDA with CUBLAS. This is an alias of matmul_nt operator.""" + return _matmul_cublas_common(cfg, data, weight, bias, out_dtype, False, True) + + @autotvm.register_topi_schedule("dense_cublas.cuda") def schedule_dense_cublas(_, outs): """Schedule dense operator using CUBLAS""" diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 04d649037fef..1b3214154687 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -580,6 +580,23 @@ def schedule_fast_softmax(outs): return _default_schedule(outs, False) +def schedule_matmul(outs): + """Schedule for matmul + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of matmul + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_dense(outs): """Schedule for dense diff --git a/python/tvm/topi/gpu/dense.py b/python/tvm/topi/gpu/dense.py index 806aa9f5ca44..b9009d3f3393 100644 --- a/python/tvm/topi/gpu/dense.py +++ b/python/tvm/topi/gpu/dense.py @@ -49,6 +49,36 @@ def _callback(op): return s +@autotvm.register_topi_compute("matmul_default.gpu") +def matmul_default( + cfg, + tensor_a, + tensor_b, + bias=None, + out_dtype=None, + transpose_a=False, + transpose_b=False, +): + """Matmul operator on GPU""" + return nn.matmul(tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b) + + +@autotvm.register_topi_schedule("matmul_default.gpu") +def schedule_matmul_default(cfg, outs): + """Schedule matmul on GPU""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "matmul": + # Temporary use this as a basic schedule for matmul + # TODO(jcf94): Add a more general schedule for matmul + _schedule_dense_small_batch(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + + def _schedule_dense_small_batch(cfg, s, C): A, weights = C.op.input_tensors _, in_dim_weights = get_const_tuple(weights.shape) diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index e8ec476b86a5..58c458a7d676 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -21,15 +21,23 @@ from .. import tag -def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""): - """The default implementation of dense in topi. +def matmul( + tensor_a, + tensor_b, + bias=None, + out_dtype=None, + transpose_a=False, + transpose_b=False, + auto_scheduler_rewritten_layout="", +): + """The default implementation of matmul in topi. Parameters ---------- - data : tvm.te.Tensor + tensor_a : tvm.te.Tensor 2-D with shape [batch, in_dim] - weight : tvm.te.Tensor + tensor_b : tvm.te.Tensor 2-D with shape [out_dim, in_dim] bias : Optional[tvm.te.Tensor] @@ -38,7 +46,13 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo out_dtype : Optional[str] The output type. This is used for mixed precision. - auto_scheduler_rewritten_layout: str = "" + transpose_a : Optional[bool] = False + Whether the tensor_a is in transposed format. + + transpose_b : Optional[bool] = False + Whether the tensor_b is in transposed format. + + auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. Returns @@ -46,42 +60,128 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ - assert len(data.shape) == 2, "only support 2-dim dense" + # TODO(jcf94): Add multi-dim support for tensor_a + assert len(tensor_a.shape) == 2, "only support 2-dim matmul" if bias is not None: assert len(bias.shape) == 1 if out_dtype is None: - out_dtype = data.dtype - batch, in_dim = data.shape + out_dtype = tensor_a.dtype + if transpose_a: + in_dim, batch = tensor_a.shape + else: + batch, in_dim = tensor_a.shape if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["j", "k"] ) - auto_scheduler.remove_index_check(weight) + auto_scheduler.remove_index_check(tensor_b) + elif transpose_b: + out_dim, red_dim = tensor_b.shape else: - out_dim, red_dim = weight.shape + red_dim, out_dim = tensor_b.shape assert in_dim == red_dim k = te.reduce_axis((0, in_dim), name="k") - matmul = te.compute( + if (transpose_a, transpose_b) == (True, True): + compute_lambda = lambda i, j: te.sum( + tensor_a[k, i].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k + ) + compute_name = "T_matmul_TT" + compute_tag = "matmul" + elif (transpose_a, transpose_b) == (True, False): + compute_lambda = lambda i, j: te.sum( + tensor_a[k, i].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k + ) + compute_name = "T_matmul_TN" + compute_tag = "matmul" + elif (transpose_a, transpose_b) == (False, True): + compute_lambda = lambda i, j: te.sum( + tensor_a[i, k].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k + ) + compute_name = "T_matmul_NT" + # TODO(jcf94): Remove `dense` when `matmul` is finally ready + compute_tag = "dense" + else: # (transpose_a, transpose_b) == (False, False): + compute_lambda = lambda i, j: te.sum( + tensor_a[i, k].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k + ) + compute_name = "T_matmul_NN" + compute_tag = "matmul" + + mat = te.compute( (batch, out_dim), - lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k), - name="T_dense", - tag="dense", - attrs={"layout_free_placeholders": [weight]}, + compute_lambda, + name=compute_name, + tag=compute_tag, + attrs={"layout_free_placeholders": [tensor_b]}, ) + if bias is not None: - matmul = te.compute( + mat = te.compute( (batch, out_dim), - lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), + lambda i, j: mat[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST, ) if auto_scheduler_rewritten_layout: - matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout) + mat = auto_scheduler.rewrite_compute_body(mat, auto_scheduler_rewritten_layout) + + return mat + + +@tvm.target.generic_func +def matmul_legalize(attrs, inputs, types): + """Legalizes matmul op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current matmul + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # not to change by default + # pylint: disable=unused-argument + return None + + +def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""): + """The default implementation of dense in topi. + This is an alias of matmul_nt operator for data tensor in non-transposed format and weight + tensor in transposed format. + + Parameters + ---------- + data : tvm.te.Tensor + 2-D with shape [batch, in_dim] - return matmul + weight : tvm.te.Tensor + 2-D with shape [out_dim, in_dim] + + bias : Optional[tvm.te.Tensor] + 1-D with shape [out_dim] + + out_dtype : Optional[str] + The output type. This is used for mixed precision. + + auto_scheduler_rewritten_layout: str = "" + The layout after auto-scheduler's layout rewrite pass. + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [batch, out_dim] + """ + return matmul(data, weight, bias, out_dtype, False, True, auto_scheduler_rewritten_layout) @tvm.target.generic_func diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 4fed4c16464e..189ac5bd34bd 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -28,7 +28,7 @@ from .utils import get_fp32_len from .injective import schedule_injective_from_existing -from .. import generic, tag +from .. import tag from ..utils import traverse_inline, get_const_tuple @@ -281,72 +281,121 @@ def _callback(op): return s -def dense_blas_common(cfg, data, weight, bias, out_dtype, lib): - """Compute dense using a BLAS library""" - M, K = get_const_tuple(data.shape) - N, _ = get_const_tuple(weight.shape) +def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib): + """Compute matmul/dense using a BLAS library""" + M, K = get_const_tuple(tensor_a.shape) + N, _ = get_const_tuple(tensor_b.shape) if isinstance(M, int) and isinstance(K, int) and isinstance(N, int): cfg.add_flop(M * K * N * 2) - if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype == "int32": + if tensor_a.dtype == "uint8" and tensor_b.dtype == "int8" and out_dtype == "int32": if not hasattr(lib, "matmul_u8s8s32"): raise NotImplementedError( - f"Dense with {lib.__name__} for {data.dtype} is not supported " + f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not supported " "(matmulu8s8s32 not imlemented)" ) - C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype) - elif data.dtype == "float32" or data.dtype == "float64": - C = lib.matmul(data, weight, False, True) + C = lib.matmul_u8s8s32(tensor_a, tensor_b, transpose_a, transpose_b, dtype=out_dtype) + elif tensor_a.dtype == "float32" or tensor_a.dtype == "float64": + C = lib.matmul(tensor_a, tensor_b, transpose_a, transpose_b) else: - raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype} is not supported") + raise NotImplementedError( + f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not supported" + ) if bias is not None: C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) return C +def schedule_matmul_blas_common(outs): + """Default matmul schedule for BLAS library""" + s = te.create_schedule([x.op for x in outs]) + te.schedule.AutoInlineInjective(s) + + for out in outs: + if "dense" not in out.op.tag and "matmul" not in out.op.tag: + schedule_injective_from_existing(s, out) + return s + + @autotvm.register_topi_compute("dense_cblas.x86") def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense using a cblas""" - return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas) + """Compute dense using cblas. This is an alias of matmul_nt operator.""" + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, cblas) @autotvm.register_topi_schedule("dense_cblas.x86") def schedule_dense_cblas(_, outs): - """Create schedule for dense_cblas""" - return generic.schedule_extern(outs) + """Create schedule for dense_cblas. This is an alias of matmul_nt operator.""" + return schedule_matmul_blas_common(outs) @autotvm.register_topi_compute("dense_mkl.x86") def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense using mkl""" - return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl) + """Compute dense using mkl. This is an alias of matmul_nt operator.""" + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkl) @autotvm.register_topi_schedule("dense_mkl.x86") def schedule_dense_mkl(_, outs): - """Create schedule for dense_mkl""" - # return generic.schedule_extern(outs) - s = te.create_schedule([x.op for x in outs]) - te.schedule.AutoInlineInjective(s) - - def _callback(op): - if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in op.tag: - schedule_injective_from_existing(s, op.output(0)) - - # traverse_inline(s, outs[0].op, _callback) - for out in outs: - if "dense" not in out.op.name: - schedule_injective_from_existing(s, out) - return s + """Create schedule for dense_mkl. This is an alias of matmul_nt operator.""" + return schedule_matmul_blas_common(outs) @autotvm.register_topi_compute("dense_mkldnn.x86") def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense using mkldnn""" - return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn) + """Compute dense using mkldnn. This is an alias of matmul_nt operator.""" + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkldnn) @autotvm.register_topi_schedule("dense_mkldnn.x86") def schedule_dense_mkldnn(_, outs): - """Create schedule for dense_mkldnn""" - return generic.schedule_extern(outs) + """Create schedule for dense_mkldnn. This is an alias of matmul_nt operator.""" + return schedule_matmul_blas_common(outs) + + +@autotvm.register_topi_compute("matmul_cblas.x86") +def matmul_cblas( + cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False +): + """Compute matmul using cblas.""" + return matmul_blas_common( + cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, cblas + ) + + +@autotvm.register_topi_schedule("matmul_cblas.x86") +def schedule_matmul_cblas(_, outs): + """Create schedule for matmul_cblas.""" + return schedule_matmul_blas_common(outs) + + +@autotvm.register_topi_compute("matmul_mkl.x86") +def matmul_mkl( + cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False +): + """Compute matmul using mkl.""" + return matmul_blas_common( + cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, mkl + ) + + +@autotvm.register_topi_schedule("matmul_mkl.x86") +def schedule_matmul_mkl(_, outs): + """Create schedule for matmul_mkl.""" + return schedule_matmul_blas_common(outs) + + +@autotvm.register_topi_compute("matmul_mkldnn.x86") +def matmul_mkldnn( + cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False +): + """Compute matmul using mkldnn.""" + return matmul_blas_common( + cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, mkldnn + ) + + +@autotvm.register_topi_schedule("matmul_mkldnn.x86") +def schedule_matmul_mkldnn(_, outs): + """Create schedule for matmul_mkldnn.""" + return schedule_matmul_blas_common(outs) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index f0137fa3cbcc..04320d1f6f85 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -54,6 +54,18 @@ pub struct BiasAddAttrsNode { pub axis: i32, } +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "MatmulAttrs"] +#[type_key = "relay.attrs.MatmulAttrs"] +pub struct MatmulAttrsNode { + pub base: BaseAttrsNode, + pub units: IndexExpr, + pub out_dtype: DataType, + pub transpose_a: bool, + pub transpose_b: bool, +} + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "DenseAttrs"] diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 81de4bc90ad7..6f4db5ab268a 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -44,6 +44,9 @@ Expr MakeClip(Expr a, double a_min, double a_max); Expr MakeConcatenate(Expr data, int axis); +Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType out_dtype, bool transpose_a, + bool transpose_b); + Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 489be15e1643..4eaa12b17d7b 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -162,7 +162,39 @@ Useful for .set_support_level(3) .add_type_rel("FIFOBuffer", FIFOBufferRel); -// relay.nn.dense +// ------------------- relay.nn.matmul +TVM_REGISTER_NODE_TYPE(MatmulAttrs); + +Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType out_dtype, bool transpose_a, + bool transpose_b) { + auto attrs = make_object(); + attrs->units = units; + attrs->out_dtype = out_dtype; + attrs->transpose_a = transpose_a; + attrs->transpose_b = transpose_b; + static const Op& matmul_op = Op::Get("nn.matmul"); + return Call(matmul_op, {tensor_a, tensor_b}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.matmul").set_body_typed(MakeMatmul); + +RELAY_REGISTER_OP("nn.matmul") + .describe(R"code(Applies a linear transformation: :math:`C = A * B`. A & B can be transposed. + +- **tensor_a**: `(x1, x2, ..., xn, input_dim)` or `(x1, x2, ..., input_dim, xn)` +- **tensor_b**: `(input_dim, units)` or `(units, input_dim)` +- **out**: `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("tensor_a", "nD Tensor", "The first input Tensor.") + .add_argument("tensor_b", "2D Tensor", "The second input Tensor.") + .set_support_level(1) + .add_type_rel("Matmul", MatmulRel); +// ------------------- relay.nn.matmul + +// ------------------- relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); // Positional relay function to create dense operator used by frontend FFI. @@ -189,9 +221,10 @@ RELAY_REGISTER_OP("nn.dense") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) - .add_type_rel("Dense", DenseRel); + .add_type_rel("Dense", MatmulRel); +// ------------------- relay.nn.dense -// relay.nn.contrib_dense_pack +// ------------------- relay.nn.contrib_dense_pack // Positional relay function to create dense_pack operator used by frontend FFI. Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); @@ -217,6 +250,7 @@ RELAY_REGISTER_OP("nn.contrib_dense_pack") .add_argument("weight", "3D Tensor", "Packed weight matrix.") .set_support_level(10) .add_type_rel("DensePack", DensePackRel); +// ------------------- relay.nn.contrib_dense_pack // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 1ac800f357b0..29f200c67c59 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -36,31 +36,44 @@ namespace tvm { namespace relay { template -bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; + const auto* tensor_a = types[0].as(); + const auto* tensor_b = types[1].as(); + if (tensor_a == nullptr) return false; + ICHECK(static_cast(tensor_a->shape.size()) != 0); const AttrType* param = attrs.as(); ICHECK(param != nullptr); + // Default set to dense layout + bool transpose_a = false; + bool transpose_b = true; + const auto& mattrs = attrs.as(); + if (mattrs != nullptr) { + transpose_a = mattrs->transpose_a; + transpose_b = mattrs->transpose_b; + } - ICHECK(static_cast(data->shape.size()) != 0); - - Array dshape = data->shape; + const Array& dshape = tensor_a->shape; Array oshape = dshape; + tvm::PrimExpr reduce = dshape[dshape.size() - 1]; + if (transpose_a) { + reduce = dshape[dshape.size() - 2]; + oshape.Set((oshape.size() - 2), dshape[oshape.size() - 1]); + } if (param->units.defined()) { - // validate the weight shape is proper if defined - // Assign weight type - Array wshape({param->units, dshape[dshape.size() - 1]}); - // It is possible for weight to be nullptr in which case we will use - // data dtype as the weight dtype. However if weight dtype is explicitly + // validate the tensor_b shape is proper if defined + // Assign tensor_b type + const Array& wshape = transpose_b ? Array({param->units, reduce}) + : Array({reduce, param->units}); + // It is possible for tensor_b to be nullptr in which case we will use + // data dtype as the tensor_b dtype. However if tensor_b dtype is explicitly // present we will use that. - auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype); + auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype); if (param->auto_scheduler_rewritten_layout.size() == 0) { // Normal case: assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype)); } else { // If the layout is rewritten by auto-scheduler, // we just forcly apply the layout provided by auto-scheduler and @@ -69,31 +82,32 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, } oshape.Set((oshape.size() - 1), param->units); } else { - if (weight == nullptr) return false; - Array wshape = weight->shape; - // When weight's layout has been rewritten, figure it out based on the + if (tensor_b == nullptr) return false; + const Array& wshape = tensor_b->shape; + // When tensor_b's layout has been rewritten, figure it out based on the // total number of elements and input dimensions. if (param->auto_scheduler_rewritten_layout.size() != 0) { - PrimExpr weight_elements = 1; + PrimExpr tensor_b_elements = 1; for (size_t i = 0; i < wshape.size(); i++) { - weight_elements = weight_elements * wshape[i]; + tensor_b_elements = tensor_b_elements * wshape[i]; } - oshape.Set(oshape.size() - 1, weight_elements / dshape[dshape.size() - 1]); - // Otherwise just pull it out of the weight shape directly. + oshape.Set(oshape.size() - 1, tensor_b_elements / dshape[dshape.size() - 1]); + // Otherwise just pull it out of the tensor_b shape directly. } else { - ICHECK(static_cast(weight->shape.size()) == 2); - if (!data->shape.back().as()) { - ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) - << "DenseRel: input dimension doesn't match," - << " data shape=" << data->shape << ", weight shape=" << weight->shape; + ICHECK(static_cast(tensor_b->shape.size()) == 2); + if (!tensor_a->shape.back().as()) { + ICHECK((transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[1])) || + (!transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[0]))) + << "MatmulRel: input dimension doesn't match," + << " tensor_a shape=" << tensor_a->shape << ", tensor_b shape=" << tensor_b->shape; } - oshape.Set((oshape.size() - 1), wshape[0]); + oshape.Set((oshape.size() - 1), transpose_b ? wshape[0] : wshape[1]); } } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { - out_dtype = data->dtype; + out_dtype = tensor_a->dtype; } // assign output type reporter->Assign(types[2], TensorType(oshape, out_dtype)); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 6284524bff27..592fa77aed77 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -70,7 +70,7 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Dense infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return DenseRel(tensor_types, 3, attrs, reporter); + return MatmulRel(tensor_types, 3, attrs, reporter); } // Positional relay function to create quantized dense operator used by frontend FFI. diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index edc4119ce859..da0bd35a332a 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -87,6 +87,8 @@ class FuncMutator : public ExprMutator { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); } else if (auto pattr = call->attrs.as()) { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); } else if (auto pattr = call->attrs.as()) { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); } else if (auto pattr = call->attrs.as()) { @@ -103,9 +105,9 @@ class FuncMutator : public ExprMutator { std::deque ori_layouts_queue_; std::deque new_layouts_queue_; - std::vector target_ops_{"nn.conv2d", "nn.conv3d", - "nn.contrib_conv2d_winograd_without_weight_transform", - "nn.dense", "nn.batch_matmul"}; + std::vector target_ops_{ + "nn.conv2d", "nn.conv3d", "nn.contrib_conv2d_winograd_without_weight_transform", + "nn.matmul", "nn.dense", "nn.batch_matmul"}; }; Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { @@ -166,6 +168,8 @@ TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout") return attrs.as()->auto_scheduler_rewritten_layout; } else if (attrs->IsInstance()) { return attrs.as()->auto_scheduler_rewritten_layout; + } else if (attrs->IsInstance()) { + return attrs.as()->auto_scheduler_rewritten_layout; } else if (attrs->IsInstance()) { return attrs.as()->auto_scheduler_rewritten_layout; } else if (attrs->IsInstance()) { diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 26c9278cb733..c942411471cd 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -117,6 +117,7 @@ def run_tvm_graph( disabled_pass=None, ignore_in_shape=False, serialize=False, + use_dense_op=True, ): """Generic function to compile on relay and execute on tvm""" input_data = convert_to_list(input_data) @@ -131,7 +132,11 @@ def run_tvm_graph( e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data) } mod, params = relay.frontend.from_tensorflow( - graph_def, layout=layout, shape=shape_dict, outputs=out_names + graph_def, + layout=layout, + shape=shape_dict, + outputs=out_names, + use_dense_op=use_dense_op, ) dev = tvm.device(target, 0) if mode == "debug": @@ -213,6 +218,7 @@ def compare_tf_with_tvm( add_shapes_to_graph_def=True, targets=None, ignore_in_shape=False, + use_dense_op=True, ): """Generic function to generate and compare tensorflow and TVM output""" @@ -260,6 +266,7 @@ def name_without_num(name): mode=mode, cuda_layout=cuda_layout, ignore_in_shape=ignore_in_shape, + use_dense_op=use_dense_op, ) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared @@ -1810,7 +1817,8 @@ def _test_matmul(i, j, k, dtype, outer=None): A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=True) + compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=False) def test_forward_matmul(): diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 686fd9834640..c8a94683eec4 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -199,6 +199,22 @@ def test_dense_grad(): verify_dense_grad((5, 4), (3, 4)) +def verify_matmul_grad(a_shape, b_shape, transpose_a, transpose_b): + tensor_a = relay.var("tensor_a", relay.TensorType(a_shape, "float32")) + tensor_b = relay.var("tensor_b", relay.TensorType(b_shape, "float32")) + fwd_func = relay.Function( + [tensor_a, tensor_b], + relay.nn.matmul(tensor_a, tensor_b, transpose_a=transpose_a, transpose_b=transpose_b), + ) + check_grad(fwd_func) + + +def test_matmul_grad(): + verify_matmul_grad((1, 8), (8, 16), False, False) + verify_matmul_grad((4, 1), (4, 3), True, False) + verify_matmul_grad((4, 5), (3, 4), True, True) + + def verify_batch_flatten_grad(d_shape): data = relay.var("data", relay.TensorType(d_shape, "float32")) fwd_func = relay.Function([data], relay.nn.batch_flatten(data)) @@ -216,4 +232,5 @@ def test_batch_flatten_grad(): test_global_avg_pool2d_grad() test_conv2d_grad() test_dense_grad() + test_matmul_grad() test_batch_flatten_grad() diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 89475ac7df86..cbc3e7fbd1e5 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -409,6 +409,66 @@ def test_batch_norm(): ) +@pytest.mark.xfail +def test_matmul_type_check(): + dtype = "float16" + n, c, h, w = 2, 2, 2, 2 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + # it should fail since it does not match with m(2) + mismatch_w = 3 + w = relay.var("w", relay.TensorType((mismatch_w, 2), dtype)) + y = relay.nn.matmul(x, w) + yy = run_infer_type(y) + + +@tvm.testing.uses_gpu +def test_matmul(): + for dtype in ["float16", "float32"]: + # Matmul accuracy for float16 is poor + if dtype == "float16": + continue + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + w = relay.var("w", relay.TensorType((2, w), dtype)) + y = relay.nn.matmul(x, w, units=2, transpose_b=True) + assert "units=2" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) + + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, w, h), dtype)) + wh, ww = te.size_var("wh"), te.size_var("ww") + w = relay.var("w", relay.TensorType((wh, ww), dtype)) + y = relay.nn.matmul(x, w, transpose_a=True) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype) + + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + w = relay.var("w", relay.IncompleteType()) + y = relay.nn.matmul(x, w, units=2) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) + + x = relay.var("x", shape=(5, 10), dtype=dtype) + w = relay.var("w", shape=(5, 2), dtype=dtype) + z = relay.nn.matmul(x, w, transpose_a=True) + + # Check result. + func = relay.Function([x, w], z) + x_data = np.random.rand(5, 10).astype(dtype) + w_data = np.random.rand(5, 2).astype(dtype) + ref_res = np.dot(x_data.transpose(), w_data) + + for target, dev in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", device=dev, target=target) + intrp2 = relay.create_executor("debug", device=dev, target=target) + op_res1 = intrp1.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) + op_res2 = intrp2.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) + + @pytest.mark.xfail def test_dense_type_check(): dtype = "float16" @@ -426,7 +486,7 @@ def test_dense(): for dtype in ["float16", "float32"]: # Dense accuracy for float16 is poor if dtype == "float16": - return + continue n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) w = relay.var("w", relay.TensorType((2, w), dtype)) @@ -506,6 +566,7 @@ def test_bitserial_dense(): test_log_softmax() test_dropout() test_batch_norm() + test_matmul() test_dense() test_bitserial_dense() test_dense_dtype() diff --git a/tests/python/topi/python/test_topi_matmul.py b/tests/python/topi/python/test_topi_matmul.py index e5a21a3ad3b7..de2d4d3c4c8e 100644 --- a/tests/python/topi/python/test_topi_matmul.py +++ b/tests/python/topi/python/test_topi_matmul.py @@ -41,6 +41,31 @@ def with_tvm(lam, *args): return out_nd.numpy() +def verify_nn_matmul(sa, sb, transp_a, transp_b): + a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32) + b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32) + c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if transp_b else b) + c2 = with_tvm( + lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, transpose_b=transp_b), + a, + b, + ) + tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5) + + +def test_nn_matmul(): + verify_nn_matmul((1, 1), (1, 1), False, False) + verify_nn_matmul((1, 1), (1, 1), True, True) + verify_nn_matmul((2, 2), (2, 2), False, False) + verify_nn_matmul((2, 2), (2, 2), True, True) + verify_nn_matmul((2, 3), (3, 5), False, False) + verify_nn_matmul((5, 3), (3, 2), False, False) + verify_nn_matmul((3, 5), (3, 2), True, False) + verify_nn_matmul((3, 5), (2, 3), True, True) + verify_nn_matmul((3, 5), (3, 2), True, False) + verify_nn_matmul((5, 3), (2, 3), False, True) + + def verify_matmul(sa, sb, transp_a, transp_b): a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32) b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32) @@ -79,5 +104,6 @@ def test_tensordot(): if __name__ == "__main__": + test_nn_matmul() test_matmul() test_tensordot()