Skip to content

Commit

Permalink
Update all "nn.dense" to "nn.matmul"
Browse files Browse the repository at this point in the history
Update grad for matmul op

Recover DenseAttrs

Bug fix for qnn.dense

Code refine
  • Loading branch information
jcf94 committed Jun 22, 2021
1 parent 6822e97 commit 10f0c11
Show file tree
Hide file tree
Showing 54 changed files with 291 additions and 128 deletions.
19 changes: 19 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,25 @@ struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
}
};

/*!
* \brief Attributes for dense operator.
* \note This attr is still used by `qnn.dense`. TODO: Rewrite the `qnn.dense` to `qnn.matmul`.
*/
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes for batch matmul operator */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/target/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def convert_attributes(cls, attrs):
"add": rename("Add"),
"nn.relu": rename("Relu"),
"transpose": Transpose,
"nn.dense": MatMul,
"nn.matmul": MatMul,
"nn.max_pool2d": MaxPool,
"nn.batch_flatten": Flatten,
"multiply": rename("Mul"),
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,15 @@ def _impl(inputs, attr, params, mod):
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr["transpose_b"]:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
return AttrCvt(
op_name="matmul",
extras={
"units": channels,
"data_transposed": False,
"weight_transposed": True,
},
ignores=["transpose_a", "transpose_b", "T"],
)(inputs, attr)
return AttrCvt(
op_name="matmul",
extras={
Expand Down
42 changes: 34 additions & 8 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,17 +543,43 @@ def bias_add_grad(orig, grad):
]


@register_gradient("nn.dense")
def dense_grad(orig, grad):
@register_gradient("nn.matmul")
def matmul_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (True, True):
return [
collapse_sum_like(
_nn.matmul(weight, grad, data_transposed=True, weight_transposed=True), data
),
collapse_sum_like(
_nn.matmul(grad, data, data_transposed=True, weight_transposed=True), weight
),
]
if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (True, False):
return [
collapse_sum_like(_nn.matmul(weight, grad, weight_transposed=True), data),
collapse_sum_like(_nn.matmul(data, grad), weight),
]
if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (False, True):
# Keep using Dense op here for not involving extra ops
# TODO: Merge all to nn.matmul when it is finally ready
return [
collapse_sum_like(
_nn.dense(grad, transpose(weight), units=weight.checked_type.shape[1]),
# _nn.matmul(grad, weight), data)
data,
),
collapse_sum_like(
_nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]),
# _nn.matmul(grad, data, data_transposed=True)
weight,
),
]
# (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (False, False)
return [
collapse_sum_like(
_nn.dense(grad, transpose(weight), units=weight.checked_type.shape[1]), data
),
collapse_sum_like(
_nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]), weight
),
collapse_sum_like(_nn.matmul(grad, weight, weight_transposed=True), data),
collapse_sum_like(_nn.matmul(data, grad, data_transposed=True), weight),
]


Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def dense_pattern():
pattern : dataflow_pattern.AltPattern
Denotes the convolution pattern.
"""
pattern = is_op("nn.dense")(wildcard(), is_constant())
pattern = is_op("nn.matmul")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
return pattern

Expand Down Expand Up @@ -237,9 +237,9 @@ def check_qnn_conv(extract):
return qnn_conv2d(call)

def check_dense(extract):
"""Check conv pattern is supported by ACL."""
"""Check dense pattern is supported by ACL."""
call = extract
while call.op.name != "nn.dense":
while call.op.name != "nn.matmul":
call = call.args[0]
return dense(call)

Expand Down Expand Up @@ -368,7 +368,7 @@ def depthwise_conv2d(attrs, args):
return True


@tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib")
@tvm.ir.register_op_attr("nn.matmul", "target.arm_compute_lib")
def dense(expr):
"""Check if the external ACL codegen for dense should be used."""
attrs, args = expr.attrs, expr.args
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/contrib/bnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def bias_check(expr):
return False


@tvm.ir.register_op_attr("nn.dense", "target.bnns")
@tvm.ir.register_op_attr("nn.matmul", "target.bnns")
def dense(expr):
"""Check if the dense can be used in BNNS."""
attrs, args = expr.attrs, expr.args
Expand Down Expand Up @@ -239,7 +239,7 @@ def make_dense_bias_pattern():
data = wildcard()
weight = wildcard()
bias = wildcard()
d = is_op("nn.dense")(data, weight)
d = is_op("nn.matmul")(data, weight)
return is_op("add")(d, bias)


Expand All @@ -263,7 +263,7 @@ def make_dense_bias_gelu_pattern():
def check_dense(extract):
"""Check dense pattern is supported by BNNS."""
call = extract
while call.op.name != "nn.dense":
while call.op.name != "nn.matmul":
call = call.args[0]
return dense(call)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _func_wrapper(expr):

_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.matmul")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
return True


@_register_external_dynamic_check_func("nn.dense")
@_register_external_dynamic_check_func("nn.matmul")
def dense_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if dense is supported by TensorRT."""

