Skip to content

Commit

Permalink
Code refine
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jun 22, 2021
1 parent 6c736fb commit 93787e4
Show file tree
Hide file tree
Showing 14 changed files with 50 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MatmulAttrs(Attrs):
class DenseAttrs(Attrs):
"""Attributes for nn.dense
This is still used by qnn.dense. TODO: Rename `qnn.dense` to `qnn.matmul`.
This is still used by qnn.dense. TODO: Rewrite `qnn.dense` to `qnn.matmul`.
"""


Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/strategy/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out


def dense_strategy_bifrost(attrs, inputs, out_type, target):
"""dense mali(bifrost) strategy"""
"""Dense mali(bifrost) strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.bifrost.dense),
Expand All @@ -122,7 +124,7 @@ def dense_strategy_bifrost(attrs, inputs, out_type, target):

@matmul_strategy.register("bifrost")
def matmul_strategy_bifrost(attrs, inputs, out_type, target):
"""matmul mali(bifrost) strategy"""
"""Matmul mali(bifrost) strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):


def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
"""Dense cuda strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
data, weights = inputs
b, i = get_const_tuple(data.shape)
Expand Down Expand Up @@ -759,13 +761,13 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):

@matmul_strategy.register(["cuda", "gpu"])
def matmul_strategy_cuda(attrs, inputs, out_type, target):
"""matmul cuda strategy"""
"""Matmul cuda strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
strategy = dense_strategy_cuda(attrs, inputs, out_type, target)
else:
logger.warning("Matmul other than NT format is not optimized for x86.")
logger.warning("Matmul other than NT format is not optimized for cuda.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def _compute_matmul(attrs, inputs, out_type):

@override_native_generic_func("matmul_strategy")
def matmul_strategy(attrs, inputs, out_type, target):
"""matmul generic strategy"""
"""Matmul generic strategy"""
logger.warning("matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/strategy/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):


def dense_strategy_hls(attrs, inputs, out_type, target):
"""dense hls strategy"""
"""Dense hls strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense),
Expand All @@ -171,7 +173,7 @@ def dense_strategy_hls(attrs, inputs, out_type, target):

@matmul_strategy.register("hls")
def matmul_strategy_hls(attrs, inputs, out_type, target):
"""matmul hls strategy"""
"""Matmul hls strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/strategy/mali.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty


def dense_strategy_mali(attrs, inputs, out_type, target):
"""dense mali strategy"""
"""Dense mali strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.mali.dense),
Expand All @@ -185,7 +187,7 @@ def dense_strategy_mali(attrs, inputs, out_type, target):

@matmul_strategy.register("mali")
def matmul_strategy_mali(attrs, inputs, out_type, target):
"""dense mali strategy"""
"""Matmul mali strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):


def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
"""Dense strategy for ROCM.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
strategy = _op.OpStrategy()
strategy.add_implementation(
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):


def dense_strategy_cpu(attrs, inputs, out_type, target):
"""Dense x86 strategy. This is a specialized case for Matmul with data non-transposed and
weight transposed.
"""Dense x86 strategy.
This is a specialized case for Matmul with data non-transposed and weight transposed.
"""

strategy = _op.OpStrategy()
Expand All @@ -394,14 +394,19 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):

@matmul_strategy.register("cpu")
def matmul_strategy_cpu(attrs, inputs, out_type, target):
"""Matmul x86 strategy."""
"""Matmul x86 strategy"""

if not attrs.data_transposed and attrs.weight_transposed:
# Specialized schedule for dense(matmul-NT)
strategy = dense_strategy_cpu(attrs, inputs, out_type, target)
else:
logger.warning("Matmul other than NT format is not optimized for x86.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
name="matmul.generic",
)

same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
dtype = inputs[0].dtype
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
out_dtype=attrs["out_dtype"],
data_transposed=False,
weight_transposed=True,
) # TODO: Rename qnn.dense to qnn.matmul and remove this transformation
) # TODO: Rewrite qnn.dense to qnn.matmul and remove this transformation
return helper_no_fast_int8_hw_legalization(mattrs, inputs, types, relay.nn.matmul)


Expand Down Expand Up @@ -342,7 +342,7 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
out_dtype=attrs["out_dtype"],
data_transposed=False,
weight_transposed=True,
) # TODO: Rename qnn.dense to qnn.matmul and remove this transformation
) # TODO: Rewrite qnn.dense to qnn.matmul and remove this transformation
return helper_no_fast_int8_hw_legalization(mattrs, inputs, types, relay.nn.matmul)


Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):

def _dense_legalize(attrs, inputs, arg_types):
"""Legalizes dense op.
This is a specialized case for Matmul with data non-transposed and weight transposed.
Parameters
----------
Expand Down
10 changes: 10 additions & 0 deletions rust/tvm/src/ir/relay/attrs/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ pub struct BiasAddAttrsNode {
pub axis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "DenseAttrs"]
#[type_key = "relay.attrs.DenseAttrs"]
pub struct DenseAttrsNode {
pub base: BaseAttrsNode,
pub units: IndexExpr,
pub out_dtype: DataType,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "MatmulAttrs"]
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ Useful for

// ------------------- relay.nn.matmul
TVM_REGISTER_NODE_TYPE(MatmulAttrs);
TVM_REGISTER_NODE_TYPE(DenseAttrs); // Used by qnn.dense. TODO: Rename qnn.dense to qnn.matmul
TVM_REGISTER_NODE_TYPE(DenseAttrs); // Used by qnn.dense. TODO: Rewrite qnn.dense to qnn.matmul

Expr MakeMatmul(Expr data, Expr weight, IndexExpr units, DataType out_dtype, bool data_transposed,
bool weight_transposed) {
Expand Down
4 changes: 4 additions & 0 deletions src/relay/transforms/auto_scheduler_layout_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class FuncMutator : public ExprMutator {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
} else if (auto pattr = call->attrs.as<Conv3DAttrs>()) {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
} else if (auto pattr = call->attrs.as<DenseAttrs>()) {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
} else if (auto pattr = call->attrs.as<MatmulAttrs>()) {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
} else if (auto pattr = call->attrs.as<BatchMatmulAttrs>()) {
Expand Down Expand Up @@ -166,6 +168,8 @@ TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
return attrs.as<Conv2DWinogradAttrs>()->auto_scheduler_rewritten_layout;
} else if (attrs->IsInstance<Conv3DAttrs>()) {
return attrs.as<Conv3DAttrs>()->auto_scheduler_rewritten_layout;
} else if (attrs->IsInstance<DenseAttrs>()) {
return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout;
} else if (attrs->IsInstance<MatmulAttrs>()) {
return attrs.as<MatmulAttrs>()->auto_scheduler_rewritten_layout;
} else if (attrs->IsInstance<BatchMatmulAttrs>()) {
Expand Down
2 changes: 2 additions & 0 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class MixedPrecisionPass : public MixedModeMutator {
return ModifyAttrsOutputDType(attrs, accumulation_dtype);
} else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
return ModifyAttrsOutputDType(attrs, accumulation_dtype);
} else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
return ModifyAttrsOutputDType(attrs, accumulation_dtype);
} else if (auto attrs = cur_attrs.as<MatmulAttrs>()) {
return ModifyAttrsOutputDType(attrs, accumulation_dtype);
} else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
Expand Down

0 comments on commit 93787e4

Please sign in to comment.