diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 36409d375d74c..58f8f47e7fa23 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -364,7 +364,7 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): op_name : str The name of the op. - data_dependant : bool + data_dependant : bool or list of bool Whether the shape function depends on input data. shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr]) diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 17b3d3f22737b..00782d6c15b9d 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -105,7 +105,9 @@ TEST(Relay, BuildModule) { } auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs); (*reg)("add", "FTVMStrategy", fgeneric, 10); - (*reg)("add", "TShapeDataDependant", {0}, 10); + Array dep; + dep.push_back(0); + (*reg)("add", "TShapeDataDependant", dep, 10); // build auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); tvm::runtime::Module build_mod = (*pfb)();