diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 79eb7e4f19ff..4c9a898f2374 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -391,10 +391,20 @@ def _make_executor(self, expr=None): ret_type = self.mod["main"].checked_type.ret_type if _ty.is_dynamic(ret_type): raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type) - num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 mod = build(self.mod, target=self.target) gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) + def _unflatten(flat_iter, cur_type): + if isinstance(cur_type, _ty.TensorType): + return next(flat_iter) + if isinstance(cur_type, _ty.TupleType): + fields = [] + for field_type in cur_type.fields: + field = _unflatten(flat_iter, field_type) + fields.append(field) + return fields + raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) + def _graph_wrapper(*args, **kwargs): args = self._convert_args(self.mod["main"], args, kwargs) # Create map of inputs. @@ -402,13 +412,11 @@ def _graph_wrapper(*args, **kwargs): gmodule.set_input(i, arg) # Run the module, and fetch the output. gmodule.run() - # make a copy so multiple invocation won't hurt perf. - if num_outputs == 1: - return gmodule.get_output(0).copyto(_nd.cpu(0)) - outputs = [] - for i in range(num_outputs): - outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0))) - return outputs + flattened = [] + for i in range(gmodule.get_num_outputs()): + flattened.append(gmodule.get_output(i).copyto(_nd.cpu(0))) + unflattened = _unflatten(iter(flattened), ret_type) + return unflattened return _graph_wrapper diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 3c42b7b4196f..68708aaeb413 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -209,6 +209,27 @@ def test_compile_nested_tuples(): ref = ref + 1 +def test_graph_executor_nested_tuples(): + x, y, z, w = [relay.var(c, shape=(2, 3), dtype="float32") for c in "xyzw"] + out = relay.Tuple([x, relay.Tuple([y, relay.Tuple([z, w])])]) + func = relay.Function([x, y, z, w], out) + + exe = relay.create_executor( + kind="graph", mod=tvm.IRModule.from_expr(func), ctx=tvm.cpu(0), target="llvm" + ) + f = exe.evaluate() + + data = [np.random.uniform(size=(2, 3)).astype("float32") for _ in "xyzw"] + out = f(*data) + assert len(out) == 2 + tvm.testing.assert_allclose(out[0].asnumpy(), data[0]) + assert len(out[1]) == 2 + tvm.testing.assert_allclose(out[1][0].asnumpy(), data[1]) + assert len(out[1][1]) == 2 + tvm.testing.assert_allclose(out[1][1][0].asnumpy(), data[2]) + tvm.testing.assert_allclose(out[1][1][1].asnumpy(), data[3]) + + if __name__ == "__main__": test_plan_memory() test_with_params()