Skip to content

Commit

Permalink
[Matmul] Add matmul op (apache#8234)
Browse files Browse the repository at this point in the history
* Add Matmul Op

* Recover DenseAttrs

* Add grad for matmul & some update

* Update matmul cuda default schedule

* Add blas support for matmul

* Lint fix add update doc strings
  • Loading branch information
jcf94 authored and ylc committed Sep 29, 2021
1 parent a1f73da commit 8423279
Show file tree
Hide file tree
Showing 25 changed files with 842 additions and 115 deletions.
26 changes: 26 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,32 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
}
};

/*! \brief Attributes for matmul operator */
struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
IndexExpr units;
DataType out_dtype;
bool transpose_a;
bool transpose_b;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite

TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");

TVM_ATTR_FIELD(transpose_a)
.set_default(false)
.describe("Whether the first input tensor is in transposed format.");

TVM_ATTR_FIELD(transpose_b)
.set_default(false)
.describe("Whether the second input tensor is in transposed format.");
}
};

/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
Expand Down
19 changes: 18 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@

__all__ = ["from_tensorflow"]

# The default configurations of Relay TensorFlow frontend.
TF_DEFAULT_CONFIGS = {
# By default, TVM converts `tf.matmul` to `transpose(weight) + nn.dense`, which introduces
# unnecessary overhead in weight transpose. Change this flag to False to directly convert to
# `nn.matmul` to get rid of the overhead.
# However, please note that `nn.matmul` is in experimental so it may have some performance
# issues.
"use_dense": True,
}

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1204,7 +1214,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
return func, self._params


def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.
Expand All @@ -1222,6 +1232,11 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
use_dense_op : bool (Optional) = True
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be
transposed, may insert extra `transpose` to the original graph.
Returns
-------
mod : tvm.IRModule
Expand All @@ -1230,6 +1245,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
global TF_DEFAULT_CONFIGS
TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op

g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
Expand Down
20 changes: 15 additions & 5 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,13 +1113,23 @@ def _impl(inputs, attr, params, mod):

def _matmul():
def _impl(inputs, attr, params, mod):
from .tensorflow import TF_DEFAULT_CONFIGS

channels = _infer_channels(inputs[1], not attr["transpose_b"])
if attr["transpose_a"]:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr["transpose_b"]:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
if TF_DEFAULT_CONFIGS["use_dense"]:
if attr["transpose_a"]:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr["transpose_b"]:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
return AttrCvt(
op_name="dense",
extras={"units": channels},
ignores=["transpose_a", "transpose_b", "T"],
)(inputs, attr)
return AttrCvt(
op_name="dense", extras={"units": channels}, ignores=["transpose_a", "transpose_b", "T"]
op_name="matmul",
extras={"units": channels},
ignores=["T"],
)(inputs, attr)

return _impl
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,35 @@ def dense_grad(orig, grad):
]


@register_gradient("nn.matmul")
def matmul_grad(orig, grad):
"""Returns [grad' @ tensor_b, tensor_a @ grad']"""
tensor_a, tensor_b = orig.args
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
return [
collapse_sum_like(
_nn.matmul(tensor_b, grad, transpose_a=True, transpose_b=True), tensor_a
),
collapse_sum_like(
_nn.matmul(grad, tensor_a, transpose_a=True, transpose_b=True), tensor_b
),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
return [
collapse_sum_like(_nn.matmul(tensor_b, grad, transpose_b=True), tensor_a),
collapse_sum_like(_nn.matmul(tensor_a, grad), tensor_b),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
# Keep using Dense op here for not involving extra ops
# TODO(jcf94): Merge all to nn.matmul when it is finally ready
return dense_grad(orig, grad)
# (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
return [
collapse_sum_like(_nn.matmul(grad, tensor_b, transpose_b=True), tensor_a),
collapse_sum_like(_nn.matmul(tensor_a, grad, transpose_a=True), tensor_b),
]


@register_gradient("nn.batch_matmul")
def batch_matmul_grad(orig, grad):
"""gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
Expand Down
63 changes: 56 additions & 7 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


@reg.register_legalize("nn.matmul")
def legalize_matmul(attrs, inputs, types):
"""Legalize matmul op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current matmul
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.matmul_legalize(attrs, inputs, types)


# matmul
reg.register_strategy("nn.matmul", strategy.matmul_strategy)
reg.register_pattern("nn.matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_legalize("nn.dense")
def legalize_dense(attrs, inputs, types):
"""Legalize dense op.
Expand Down Expand Up @@ -1160,21 +1186,44 @@ def batch_flatten_shape_func(attrs, inputs, _):


@script
def _dense_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
def _matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b):
out = output_tensor((tensor_a_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0]
out[i] = tensor_a_shape[i]
if transpose_a:
out[out.shape[0] - 2] = out[out.shape[0] - 1]
out[out.shape[0] - 1] = tensor_b_shape[0] if transpose_b else tensor_b_shape[1]

return out


@reg.register_shape_func("nn.matmul", False)
def matmul_shape_func(attrs, inputs, _):
"""Shape function for matmul op."""
ret = [
_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", attrs.transpose_a),
expr.IntImm("bool", attrs.transpose_b),
)
]
return ret


@reg.register_shape_func("nn.dense", False)
def dense_shape_func(attrs, inputs, _):
"""Shape function for dense op. This is an alias of matmul_nt operator for data tensor in
non-transposed format and weight tensor in transposed format.
"""
Shape function for dense op.
"""
ret = [_dense_shape_func(inputs[0], inputs[1])]
ret = [
_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", False),
expr.IntImm("bool", True),
)
]
return ret


Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,50 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)


def matmul(tensor_a, tensor_b, units=None, out_dtype="", transpose_a=False, transpose_b=False):
"""Matmul operator.
Applies a linear transformation. The A & B can be transposed.
.. math::
`C = A * B`
Parameters
----------
data : tvm.relay.Expr
The first input of the operator,
of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., units_in, d_n)`.
weight : tvm.relay.Expr
The second input expressions, 2-D matrix,
of shape `(units_in, units)` or `(units, units_in)`.
units : Optional[int]
Number of hidden units of the matmul transformation.
out_dtype : Optional[str]
Specifies the output data type for mixed precision matmul,
of shape `(d_1, d_2, ..., d_n, units)`.
transpose_a : Optional[bool] = False
Whether the data tensor is in transposed format.
transpose_b : Optional[bool] = False
Whether the weight tensor is in transposed format.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# Since currently `nn.dense` has better topi schedule support, will prefer to use `dense`
# rather than `matmul` for better compatibility
if not transpose_a and transpose_b:
# TODO(jcf94): Remove this when `nn.matmul` is finnaly ready
return dense(tensor_a, tensor_b, units, out_dtype)
return _make.matmul(tensor_a, tensor_b, units, out_dtype, transpose_a, transpose_b)


def dense(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class BiasAddAttrs(Attrs):
"""Atttribute of nn.bias_add"""


@tvm._ffi.register_object("relay.attrs.MatmulAttrs")
class MatmulAttrs(Attrs):
"""Attributes for nn.matmul"""


@tvm._ffi.register_object("relay.attrs.DenseAttrs")
class DenseAttrs(Attrs):
"""Attributes for nn.dense"""
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,38 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@matmul_strategy.register(["cuda", "gpu"])
def matmul_strategy_cuda(attrs, inputs, out_type, target):
"""Matmul cuda strategy."""
strategy = _op.OpStrategy()

if is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
name="matmul.cuda",
)
else:
logger.warning(
"Matmul is not optimized for cuda. Recommend to use cublas for better performance."
)
# Temporary use this as a basic schedule
strategy.add_implementation(
wrap_compute_matmul(topi.gpu.matmul_default),
wrap_topi_schedule(topi.gpu.schedule_matmul_default),
name="matmul_default.gpu",
)

if target.kind.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_matmul(topi.cuda.matmul_cublas),
wrap_topi_schedule(topi.cuda.schedule_matmul_cublas),
name="matmul_cublas.cuda",
plevel=25,
)
return strategy


@dense_strategy.register(["cuda", "gpu"])
def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,42 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
return strategy


# matmul
def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap matmul topi compute"""

def _compute_matmul(attrs, inputs, out_type):
"""Compute definition of matmul"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
args = [
inputs[0],
inputs[1],
None,
out_dtype,
attrs.transpose_a,
attrs.transpose_b,
]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]

return _compute_matmul


@override_native_generic_func("matmul_strategy")
def matmul_strategy(attrs, inputs, out_type, target):
"""matmul generic strategy"""
logger.warning("matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
wrap_topi_schedule(topi.generic.schedule_matmul),
name="matmul.generic",
)
return strategy


# dense
def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
"""wrap dense topi compute"""
Expand Down
Loading

0 comments on commit 8423279

Please sign in to comment.