Expand Down Expand Up @@ -915,7 +915,7 @@ def visit_call(self, call):
"nn.conv2d_transpose",
"nn.conv3d",
"nn.conv3d_transpose",
"nn.dense",
"nn.matmul",
"nn.batch_matmul",
"sum",
"prod",
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,11 +1473,13 @@ def bias_add(data, bias, axis=1):

def matmul(data, weight, units=None, out_dtype="", data_transposed=False, weight_transposed=False):
"""Matmul operator.
Applies a linear transformation. The X & W can be transposed.
Applies a linear transformation:
.. math::
`Y = X * W`
`Y = X * W`
The X & W can be transposed.
Parameters
----------
Expand Down Expand Up @@ -1512,11 +1514,15 @@ def matmul(data, weight, units=None, out_dtype="", data_transposed=False, weight

def dense(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
Applies a linear transformation:
.. math::
`Y = X * W^T`
`Y = X * W^T`
.. note::
This is an alias of `nn.matmul` when data is in non-transposed format and weight is
transposed.
Parameters
----------
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ class MatmulAttrs(Attrs):
"""Attributes for nn.matmul"""


@tvm._ffi.register_object("relay.attrs.DenseAttrs")
class DenseAttrs(Attrs):
"""Attributes for nn.dense
This is still used by qnn.dense. TODO: Rewrite `qnn.dense` to `qnn.matmul`.
"""


@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs")
class SoftmaxAttrs(Attrs):
"""Attributes for nn.softmax"""
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/strategy/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out


def dense_strategy_bifrost(attrs, inputs, out_type, target):
"""dense mali(bifrost) strategy"""
"""Dense mali(bifrost) strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.bifrost.dense),
Expand All @@ -122,7 +124,7 @@ def dense_strategy_bifrost(attrs, inputs, out_type, target):

@matmul_strategy.register("bifrost")
def matmul_strategy_bifrost(attrs, inputs, out_type, target):
"""matmul mali(bifrost) strategy"""
"""Matmul mali(bifrost) strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):


def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
"""Dense cuda strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
data, weights = inputs
b, i = get_const_tuple(data.shape)
Expand Down Expand Up @@ -759,13 +761,13 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):

@matmul_strategy.register(["cuda", "gpu"])
def matmul_strategy_cuda(attrs, inputs, out_type, target):
"""matmul cuda strategy"""
"""Matmul cuda strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
strategy = dense_strategy_cuda(attrs, inputs, out_type, target)
else:
logger.warning("Matmul other than NT format is not optimized for x86.")
logger.warning("Matmul other than NT format is not optimized for cuda.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def _compute_matmul(attrs, inputs, out_type):

@override_native_generic_func("matmul_strategy")
def matmul_strategy(attrs, inputs, out_type, target):
"""matmul generic strategy"""
"""Matmul generic strategy"""
logger.warning("matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/strategy/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):


def dense_strategy_hls(attrs, inputs, out_type, target):
"""dense hls strategy"""
"""Dense hls strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense),
Expand All @@ -171,7 +173,7 @@ def dense_strategy_hls(attrs, inputs, out_type, target):

@matmul_strategy.register("hls")
def matmul_strategy_hls(attrs, inputs, out_type, target):
"""matmul hls strategy"""
"""Matmul hls strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/strategy/mali.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty


def dense_strategy_mali(attrs, inputs, out_type, target):
"""dense mali strategy"""
"""Dense mali strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.mali.dense),
Expand All @@ -185,7 +187,7 @@ def dense_strategy_mali(attrs, inputs, out_type, target):

@matmul_strategy.register("mali")
def matmul_strategy_mali(attrs, inputs, out_type, target):
"""dense mali strategy"""
"""Matmul mali strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):


def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
"""Dense strategy for ROCM.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
strategy = _op.OpStrategy()
strategy.add_implementation(
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):


def dense_strategy_cpu(attrs, inputs, out_type, target):
"""Dense x86 strategy. This is a specialized case for Matmul with data non-transposed and
weight transposed.
"""Dense x86 strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""

strategy = _op.OpStrategy()
Expand All @@ -394,14 +394,19 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):

@matmul_strategy.register("cpu")
def matmul_strategy_cpu(attrs, inputs, out_type, target):
"""Matmul x86 strategy."""
"""Matmul x86 strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
strategy = dense_strategy_cpu(attrs, inputs, out_type, target)
else:
logger.warning("Matmul other than NT format is not optimized for x86.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
name="matmul.generic",
)

same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
dtype = inputs[0].dtype
Expand Down
Loading

0 comments on commit 10f0c11

Please sign in to comment.