Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GRAPH RT] Additional API support #17513

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions python/tvm/contrib/graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def __init__(self, module):
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
self._get_input_index = module["get_input_index"]
self._get_output_index = module["get_output_index"]
self._get_input_info = module["get_input_info"]
self._get_output_info = module["get_output_info"]
self._get_num_inputs = module["get_num_inputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]
Expand Down Expand Up @@ -315,6 +317,21 @@ def get_input_index(self, name):
"""
return self._get_input_index(name)

def get_output_index(self, name):
"""Get outputs index via output name.

Parameters
----------
name : str
The output key name

Returns
-------
index: int
The output index. -1 will be returned if the given output name is not found.
"""
return self._get_output_index(name)

def get_input_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the graph.

Expand All @@ -341,6 +358,24 @@ def get_input_info(self):

return shape_dict, dtype_dict

def get_output_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the graph.

Returns
-------
shape_dict : Map
Shape dictionary - {output_name: tuple}.
dtype_dict : Map
dtype dictionary - {output_name: dtype}.
"""
output_info = self._get_output_info()
assert "shape" in output_info
shape_dict = output_info["shape"]
assert "dtype" in output_info
dtype_dict = output_info["dtype"]

return shape_dict, dtype_dict

def get_output(self, index, out=None):
"""Get index-th output to out

Expand Down
6 changes: 5 additions & 1 deletion src/runtime/crt/aot_executor_module/aot_executor_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = {
&TVMAotExecutorModule_NotImplemented, // set_input (implemented via python wrapper)
&TVMAotExecutorModule_NotImplemented, // share_params (do not implement)
&TVMAotExecutorModule_GetInputName, // get_input_name
&TVMAotExecutorModule_NotImplemented, // get_output_index
&TVMAotExecutorModule_NotImplemented, // get_output_info
};

static const TVMFuncRegistry aot_executor_registry = {
Expand All @@ -223,7 +225,9 @@ static const TVMFuncRegistry aot_executor_registry = {
"run\0"
"set_input\0"
"share_params\0"
"get_input_name\0",
"get_input_name\0"
"get_output_index\0"
"get_output_info\0",
aot_executor_registry_funcs};

tvm_crt_error_t TVMAotExecutorModule_Register() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ static const TVMBackendPackedCFunc graph_executor_registry_funcs[] = {
&TVMGraphExecutorModule_Run,
&TVMGraphExecutorModule_SetInput,
&TVMGraphExecutorModule_NotImplemented, // share_params
&TVMGraphExecutorModule_NotImplemented, // get_output_index
&TVMGraphExecutorModule_NotImplemented, // get_output_info
};

static const TVMFuncRegistry graph_executor_registry = {
Expand All @@ -247,7 +249,9 @@ static const TVMFuncRegistry graph_executor_registry = {
"load_params\0"
"run\0"
"set_input\0"
"share_params\0",
"share_params\0"
"get_output_index\0"
"get_output_info\0",
graph_executor_registry_funcs};

tvm_crt_error_t TVMGraphExecutorModule_Register() {
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,18 @@ PackedFunc GraphExecutor::GetFunction(const String& name, const ObjectPtr<Object
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(args[0].operator String());
});
} else if (name == "get_output_index") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(String::CanConvertFrom(args[0])) << "Output key is not a string";
int out_idx = -1;
for (size_t i = 0; i < outputs_.size(); i++) {
std::string& name = nodes_[outputs_[i].node_id].name;
if (args[0].operator String() == name) {
out_idx = i;
}
}
*rv = out_idx;
});
} else if (name == "get_input_info") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
auto [shape_info, dtype_info] = this->GetInputInfo();
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/test_backend_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,12 @@ def test_graph_executor_api():
assert isinstance(dtype_dict[name], tvm.runtime.container.String)
assert dtype_dict[name] == ty.dtype

shape_dict, dtype_dict = mod.get_output_info()
assert isinstance(shape_dict, tvm.container.Map)
assert isinstance(dtype_dict, tvm.container.Map)
for i, key in enumerate(shape_dict):
assert mod.get_output_index(key) == i


@tvm.testing.requires_llvm
def test_benchmark():
Expand Down
Loading