diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 767cb6f644de..c8327de94232 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -420,8 +420,28 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const ConstantNode* op) final { using tir::make_const; ICHECK(data_dependants_.size()); - ICHECK(op->is_scalar()); bool data_dependant = data_dependants_.back(); + if (!op->is_scalar()) { + // This is a constant weight, extract the shape of the weight tensor. + // This can not be data dependent. + CHECK(!data_dependant); + auto ttype = op->checked_type().as(); + int ndim = static_cast(ttype->shape.size()); + Array out_shape{ndim}; + te::Tensor value = tvm::te::compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = make_const(DataType::Int(64), 0); + for (int i = 0; i < ndim; i++) { + ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); + } + return ret; + }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } if (data_dependant) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 92d6e8e55db4..6958010176e3 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -770,5 +770,30 @@ def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)): tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1))) +def test_constant_shape_with_external_codegen(): + mod = tvm.IRModule() + shape = (relay.Any(), 25) + dtype = "float32" + + # external function + x = relay.var("x", shape=shape, dtype=dtype) + weight = relay.const(np.random.rand(5, 25).astype("float32"), dtype="float32") + out = relay.nn.dense(x, weight) + f1 = relay.Function([x], out) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Compiler", "a") + glb_f1 = relay.GlobalVar("f1") + mod[glb_f1] = f1 + mod = relay.transform.InferType()(mod) + + # Main function + x = relay.var("x", shape=shape, dtype=dtype) + mod["main"] = relay.Function([x], glb_f1(x)) + comp = relay.vm.VMCompiler() + opt_mod, _ = comp.optimize(mod, target="llvm") + assert "shape_func" in opt_mod.astext(False) + + if __name__ == "__main__": pytest.main([__file__])