Skip to content

Commit

Permalink
[Relay][Autoscheduler] Fix autoscheduler matmul without units. (apach…
Browse files Browse the repository at this point in the history
…e#7957)

* Fix autoscheduler matmul without units.

* Fix lint.
  • Loading branch information
Josh Fromm authored and trevor-m committed May 11, 2021
1 parent b6f4cd8 commit 762ef80
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
27 changes: 19 additions & 8 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

ICHECK(static_cast<int>(data->shape.size()) != 0);

Array<tvm::PrimExpr> oshape = data->shape;
Array<tvm::PrimExpr> dshape = data->shape;
Array<tvm::PrimExpr> oshape = dshape;
if (param->units.defined()) {
Array<tvm::PrimExpr> dshape = data->shape;
// validate the weight shape is proper if defined
// Assign weight type
Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
Expand All @@ -72,13 +72,24 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
} else {
if (weight == nullptr) return false;
Array<tvm::PrimExpr> wshape = weight->shape;
ICHECK(static_cast<int>(weight->shape.size()) == 2);
if (!data->shape.back().as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
<< "DenseRel: input dimension doesn't match,"
<< " data shape=" << data->shape << ", weight shape=" << weight->shape;
// When weight's layout has been rewritten, figure it out based on the
// total number of elements and input dimensions.
if (param->auto_scheduler_rewritten_layout.size() != 0) {
PrimExpr weight_elements = 1;
for (size_t i = 0; i < wshape.size(); i++) {
weight_elements = weight_elements * wshape[i];
}
oshape.Set(oshape.size() - 1, weight_elements / dshape[dshape.size() - 1]);
// Otherwise just pull it out of the weight shape directly.
} else {
ICHECK(static_cast<int>(weight->shape.size()) == 2);
if (!data->shape.back().as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
<< "DenseRel: input dimension doesn't match,"
<< " data shape=" << data->shape << ", weight shape=" << weight->shape;
}
oshape.Set((oshape.size() - 1), wshape[0]);
}
oshape.Set((oshape.size() - 1), wshape[0]);
}

DataType out_dtype = param->out_dtype;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_relay_dense(m=128, n=128, k=128):
dtype = "float32"
d = relay.var("data", shape=(m, k), dtype=dtype)
w = relay.var("weight", shape=(n, k), dtype=dtype)
y = relay.nn.dense(d, w, units=n)
y = relay.nn.dense(d, w)
mod = tvm.IRModule()
mod["main"] = relay.Function([d, w], y)
data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
Expand Down

0 comments on commit 762ef80

Please sign in to comment.