From 10f0c114579bdfb6d3a1a9bd1f9fbb6c73bc06e7 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 21 Jun 2021 14:46:33 +0800 Subject: [PATCH] Update all "nn.dense" to "nn.matmul" Update grad for matmul op Recover DenseAttrs Bug fix for qnn.dense Code refine --- include/tvm/relay/attrs/nn.h | 19 +++++++ python/tvm/contrib/target/onnx.py | 2 +- python/tvm/relay/frontend/tensorflow_ops.py | 9 ++++ python/tvm/relay/op/_tensor_grad.py | 42 ++++++++++++--- .../tvm/relay/op/contrib/arm_compute_lib.py | 8 +-- python/tvm/relay/op/contrib/bnns.py | 6 +-- python/tvm/relay/op/contrib/dnnl.py | 2 +- python/tvm/relay/op/contrib/tensorrt.py | 4 +- python/tvm/relay/op/nn/nn.py | 14 +++-- python/tvm/relay/op/op_attrs.py | 8 +++ python/tvm/relay/op/strategy/bifrost.py | 6 ++- python/tvm/relay/op/strategy/cuda.py | 8 +-- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/op/strategy/hls.py | 6 ++- python/tvm/relay/op/strategy/mali.py | 6 ++- python/tvm/relay/op/strategy/rocm.py | 4 +- python/tvm/relay/op/strategy/x86.py | 11 ++-- python/tvm/relay/qnn/op/legalizations.py | 18 ++++++- python/tvm/relay/quantize/_annotate.py | 2 +- python/tvm/relay/transform/mixed_precision.py | 2 +- python/tvm/topi/cuda/tensorcore_alter_op.py | 1 + python/tvm/topi/rocm/dense.py | 8 +-- python/tvm/topi/x86/dense.py | 51 ++++++++++--------- python/tvm/topi/x86/dense_alter_op.py | 2 +- rust/tvm/src/ir/relay/attrs/nn.rs | 10 ++++ src/relay/analysis/mac_count.cc | 34 +++++++++---- .../contrib/arm_compute_lib/codegen.cc | 4 +- src/relay/backend/contrib/bnns/codegen.cc | 4 +- src/relay/backend/contrib/dnnl/codegen.cc | 2 +- src/relay/op/nn/nn.cc | 1 + src/relay/op/nn/nn.h | 10 +++- src/relay/qnn/op/dense.cc | 14 ++--- src/relay/quantize/realize.cc | 2 +- .../auto_scheduler_layout_rewrite.cc | 10 ++-- .../transforms/combine_parallel_dense.cc | 8 +-- src/relay/transforms/convert_sparse_dense.cc | 4 +- src/relay/transforms/simplify_fc_transpose.cc | 4 +- src/relay/transforms/to_mixed_precision.cc | 2 + .../contrib/arm_compute_lib/acl_runtime.cc | 2 +- src/runtime/contrib/bnns/bnns_json_runtime.cc | 2 +- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 2 +- src/runtime/contrib/tensorrt/tensorrt_ops.cc | 2 +- .../test_arm_compute_lib/test_dense.py | 8 ++- tests/python/contrib/test_bnns/test_dense.py | 4 +- tests/python/frontend/pytorch/test_forward.py | 12 +++-- .../relay/test_autotvm_task_extraction.py | 4 +- tests/python/relay/test_dataflow_pattern.py | 2 +- tests/python/relay/test_ir_parser.py | 2 +- tests/python/relay/test_layer_count.py | 2 +- tests/python/relay/test_op_grad_level2.py | 17 +++++++ .../python/relay/test_pass_alter_op_layout.py | 2 +- tests/python/relay/test_pass_auto_quantize.py | 2 +- tests/python/relay/test_pass_gradient.py | 2 +- .../relay/test_pass_legalize_tensorcore.py | 4 +- 54 files changed, 291 insertions(+), 128 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index bac82b2f8373..8824f13d6a0a 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -987,6 +987,25 @@ struct MatmulAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Attributes for dense operator. + * \note This attr is still used by `qnn.dense`. TODO: Rewrite the `qnn.dense` to `qnn.matmul`. + */ +struct DenseAttrs : public tvm::AttrsNode { + IndexExpr units; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + DataType out_dtype; + + TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + /*! \brief Attributes for batch matmul operator */ struct BatchMatmulAttrs : public tvm::AttrsNode { tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index a38bcf5bcefa..d60a67f49d8a 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -623,7 +623,7 @@ def convert_attributes(cls, attrs): "add": rename("Add"), "nn.relu": rename("Relu"), "transpose": Transpose, - "nn.dense": MatMul, + "nn.matmul": MatMul, "nn.max_pool2d": MaxPool, "nn.batch_flatten": Flatten, "multiply": rename("Mul"), diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 45486c163cbf..e6fe0a699b4a 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1121,6 +1121,15 @@ def _impl(inputs, attr, params, mod): inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) if not attr["transpose_b"]: inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) + return AttrCvt( + op_name="matmul", + extras={ + "units": channels, + "data_transposed": False, + "weight_transposed": True, + }, + ignores=["transpose_a", "transpose_b", "T"], + )(inputs, attr) return AttrCvt( op_name="matmul", extras={ diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index d5b891088933..f54b1c6202c0 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -543,17 +543,43 @@ def bias_add_grad(orig, grad): ] -@register_gradient("nn.dense") -def dense_grad(orig, grad): +@register_gradient("nn.matmul") +def matmul_grad(orig, grad): """Returns [grad' @ weight, data @ grad']""" data, weight = orig.args + if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (True, True): + return [ + collapse_sum_like( + _nn.matmul(weight, grad, data_transposed=True, weight_transposed=True), data + ), + collapse_sum_like( + _nn.matmul(grad, data, data_transposed=True, weight_transposed=True), weight + ), + ] + if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (True, False): + return [ + collapse_sum_like(_nn.matmul(weight, grad, weight_transposed=True), data), + collapse_sum_like(_nn.matmul(data, grad), weight), + ] + if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (False, True): + # Keep using Dense op here for not involving extra ops + # TODO: Merge all to nn.matmul when it is finally ready + return [ + collapse_sum_like( + _nn.dense(grad, transpose(weight), units=weight.checked_type.shape[1]), + # _nn.matmul(grad, weight), data) + data, + ), + collapse_sum_like( + _nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]), + # _nn.matmul(grad, data, data_transposed=True) + weight, + ), + ] + # (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (False, False) return [ - collapse_sum_like( - _nn.dense(grad, transpose(weight), units=weight.checked_type.shape[1]), data - ), - collapse_sum_like( - _nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]), weight - ), + collapse_sum_like(_nn.matmul(grad, weight, weight_transposed=True), data), + collapse_sum_like(_nn.matmul(data, grad, data_transposed=True), weight), ] diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 9f3c1cdec0f7..0f4f4b37cbb7 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -172,7 +172,7 @@ def dense_pattern(): pattern : dataflow_pattern.AltPattern Denotes the convolution pattern. """ - pattern = is_op("nn.dense")(wildcard(), is_constant()) + pattern = is_op("nn.matmul")(wildcard(), is_constant()) pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) return pattern @@ -237,9 +237,9 @@ def check_qnn_conv(extract): return qnn_conv2d(call) def check_dense(extract): - """Check conv pattern is supported by ACL.""" + """Check dense pattern is supported by ACL.""" call = extract - while call.op.name != "nn.dense": + while call.op.name != "nn.matmul": call = call.args[0] return dense(call) @@ -368,7 +368,7 @@ def depthwise_conv2d(attrs, args): return True -@tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib") +@tvm.ir.register_op_attr("nn.matmul", "target.arm_compute_lib") def dense(expr): """Check if the external ACL codegen for dense should be used.""" attrs, args = expr.attrs, expr.args diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index 2ace502e6528..c2c77e3b85e6 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -191,7 +191,7 @@ def bias_check(expr): return False -@tvm.ir.register_op_attr("nn.dense", "target.bnns") +@tvm.ir.register_op_attr("nn.matmul", "target.bnns") def dense(expr): """Check if the dense can be used in BNNS.""" attrs, args = expr.attrs, expr.args @@ -239,7 +239,7 @@ def make_dense_bias_pattern(): data = wildcard() weight = wildcard() bias = wildcard() - d = is_op("nn.dense")(data, weight) + d = is_op("nn.matmul")(data, weight) return is_op("add")(d, bias) @@ -263,7 +263,7 @@ def make_dense_bias_gelu_pattern(): def check_dense(extract): """Check dense pattern is supported by BNNS.""" call = extract - while call.op.name != "nn.dense": + while call.op.name != "nn.matmul": call = call.args[0] return dense(call) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 79bd02db164b..7364107a905b 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -61,7 +61,7 @@ def _func_wrapper(expr): _register_external_op_helper("nn.batch_norm") _register_external_op_helper("nn.conv2d") -_register_external_op_helper("nn.dense") +_register_external_op_helper("nn.matmul") _register_external_op_helper("nn.relu") _register_external_op_helper("add") _register_external_op_helper("subtract") diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cbe6a22f4a4d..7ffa869f031e 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -393,7 +393,7 @@ def conv2d_annotate_fn(expr): # pylint: disable=unused-variable return True -@_register_external_dynamic_check_func("nn.dense") +@_register_external_dynamic_check_func("nn.matmul") def dense_annotate_fn(expr): # pylint: disable=unused-variable """Check if dense is supported by TensorRT.""" @@ -915,7 +915,7 @@ def visit_call(self, call): "nn.conv2d_transpose", "nn.conv3d", "nn.conv3d_transpose", - "nn.dense", + "nn.matmul", "nn.batch_matmul", "sum", "prod", diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b15c3dbd34dd..43c939fb57bc 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1473,11 +1473,13 @@ def bias_add(data, bias, axis=1): def matmul(data, weight, units=None, out_dtype="", data_transposed=False, weight_transposed=False): """Matmul operator. - Applies a linear transformation. The X & W can be transposed. + Applies a linear transformation: .. math:: - `Y = X * W` + `Y = X * W` + + The X & W can be transposed. Parameters ---------- @@ -1512,11 +1514,15 @@ def matmul(data, weight, units=None, out_dtype="", data_transposed=False, weight def dense(data, weight, units=None, out_dtype=""): """Dense operator. - Applies a linear transformation + Applies a linear transformation: .. math:: - `Y = X * W^T` + `Y = X * W^T` + + .. note:: + This is an alias of `nn.matmul` when data is in non-transposed format and weight is + transposed. Parameters ---------- diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index ffea2cb39157..1bd8d54786d9 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -69,6 +69,14 @@ class MatmulAttrs(Attrs): """Attributes for nn.matmul""" +@tvm._ffi.register_object("relay.attrs.DenseAttrs") +class DenseAttrs(Attrs): + """Attributes for nn.dense + + This is still used by qnn.dense. TODO: Rewrite `qnn.dense` to `qnn.matmul`. + """ + + @tvm._ffi.register_object("relay.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index 56c634e6af68..1e6e68002bfe 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -110,7 +110,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out def dense_strategy_bifrost(attrs, inputs, out_type, target): - """dense mali(bifrost) strategy""" + """Dense mali(bifrost) strategy. + This is a specialized case for Matmul with data non-transposed and weight transposed. + """ strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_dense(topi.bifrost.dense), @@ -122,7 +124,7 @@ def dense_strategy_bifrost(attrs, inputs, out_type, target): @matmul_strategy.register("bifrost") def matmul_strategy_bifrost(attrs, inputs, out_type, target): - """matmul mali(bifrost) strategy""" + """Matmul mali(bifrost) strategy""" if not attrs.data_transposed and attrs.weight_transposed: # Specialized schedule for dense(matmul-NT) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index f4300357e0f4..0357cb44685e 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -699,7 +699,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target): def dense_strategy_cuda(attrs, inputs, out_type, target): - """dense cuda strategy""" + """Dense cuda strategy. + This is a specialized case for Matmul with data non-transposed and weight transposed. + """ strategy = _op.OpStrategy() data, weights = inputs b, i = get_const_tuple(data.shape) @@ -759,13 +761,13 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): @matmul_strategy.register(["cuda", "gpu"]) def matmul_strategy_cuda(attrs, inputs, out_type, target): - """matmul cuda strategy""" + """Matmul cuda strategy""" if not attrs.data_transposed and attrs.weight_transposed: # Specialized schedule for dense(matmul-NT) strategy = dense_strategy_cuda(attrs, inputs, out_type, target) else: - logger.warning("Matmul other than NT format is not optimized for x86.") + logger.warning("Matmul other than NT format is not optimized for cuda.") strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_matmul(topi.nn.matmul), diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 5e52f1d13c11..128012656b83 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -752,7 +752,7 @@ def _compute_matmul(attrs, inputs, out_type): @override_native_generic_func("matmul_strategy") def matmul_strategy(attrs, inputs, out_type, target): - """matmul generic strategy""" + """Matmul generic strategy""" logger.warning("matmul is not optimized for this platform.") strategy = _op.OpStrategy() strategy.add_implementation( diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py index 961448fa033d..bb51973d8ef2 100644 --- a/python/tvm/relay/op/strategy/hls.py +++ b/python/tvm/relay/op/strategy/hls.py @@ -159,7 +159,9 @@ def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target): def dense_strategy_hls(attrs, inputs, out_type, target): - """dense hls strategy""" + """Dense hls strategy. + This is a specialized case for Matmul with data non-transposed and weight transposed. + """ strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_dense(topi.nn.dense), @@ -171,7 +173,7 @@ def dense_strategy_hls(attrs, inputs, out_type, target): @matmul_strategy.register("hls") def matmul_strategy_hls(attrs, inputs, out_type, target): - """matmul hls strategy""" + """Matmul hls strategy""" if not attrs.data_transposed and attrs.weight_transposed: # Specialized schedule for dense(matmul-NT) diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index 92e94e195d8f..c4cb52b9522e 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -173,7 +173,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty def dense_strategy_mali(attrs, inputs, out_type, target): - """dense mali strategy""" + """Dense mali strategy. + This is a specialized case for Matmul with data non-transposed and weight transposed. + """ strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_dense(topi.mali.dense), @@ -185,7 +187,7 @@ def dense_strategy_mali(attrs, inputs, out_type, target): @matmul_strategy.register("mali") def matmul_strategy_mali(attrs, inputs, out_type, target): - """dense mali strategy""" + """Matmul mali strategy""" if not attrs.data_transposed and attrs.weight_transposed: # Specialized schedule for dense(matmul-NT) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index d249e4e8a5f2..54cdd979af7f 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -183,7 +183,9 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): def dense_strategy_rocm(attrs, inputs, out_type, target): - """Dense strategy for ROCM""" + """Dense strategy for ROCM. + This is a specialized case for Matmul with data non-transposed and weight transposed. + """ assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense" strategy = _op.OpStrategy() strategy.add_implementation( diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index eb69c6466eab..7f1a8139c059 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -371,8 +371,8 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): - """Dense x86 strategy. This is a specialized case for Matmul with data non-transposed and - weight transposed. + """Dense x86 strategy. + This is a specialized case for Matmul with data non-transposed and weight transposed. """ strategy = _op.OpStrategy() @@ -394,7 +394,7 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): @matmul_strategy.register("cpu") def matmul_strategy_cpu(attrs, inputs, out_type, target): - """Matmul x86 strategy.""" + """Matmul x86 strategy""" if not attrs.data_transposed and attrs.weight_transposed: # Specialized schedule for dense(matmul-NT) @@ -402,6 +402,11 @@ def matmul_strategy_cpu(attrs, inputs, out_type, target): else: logger.warning("Matmul other than NT format is not optimized for x86.") strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_matmul(topi.nn.matmul), + naive_schedule, + name="matmul.generic", + ) same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype dtype = inputs[0].dtype diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 3c4d2ddcd0ec..768637250aec 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -308,7 +308,14 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types): # ARM prefers the dtypes to be same. if is_fast_int8_on_arm(): return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) - return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) + mattrs = tvm.ir.make_node( + "relay.attrs.MatmulAttrs", + units=attrs["units"], + out_dtype=attrs["out_dtype"], + data_transposed=False, + weight_transposed=True, + ) # TODO: Rewrite qnn.dense to qnn.matmul and remove this transformation + return helper_no_fast_int8_hw_legalization(mattrs, inputs, types, relay.nn.matmul) ########################## @@ -329,7 +336,14 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): # The VNNI transformations prefer uint8 x int8 datatypes. if is_fast_int8_on_intel(): return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense) - return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) + mattrs = tvm.ir.make_node( + "relay.attrs.MatmulAttrs", + units=attrs["units"], + out_dtype=attrs["out_dtype"], + data_transposed=False, + weight_transposed=True, + ) # TODO: Rewrite qnn.dense to qnn.matmul and remove this transformation + return helper_no_fast_int8_hw_legalization(mattrs, inputs, types, relay.nn.matmul) ##################### diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index ff673d23144a..3c814a43633e 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -197,7 +197,7 @@ def conv1d_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) -@register_annotate_function("nn.dense") +@register_annotate_function("nn.matmul") def dense_rewrite(ref_call, new_args, ctx): """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of dense will be quantized to weight field. Output would be in activation field.""" diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 6aa3ac09cfee..8b7744a4ef2d 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -39,7 +39,7 @@ "nn.conv1d_transpose", "nn.conv2d_transpose", "nn.conv3d_transpose", - "nn.dense", + "nn.matmul", # "nn.batch_matmul", # Handled by a special case ] DEFAULT_FOLLOW_LIST = [ diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 24bbb4332ddd..5e4bf0fe6e6c 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -99,6 +99,7 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): def _dense_legalize(attrs, inputs, arg_types): """Legalizes dense op. + This is a specialized case for Matmul with data non-transposed and weight transposed. Parameters ---------- diff --git a/python/tvm/topi/rocm/dense.py b/python/tvm/topi/rocm/dense.py index 590805d47f9c..ab6c124549e7 100644 --- a/python/tvm/topi/rocm/dense.py +++ b/python/tvm/topi/rocm/dense.py @@ -102,13 +102,7 @@ def _callback(op): @autotvm.register_topi_compute("matmul_rocblas.rocm") def matmul_rocblas( - cfg, - data, - weight, - bias=None, - out_dtype=None, - data_transposed=False, - weight_transposed=False + cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False ): """Matmul operator for rocm backend with cblas. diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 53a89cc334c1..46b3cd7f7d16 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -306,17 +306,18 @@ def matmul_blas_common(cfg, data, weight, bias, out_dtype, lib, data_transposed, @autotvm.register_topi_compute("matmul_cblas.x86") def matmul_cblas( - cfg, - data, - weight, - bias=None, - out_dtype=None, - data_transposed=False, - weight_transposed=False + cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False ): """Compute matmul using a cblas""" return matmul_blas_common( - cfg, data, weight, bias, out_dtype, cblas, data_transposed, weight_transposed, + cfg, + data, + weight, + bias, + out_dtype, + cblas, + data_transposed, + weight_transposed, ) @@ -328,17 +329,18 @@ def schedule_matmul_cblas(_, outs): @autotvm.register_topi_compute("matmul_mkl.x86") def matmul_mkl( - cfg, - data, - weight, - bias=None, - out_dtype=None, - data_transposed=False, - weight_transposed=False + cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False ): """Compute matmul using mkl""" return matmul_blas_common( - cfg, data, weight, bias, out_dtype, mkl, data_transposed, weight_transposed, + cfg, + data, + weight, + bias, + out_dtype, + mkl, + data_transposed, + weight_transposed, ) @@ -362,17 +364,18 @@ def _callback(op): @autotvm.register_topi_compute("matmul_mkldnn.x86") def matmul_mkldnn( - cfg, - data, - weight, - bias=None, - out_dtype=None, - data_transposed=False, - weight_transposed=False + cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False ): """Compute matmul using mkldnn""" return matmul_blas_common( - cfg, data, weight, bias, out_dtype, mkldnn, data_transposed, weight_transposed, + cfg, + data, + weight, + bias, + out_dtype, + mkldnn, + data_transposed, + weight_transposed, ) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index fd81eefec6f0..ad4bb500e545 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -35,7 +35,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): N, _ = get_const_tuple(weight_tensor.shape) impl, outs = relay.backend.compile_engine.select_implementation( - relay.op.get("nn.dense"), attrs, tinfos, out_type, target + relay.op.get("nn.matmul"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) if workload: diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index e77972e45f86..6a703aa51476 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -54,6 +54,16 @@ pub struct BiasAddAttrsNode { pub axis: i32, } +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "DenseAttrs"] +#[type_key = "relay.attrs.DenseAttrs"] +pub struct DenseAttrsNode { + pub base: BaseAttrsNode, + pub units: IndexExpr, + pub out_dtype: DataType, +} + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "MatmulAttrs"] diff --git a/src/relay/analysis/mac_count.cc b/src/relay/analysis/mac_count.cc index 29edf55812cc..2006e9da932e 100644 --- a/src/relay/analysis/mac_count.cc +++ b/src/relay/analysis/mac_count.cc @@ -119,26 +119,38 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) { return count; } -int64_t DenseMacCount(const Call& call_node) { +int64_t MatmulMacCount(const Call& call_node) { if (!call_node->checked_type_.defined()) { LOG(WARNING) << "The infer type pass should be called before the mac count pass"; return 0; } Array args = call_node->args; - ICHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2."; + ICHECK_EQ(args.size(), 2) << "The number of input arguments of a Matmul node should be 2."; const auto* data_type = args[0]->checked_type().as(); const auto* weight_type = args[1]->checked_type().as(); Array data_shape = data_type->shape; Array weight_shape = weight_type->shape; ICHECK(data_shape.size() == 2 && weight_shape.size() == 2) - << "The dimension of an input tensor to Dense node should be 2."; - int64_t d1 = static_cast(data_shape[0].as()->value); - int64_t d2 = static_cast(data_shape[1].as()->value); - int64_t d3 = static_cast(weight_shape[0].as()->value); - int64_t d4 = static_cast(weight_shape[1].as()->value); - ICHECK_EQ(d2, d4) << "The dimensions of input arguments do not match."; - int64_t count = d1 * d2 * d3; - return count; + << "The dimension of an input tensor to Matmul node should be 2."; + const auto& mattr = call_node->attrs.as(); + ICHECK(mattr != nullptr); + int64_t data_m, data_k, weight_k, weight_n; + if (mattr->data_transposed) { + data_k = static_cast(data_shape[0].as()->value); + data_m = static_cast(data_shape[1].as()->value); + } else { + data_m = static_cast(data_shape[0].as()->value); + data_k = static_cast(data_shape[1].as()->value); + } + if (mattr->weight_transposed) { + weight_n = static_cast(weight_shape[0].as()->value); + weight_k = static_cast(weight_shape[1].as()->value); + } else { + weight_k = static_cast(weight_shape[0].as()->value); + weight_n = static_cast(weight_shape[1].as()->value); + } + ICHECK_EQ(data_k, weight_k) << "The dimensions of input arguments do not match."; + return data_m * data_k * weight_n; } int64_t BatchMatmulMacCount(const Call& call_node) { @@ -161,7 +173,7 @@ RELAY_REGISTER_OP("nn.conv2d").set_attr("FMacCount", ConvMacCount); RELAY_REGISTER_OP("nn.conv2d_transpose").set_attr("FMacCount", Conv2dTransposeMacCount); -RELAY_REGISTER_OP("nn.dense").set_attr("FMacCount", DenseMacCount); +RELAY_REGISTER_OP("nn.matmul").set_attr("FMacCount", MatmulMacCount); RELAY_REGISTER_OP("nn.batch_matmul").set_attr("FMacCount", BatchMatmulMacCount); diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 8098c8d51274..2ecacb998627 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -247,7 +247,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { if (nodes.requantize) { ICHECK(backend::IsOp(current_call, "qnn.dense")); } else { - ICHECK(backend::IsOp(current_call, "nn.dense")); + ICHECK(backend::IsOp(current_call, "nn.matmul")); } nodes.dense = current_call; return nodes; @@ -261,7 +261,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { */ std::shared_ptr CreateCompositeDenseJSONNode(const CallNode* cn) { CompositeDenseNode nodes = UnpackCompositeDense(cn); - std::string name = "nn.dense"; + std::string name = "nn.matmul"; // Inputs must be added in the same order they appear in the relay graph. std::vector inputs; diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 72c32fb5b19e..92868de8ccd3 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -102,9 +102,9 @@ class BNNSJSONSerializer : public backend::contrib::JSONSerializer { call = GetRootCall(body, 1, {"nn.conv2d", "sigmoid"}); ICHECK(call->op.as()) << "Not op node"; } else if (name == "bnns.dense_bias") { - call = GetRootCall(fn->body.as(), 1, {"nn.dense", "add"}); + call = GetRootCall(fn->body.as(), 1, {"nn.matmul", "add"}); } else if (name == "bnns.dense_bias_gelu") { - call = FindCallWithName(fn->body.as(), 10, "nn.dense"); + call = FindCallWithName(fn->body.as(), 10, "nn.matmul"); } else { LOG(FATAL) << "Unrecognized BNNS pattern: " << name; } diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index e96255e976e9..f17d504959c6 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -238,7 +238,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C using ArgFunType = std::function(const CallNode*)>; static const std::map> op_map = { {"nn.conv2d", {"dnnl_conv2d", Conv2d}}, - {"nn.dense", {"dnnl_dense", Dense}}, + {"nn.matmul", {"dnnl_dense", Dense}}, {"nn.relu", {"dnnl_relu", Relu}}, {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, {"add", {"dnnl_add", Add}}, diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 442b31510888..1cf02c95a43f 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -164,6 +164,7 @@ Useful for // ------------------- relay.nn.matmul TVM_REGISTER_NODE_TYPE(MatmulAttrs); +TVM_REGISTER_NODE_TYPE(DenseAttrs); // Used by qnn.dense. TODO: Rewrite qnn.dense to qnn.matmul Expr MakeMatmul(Expr data, Expr weight, IndexExpr units, DataType out_dtype, bool data_transposed, bool weight_transposed) { diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 552fac3631fb..046cb234e079 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -48,10 +48,16 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(static_cast(data->shape.size()) != 0); + // Default set to dense + bool data_transposed = false; + bool weight_transposed = true; + const MatmulAttrs* mattrs = attrs.as(); + if (mattrs != nullptr) { + data_transposed = mattrs->data_transposed; + weight_transposed = mattrs->weight_transposed; + } const Array& dshape = data->shape; Array oshape = dshape; - bool data_transposed = param->data_transposed; - bool weight_transposed = param->weight_transposed; tvm::PrimExpr reduce = dshape[dshape.size() - 1]; if (data_transposed) { reduce = dshape[dshape.size() - 2]; diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index aa64df25cb2e..592fa77aed77 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -45,8 +45,8 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr || weight == nullptr) return false; - const auto* param = attrs.as(); - ICHECK(param != nullptr) << "MatmulAttrs cannot be nullptr."; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr."; ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) @@ -70,13 +70,13 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Dense infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return MatmulRel(tensor_types, 3, attrs, reporter); + return MatmulRel(tensor_types, 3, attrs, reporter); } // Positional relay function to create quantized dense operator used by frontend FFI. Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, Expr input_scale, Expr kernel_scale, IndexExpr units, DataType out_dtype) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->units = std::move(units); attrs->out_dtype = out_dtype; static const Op& op = Op::Get("qnn.dense"); @@ -85,7 +85,7 @@ Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kern } Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, - const MatmulAttrs* attrs) { + const DenseAttrs* attrs) { return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); } @@ -161,7 +161,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const auto in_shape = get_shape(arg_types[0]); const int reduction_dim_size = get_const_int(in_shape[1]); - const auto* qnn_dense_attrs = attrs.as(); + const auto* qnn_dense_attrs = attrs.as(); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); @@ -204,7 +204,7 @@ RELAY_REGISTER_OP("qnn.dense") - **weight**: quantized(int8, unit8) `(units, input_dim)` - **out**: quantized(int32) `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(6) .add_argument("data", "quantized nD Tensor", "Input data.") .add_argument("weight", "quantized 2D Tensor", "Weight matrix.") diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 24c2ff22d2fc..44366897583d 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -292,7 +292,7 @@ Expr DenseRealize(const Call& ref_call, const Array& new_args, const Objec return QRealizeIntExpr(ret, dom_scale, out_dtype); } -RELAY_REGISTER_OP("nn.dense").set_attr("FQRealizeRewrite", DenseRealize); +RELAY_REGISTER_OP("nn.matmul").set_attr("FQRealizeRewrite", DenseRealize); Expr MulRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 20341277033c..b468e60c054f 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -87,6 +87,8 @@ class FuncMutator : public ExprMutator { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); } else if (auto pattr = call->attrs.as()) { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); } else if (auto pattr = call->attrs.as()) { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); } else if (auto pattr = call->attrs.as()) { @@ -103,9 +105,9 @@ class FuncMutator : public ExprMutator { std::deque ori_layouts_queue_; std::deque new_layouts_queue_; - std::vector target_ops_{ - "nn.conv2d", "nn.conv3d", "nn.contrib_conv2d_winograd_without_weight_transform", - "nn.matmul", "nn.batch_matmul"}; + std::vector target_ops_{"nn.conv2d", "nn.conv3d", + "nn.contrib_conv2d_winograd_without_weight_transform", + "nn.matmul", "nn.batch_matmul"}; }; Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { @@ -166,6 +168,8 @@ TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout") return attrs.as()->auto_scheduler_rewritten_layout; } else if (attrs->IsInstance()) { return attrs.as()->auto_scheduler_rewritten_layout; + } else if (attrs->IsInstance()) { + return attrs.as()->auto_scheduler_rewritten_layout; } else if (attrs->IsInstance()) { return attrs.as()->auto_scheduler_rewritten_layout; } else if (attrs->IsInstance()) { diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 8901526e0a8f..c46f83b46056 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -54,7 +54,7 @@ namespace relay { class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { public: explicit ParallelDenseToBatchCombiner(uint64_t min_num_branches) - : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {} + : ParallelOpBatchCombiner("nn.matmul", "nn.batch_matmul", min_num_branches) {} protected: Call MakeCombinedOp(const Group& branches) { @@ -96,7 +96,7 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { class ParallelDenseToDenseCombiner : public ParallelOpCombiner { public: explicit ParallelDenseToDenseCombiner(uint64_t min_num_branches) - : ParallelOpCombiner("nn.dense", min_num_branches) {} + : ParallelOpCombiner("nn.matmul", min_num_branches) {} protected: bool IsSupportedOp(const CallNode* n) { return true; } @@ -113,7 +113,7 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner { } Call MakeCombinedOp(const Group& branches) { - const Op& dense_op = Op::Get("nn.dense"); + const Op& dense_op = Op::Get("nn.matmul"); Expr input = branches[0][0]->args[0]; Expr new_weight; IndexExpr new_output_dims; @@ -124,6 +124,8 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner { const auto dense_attrs = make_object(); dense_attrs->units = new_output_dims; dense_attrs->out_dtype = origin_attrs->out_dtype; + dense_attrs->data_transposed = origin_attrs->data_transposed; + dense_attrs->weight_transposed = origin_attrs->weight_transposed; return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {}); } diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc index 26a4d487196d..ba1dae6b2f31 100644 --- a/src/relay/transforms/convert_sparse_dense.cc +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -40,7 +40,7 @@ namespace relay { // Search dense op weight name from Expr class DenseOpWeightVisitor : private ExprVisitor { public: - DenseOpWeightVisitor() : dense_op_(Op::Get("nn.dense")) {} + DenseOpWeightVisitor() : dense_op_(Op::Get("nn.matmul")) {} Array Search(const Expr& expr) { VisitExpr(expr); @@ -74,7 +74,7 @@ class DenseToSparseDenseMutator : public ExprRewriter { public: DenseToSparseDenseMutator(const Array& weight_name, const Array >& weight_shape) - : dense_op_(Op::Get("nn.dense")), sparse_dense_op_(Op::Get("nn.sparse_dense")) { + : dense_op_(Op::Get("nn.matmul")), sparse_dense_op_(Op::Get("nn.sparse_dense")) { ICHECK_EQ(weight_name.size(), weight_shape.size()); for (size_t i = 0; i < weight_name.size(); ++i) { ICHECK(weight_name[i]->IsInstance()); diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc index b5090e7e6fe4..028871bd711d 100644 --- a/src/relay/transforms/simplify_fc_transpose.cc +++ b/src/relay/transforms/simplify_fc_transpose.cc @@ -41,7 +41,7 @@ namespace relay { // Find name of weight in ```y = nn.dense(x, tranpose(w, [1, 0]))``` class FCTransposeVisitor : private ExprVisitor { public: - FCTransposeVisitor() : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {} + FCTransposeVisitor() : dense_op_(Op::Get("nn.matmul")), transpose_op_(Op::Get("transpose")) {} Array Search(const Expr& expr) { VisitExpr(expr); @@ -79,7 +79,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.search_fc_transpose").set_body_typed(SearchF class FCTransposeMutator : public ExprRewriter { public: explicit FCTransposeMutator(const Array& target_weights) - : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) { + : dense_op_(Op::Get("nn.matmul")), transpose_op_(Op::Get("transpose")) { for (size_t i = 0; i < target_weights.size(); ++i) { ICHECK(target_weights[i]->IsInstance()); std::string k = target_weights[i].as()->data; diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index ae10c937ff1c..1156daf72c8f 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -134,6 +134,8 @@ class MixedPrecisionPass : public MixedModeMutator { return ModifyAttrsOutputDType(attrs, accumulation_dtype); } else if (auto attrs = cur_attrs.as()) { return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); } else if (auto attrs = cur_attrs.as()) { return ModifyAttrsOutputDType(attrs, accumulation_dtype); } diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 5bbc536afaca..1d616bab4e33 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -135,7 +135,7 @@ class ACLRuntime : public JSONRuntimeBase { } else if ("nn.depthwise_conv2d" == op_name || "qnn.depthwise_conv2d" == op_name) { CreateDepthwiseConvolution2DLayer(&layer_, node, mm); num_pools++; - } else if ("nn.dense" == op_name || "qnn.dense" == op_name) { + } else if ("nn.matmul" == op_name || "qnn.dense" == op_name) { CreateFullyConnectedLayer(&layer_, node, mm); num_pools++; } else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name || diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 87b01567cd30..e9afd9c7126f 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -176,7 +176,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { Conv2d(nid, true, "sigmoid"); } else if ("bnns.conv2d_bias" == op_name) { Conv2d(nid, true); - } else if ("nn.dense" == op_name) { + } else if ("nn.matmul" == op_name) { Dense(nid); } else if ("bnns.dense_bias" == op_name) { Dense(nid, true); diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index eef67a702d9c..1f60272da3dc 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -106,7 +106,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Conv2d(nid, true, false); } else if ("dnnl.conv2d_bias_relu" == op_name) { Conv2d(nid, true, true); - } else if ("nn.dense" == op_name) { + } else if ("nn.matmul" == op_name) { Dense(nid); } else if ("nn.batch_norm" == op_name) { BatchNorm(nid); diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 7197172d73db..f18d5e2020cc 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -1194,7 +1194,7 @@ GetOpConverters() { map->emplace("nn.layer_norm", std::make_shared()); map->emplace("nn.softmax", std::make_shared()); map->emplace("nn.conv2d", std::make_shared()); - map->emplace("nn.dense", std::make_shared()); + map->emplace("nn.matmul", std::make_shared()); map->emplace("nn.bias_add", std::make_shared()); map->emplace("add", std::make_shared()); map->emplace("subtract", std::make_shared()); diff --git a/tests/python/contrib/test_arm_compute_lib/test_dense.py b/tests/python/contrib/test_arm_compute_lib/test_dense.py index e6620a4bc1cb..e909063527b7 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_dense.py +++ b/tests/python/contrib/test_arm_compute_lib/test_dense.py @@ -115,7 +115,7 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False): node = { "op": "kernel", - "name": "nn.dense", + "name": "nn.matmul", "inputs": [], "attrs": { "num_outputs": "1", @@ -123,6 +123,8 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False): "shape": [[list(output_shape)]], "dtype": [[dtype]], "units": [[str(units)]], + "data_transposed": [["0"]], + "weight_transposed": [["1"]], }, } @@ -138,6 +140,10 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False): # qnn.dense params, input and kernel if dtype == "uint8": node["name"] = "qnn.dense" + # These two attrs are not in DenseAttrs + # TODO: Rewrite qnn.dense to qnn.matmul, so it can use MatmulAttrs + node["attrs"].pop("data_transposed") + node["attrs"].pop("weight_transposed") for param_dtype in ["int32", "float32"]: for _ in range(2): inputs.append( diff --git a/tests/python/contrib/test_bnns/test_dense.py b/tests/python/contrib/test_bnns/test_dense.py index c2cf9bf71373..61c44b79badc 100644 --- a/tests/python/contrib/test_bnns/test_dense.py +++ b/tests/python/contrib/test_bnns/test_dense.py @@ -61,7 +61,7 @@ def _get_model(shape, weight_shape, units, dtype, var_names, has_bias=False, has def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False, has_gelu=False): output_shape = (shape[0], units) - name = "nn.dense" + name = "nn.matmul" if has_bias is True: name = "bnns.dense_bias" if has_bias is True and has_gelu is True: @@ -77,6 +77,8 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False, has "shape": [[list(output_shape)]], "dtype": [[dtype]], "units": [[str(units)]], + "data_transposed": [["0"]], + "weight_transposed": [["1"]], }, } diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index be4d74ed205a..860b61b8e531 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3375,7 +3375,9 @@ def forward(self, *args): # matrix x matrix tensor1 = torch.randn(10, 4) tensor2 = torch.randn(4, 10) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.matmul"] + ) # batched matrix x batched matrix tensor1 = torch.randn(10, 3, 4) @@ -3387,12 +3389,16 @@ def forward(self, *args): # batched matrix x broadcasted matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4, 5) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.matmul"] + ) # broadcasted matrix x batched matrix tensor1 = torch.randn(10, 4) tensor2 = torch.randn(3, 4, 5) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.matmul"] + ) # batched matrix x batched matrix tensor1 = torch.randn(1, 12, 14, 64) diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index b3f1868969cc..7277f0158843 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -46,7 +46,7 @@ def test_task_extraction(): conv2d = relay.op.get("nn.conv2d") conv3d = relay.op.get("nn.conv3d") conv2d_transpose = relay.op.get("nn.conv2d_transpose") - dense = relay.op.get("nn.dense") + dense = relay.op.get("nn.matmul") mod, params, _ = get_network("resnet-18", batch_size=1) tasks = autotvm.task.extract_from_program( @@ -104,7 +104,7 @@ def test_task_extraction(): def test_task_extraction_for_dense_int8_cuda(): target = "cuda" - dense = relay.op.get("nn.dense") + dense = relay.op.get("nn.matmul") def get_net(batch, in_dim, out_dim, dtype, out_dtype): data = tvm.relay.var("data", shape=[batch, in_dim], dtype=dtype) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 1c721f40d129..c2d26b0b894e 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -444,7 +444,7 @@ def test_match_op_attr(): def test_no_match_op_attr(): - op = is_op("nn.dense").has_attr({"TOpPattern": K_ELEMWISE}) + op = is_op("nn.matmul").has_attr({"TOpPattern": K_ELEMWISE}) op_pat = op(wildcard(), wildcard()) x = relay.var("x") y = relay.var("y") diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 099e127aeba9..894bae892aae 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -939,7 +939,7 @@ def get_func(shape, dtype): program = """ def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) { %2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) { - %0 = nn.dense(%data, %weight, units=None); + %0 = nn.matmul(%data, %weight, units=None, data_transposed=False, weight_transposed=True); %1 = nn.relu(%0); add(%1, 1f) }; diff --git a/tests/python/relay/test_layer_count.py b/tests/python/relay/test_layer_count.py index f680bb2725f2..db89479b1266 100644 --- a/tests/python/relay/test_layer_count.py +++ b/tests/python/relay/test_layer_count.py @@ -23,7 +23,7 @@ def verify(num_layers): # Load a resnet with a known number of layers. mod, _ = resnet.get_workload(num_layers=num_layers) # Count the number of conv and dense layers. - count = count_layers(mod, valid_ops=["nn.conv2d", "nn.dense"]) + count = count_layers(mod, valid_ops=["nn.conv2d", "nn.matmul"]) assert count == num_layers verify(18) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 686fd9834640..b9dff0464f07 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -199,6 +199,22 @@ def test_dense_grad(): verify_dense_grad((5, 4), (3, 4)) +def verify_matmul_grad(d_shape, w_shape, d_transposed, w_transposed): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + weight = relay.var("weight", relay.TensorType(w_shape, "float32")) + fwd_func = relay.Function( + [data, weight], + relay.nn.matmul(data, weight, data_transposed=d_transposed, weight_transposed=w_transposed), + ) + check_grad(fwd_func) + + +def test_matmul_grad(): + verify_matmul_grad((1, 8), (8, 16), False, False) + verify_matmul_grad((4, 1), (4, 3), True, False) + verify_matmul_grad((4, 5), (3, 4), True, True) + + def verify_batch_flatten_grad(d_shape): data = relay.var("data", relay.TensorType(d_shape, "float32")) fwd_func = relay.Function([data], relay.nn.batch_flatten(data)) @@ -216,4 +232,5 @@ def test_batch_flatten_grad(): test_global_avg_pool2d_grad() test_conv2d_grad() test_dense_grad() + test_matmul_grad() test_batch_flatten_grad() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 5c2793c607a9..a97769cbd8dd 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1323,7 +1323,7 @@ def expected(): for target, _ in tvm.testing.enabled_targets(): with tvm.target.Target(target): with TempOpAttr( - "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout + "nn.matmul", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_matmul_layout ): a = before() a = run_opt_pass(a, transform.AlterOpLayout()) diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 51b9f5f24d1d..52d54daffcce 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -390,7 +390,7 @@ def test_dense_conv2d_rewrite(): def _check_dense(node): if isinstance(node, Call): - if node.op.name == "nn.dense": + if node.op.name == "nn.matmul": assert node.args[0].checked_type.dtype == "int8" assert node.args[1].checked_type.dtype == "int8" assert node.checked_type.dtype == "int32" diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 6228c5fc157b..24a3c971c4e5 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -426,7 +426,7 @@ def test_no_duplication(): gr = tvm.relay.transform.gradient(fn, mode="first_order") counts = count_ops(gr) - assert counts["nn.dense"] == 3, "We expect 3 dense (1 forward, two backward)" + assert counts["nn.matmul"] == 3, "We expect 3 dense (1 forward, two backward)" def test_global_function(): diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index 1312b396fe4c..affd520ff072 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -222,7 +222,7 @@ def before(): def legalize_dense(attrs, inputs, types): with tvm.target.Target("cuda"): - return topi.nn.dense_legalize(attrs, inputs, types) + return topi.nn.matmul_legalize(attrs, inputs, types) def expected(): if not do_pad: @@ -248,7 +248,7 @@ def expected(): y = relay.Function([x, weight], y) return y - with TempOpAttr("nn.dense", "FTVMLegalize", legalize_dense): + with TempOpAttr("nn.matmul", "FTVMLegalize", legalize_dense): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType())