Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Matmul] Add matmul op #8234

Merged
merged 29 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,32 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
}
};

/*! \brief Attributes for matmul operator */
struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
IndexExpr units;
DataType out_dtype;
bool data_transposed;
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
bool weight_transposed;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");

TVM_ATTR_FIELD(data_transposed)
.set_default(false)
.describe("Whether the input tensor is in transposed format.");

TVM_ATTR_FIELD(weight_transposed)
.set_default(false)
.describe("Whether the weight tensor is in transposed format.");
}
};

/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@

__all__ = ["from_tensorflow"]

# By default, TVM convert `tf.matmul` to `nn.dense` op with data tensor non-transposed and weight
# tensor transposed
_USE_DENSE_INSTEAD_OF_MATMUL = True
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1204,7 +1208,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
return func, self._params


def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should have a flag here. We should just commit to one codepath.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we're not able to remove all the nn.dense at this moment and there's not enough AutoTVM template for nn.matmul.

So the use of nn.matmul can only be seen as a experimental feature. We should not change the default behavior in case this may affect those who are using nn.dense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use the dense schedules when A_transpose=false and B_transpose=true. Then we can convert all nn.dense to nn.matmul.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR already uses dense schedule for matmul_nt in the case of lowering to TOPI. On the other hand, as @jcf94 mentioned in the PR comment, doing so will affect much more places in the codebase and we better gradually convert them instead of in a single PR. It sounds reasonable to me.

"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.

Expand All @@ -1222,6 +1226,11 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.

use_dense_op : bool (Optional)
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be
transposed, may insert extra `transpose` to the original graph.

Returns
-------
mod : tvm.IRModule
Expand All @@ -1230,6 +1239,9 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
global _USE_DENSE_INSTEAD_OF_MATMUL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to avoid using this global variable? I'm not familiar with the importer but would be nice if we could use an importer config dict or something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've also tried several ways, but seems there is no better solution from my view. Python module can be seen as a const singleton, this should be safe if the from_tensorflow function is the only entry.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is confusing too. If _USE_DENSE_INSTEAD_OF_MATMUL is not supposed to be changed by users directly, we should improve the comments of this global variable. Please see my comment at the global variable.

btw, in this case we can simply _USE_DENSE_INSTEAD_OF_MATMUL = use_dense_op without checking if they are the same or not.

if use_dense_op != _USE_DENSE_INSTEAD_OF_MATMUL:
_USE_DENSE_INSTEAD_OF_MATMUL = use_dense_op

g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
Expand Down
24 changes: 19 additions & 5 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,13 +1113,27 @@ def _impl(inputs, attr, params, mod):

def _matmul():
def _impl(inputs, attr, params, mod):
from .tensorflow import _USE_DENSE_INSTEAD_OF_MATMUL

channels = _infer_channels(inputs[1], not attr["transpose_b"])
if attr["transpose_a"]:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr["transpose_b"]:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
if _USE_DENSE_INSTEAD_OF_MATMUL:
if attr["transpose_a"]:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr["transpose_b"]:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
return AttrCvt(
op_name="dense",
extras={"units": channels},
ignores=["transpose_a", "transpose_b", "T"],
)(inputs, attr)
return AttrCvt(
op_name="dense", extras={"units": channels}, ignores=["transpose_a", "transpose_b", "T"]
op_name="matmul",
extras={
"units": channels,
"data_transposed": attr["transpose_a"] or False,
"weight_transposed": attr["transpose_b"] or False,
},
ignores=["transpose_a", "transpose_b", "T"],
)(inputs, attr)

return _impl
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,35 @@ def dense_grad(orig, grad):
]


@register_gradient("nn.matmul")
def matmul_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (True, True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refactor this to not if/else on every possible combination of transpose.

return [
collapse_sum_like(
_nn.matmul(weight, grad, data_transposed=True, weight_transposed=True), data
),
collapse_sum_like(
_nn.matmul(grad, data, data_transposed=True, weight_transposed=True), weight
),
]
if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (True, False):
return [
collapse_sum_like(_nn.matmul(weight, grad, weight_transposed=True), data),
collapse_sum_like(_nn.matmul(data, grad), weight),
]
if (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (False, True):
# Keep using Dense op here for not involving extra ops
# TODO(jcf94): Merge all to nn.matmul when it is finally ready
return dense_grad(orig, grad)
# (orig.attrs["data_transposed"], orig.attrs["weight_transposed"]) == (False, False)
return [
collapse_sum_like(_nn.matmul(grad, weight, weight_transposed=True), data),
collapse_sum_like(_nn.matmul(data, grad, data_transposed=True), weight),
]


@register_gradient("nn.batch_matmul")
def batch_matmul_grad(orig, grad):
"""gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
Expand Down
57 changes: 54 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


@reg.register_legalize("nn.matmul")
def legalize_matmul(attrs, inputs, types):
"""Legalize matmul op.
FrozenGene marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current matmul
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.matmul_legalize(attrs, inputs, types)


# matmul
reg.register_strategy("nn.matmul", strategy.matmul_strategy)
reg.register_pattern("nn.matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_legalize("nn.dense")
def legalize_dense(attrs, inputs, types):
"""Legalize dense op.
Expand Down Expand Up @@ -1160,21 +1186,46 @@ def batch_flatten_shape_func(attrs, inputs, _):


@script
def _dense_shape_func(data_shape, weight_shape):
def _matmul_shape_func(data_shape, weight_shape, data_transposed, weight_transposed):
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0]
if data_transposed:
out[out.shape[0] - 2] = out[out.shape[0] - 1]
out[out.shape[0] - 1] = weight_shape[0] if weight_transposed else weight_shape[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems really complicated. Shouldn't it just be some part of data_shape and weight_shape depending on the transposes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the dimension of data tensor can be more than 2, this is the simplest implementation to do so.


return out


@reg.register_shape_func("nn.matmul", False)
def matmul_shape_func(attrs, inputs, _):
"""
Shape function for matmul op.
"""
ret = [
_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", attrs.data_transposed),
expr.IntImm("bool", attrs.weight_transposed),
)
]
return ret


@reg.register_shape_func("nn.dense", False)
def dense_shape_func(attrs, inputs, _):
"""
Shape function for dense op.
"""
ret = [_dense_shape_func(inputs[0], inputs[1])]
ret = [
_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", False),
expr.IntImm("bool", True),
)
]
return ret


Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,47 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)


def matmul(data, weight, units=None, out_dtype="", data_transposed=False, weight_transposed=False):
"""Matmul operator.
Applies a linear transformation. The X & W can be transposed.

.. math::

`Y = X * W`

Parameters
----------
data : tvm.relay.Expr
The input data to the operator,
of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., units_in, d_n)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't both input shapes by dimension 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the input of matmul is supposed to be a multiple-dim tensor(not limited to 2). This is copied from the original nn.dense.

Other frameworks like Pytorch also has such definition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the definition of the computation above to reflect these shapes then?


weight : tvm.relay.Expr
The weight expressions, 2-D matrix,
of shape `(units_in, units)` or `(units, units_in)`.

units : int, optional
Number of hidden units of the matmul transformation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a unit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the doc has explained enough: "The hidden units." This is copied from the original nn.dense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is clear at all. Is the hidden units the inner dimension of the matmul?


out_dtype : str, optional
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
Specifies the output data type for mixed precision dense,
of shape `(d_1, d_2, ..., d_n, units)`.

data_transposed : bool, optional
Whether the data tensor is in transposed format.

weight_transposed : bool, optional
Whether the weight tensor is in transposed format.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
if not data_transposed and weight_transposed:
return dense(data, weight, units, out_dtype)
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
return _make.matmul(data, weight, units, out_dtype, data_transposed, weight_transposed)


def dense(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class BiasAddAttrs(Attrs):
"""Atttribute of nn.bias_add"""


@tvm._ffi.register_object("relay.attrs.MatmulAttrs")
class MatmulAttrs(Attrs):
"""Attributes for nn.matmul"""


@tvm._ffi.register_object("relay.attrs.DenseAttrs")
class DenseAttrs(Attrs):
"""Attributes for nn.dense"""
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,36 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@matmul_strategy.register(["cuda", "gpu"])
def matmul_strategy_cuda(attrs, inputs, out_type, target):
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
"""Matmul cuda strategy"""
strategy = _op.OpStrategy()

if is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
name="matmul.cuda",
)
else:
logger.warning("Matmul other than NT format is not optimized for cuda.")
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
# Temporary use this as a basic schedule
strategy.add_implementation(
wrap_compute_matmul(topi.cuda.matmul_default_cuda),
wrap_topi_schedule(topi.cuda.schedule_matmul_default_cuda),
name="matmul_default.cuda",
)

if target.kind.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_matmul(topi.cuda.matmul_cublas),
wrap_topi_schedule(topi.cuda.schedule_matmul_cublas),
name="matmul_cublas.cuda",
plevel=25,
)
return strategy


@dense_strategy.register(["cuda", "gpu"])
def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,42 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
return strategy


# matmul
def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap matmul topi compute"""

def _compute_matmul(attrs, inputs, out_type):
"""Compute definition of matmul"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
args = [
inputs[0],
inputs[1],
None,
out_dtype,
attrs.data_transposed,
attrs.weight_transposed,
]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]

return _compute_matmul


@override_native_generic_func("matmul_strategy")
def matmul_strategy(attrs, inputs, out_type, target):
"""matmul generic strategy"""
logger.warning("matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
wrap_topi_schedule(topi.generic.schedule_matmul),
name="matmul.generic",
)
return strategy


# dense
def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
"""wrap dense topi compute"""
Expand Down
Loading