diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 8802cd903b01..38cb763883b7 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -49,9 +49,9 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(static_cast(data->shape.size()) != 0); - Array oshape = data->shape; + Array dshape = data->shape; + Array oshape = dshape; if (param->units.defined()) { - Array dshape = data->shape; // validate the weight shape is proper if defined // Assign weight type Array wshape({param->units, dshape[dshape.size() - 1]}); @@ -72,13 +72,24 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, } else { if (weight == nullptr) return false; Array wshape = weight->shape; - ICHECK(static_cast(weight->shape.size()) == 2); - if (!data->shape.back().as()) { - ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) - << "DenseRel: input dimension doesn't match," - << " data shape=" << data->shape << ", weight shape=" << weight->shape; + // When weight's layout has been rewritten, figure it out based on the + // total number of elements and input dimensions. + if (param->auto_scheduler_rewritten_layout.size() != 0) { + PrimExpr weight_elements = 1; + for (size_t i = 0; i < wshape.size(); i++) { + weight_elements = weight_elements * wshape[i]; + } + oshape.Set(oshape.size() - 1, weight_elements / dshape[dshape.size() - 1]); + // Otherwise just pull it out of the weight shape directly. + } else { + ICHECK(static_cast(weight->shape.size()) == 2); + if (!data->shape.back().as()) { + ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) + << "DenseRel: input dimension doesn't match," + << " data shape=" << data->shape << ", weight shape=" << weight->shape; + } + oshape.Set((oshape.size() - 1), wshape[0]); } - oshape.Set((oshape.size() - 1), wshape[0]); } DataType out_dtype = param->out_dtype; diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py index 8466fc1700b0..106b4bb50346 100644 --- a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py +++ b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py @@ -117,7 +117,7 @@ def get_relay_dense(m=128, n=128, k=128): dtype = "float32" d = relay.var("data", shape=(m, k), dtype=dtype) w = relay.var("weight", shape=(n, k), dtype=dtype) - y = relay.nn.dense(d, w, units=n) + y = relay.nn.dense(d, w) mod = tvm.IRModule() mod["main"] = relay.Function([d, w], y) data, weight = get_np_array(d, dtype), get_np_array(w, dtype)