Skip to content

Commit

Permalink
Additional API support (#30)
Browse files Browse the repository at this point in the history
get_output_index support added.

Co-authored-by: Siva <quic_sivb@quicinc.com>
  • Loading branch information
srkreddy1238 committed Nov 11, 2024
1 parent e3665ae commit 74ff9bb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
35 changes: 35 additions & 0 deletions python/tvm/contrib/graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def __init__(self, module):
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
self._get_input_index = module["get_input_index"]
self._get_output_index = module["get_output_index"]
self._get_input_info = module["get_input_info"]
self._get_output_info = module["get_output_info"]
self._get_num_inputs = module["get_num_inputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]
Expand Down Expand Up @@ -315,6 +317,21 @@ def get_input_index(self, name):
"""
return self._get_input_index(name)

def get_output_index(self, name):
"""Get outputs index via output name.
Parameters
----------
name : str
The output key name
Returns
-------
index: int
The output index. -1 will be returned if the given output name is not found.
"""
return self._get_output_index(name)

def get_input_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the graph.
Expand All @@ -341,6 +358,24 @@ def get_input_info(self):

return shape_dict, dtype_dict

def get_output_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the graph.
Returns
-------
shape_dict : Map
Shape dictionary - {output_name: tuple}.
dtype_dict : Map
dtype dictionary - {output_name: dtype}.
"""
output_info = self._get_output_info()
assert "shape" in output_info
shape_dict = output_info["shape"]
assert "dtype" in output_info
dtype_dict = output_info["dtype"]

return shape_dict, dtype_dict

def get_output(self, index, out=None):
"""Get index-th output to out
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,18 @@ PackedFunc GraphExecutor::GetFunction(const String& name, const ObjectPtr<Object
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(args[0].operator String());
});
} else if (name == "get_output_index") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(String::CanConvertFrom(args[0])) << "Output key is not a string";
int out_idx = -1;
for (size_t i = 0; i < outputs_.size(); i++) {
std::string& name = nodes_[outputs_[i].node_id].name;
if (args[0].operator String() == name) {
out_idx = i;
}
}
*rv = out_idx;
});
} else if (name == "get_input_info") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
auto [shape_info, dtype_info] = this->GetInputInfo();
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/test_backend_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,12 @@ def test_graph_executor_api():
assert isinstance(dtype_dict[name], tvm.runtime.container.String)
assert dtype_dict[name] == ty.dtype

shape_dict, dtype_dict = mod.get_output_info()
assert isinstance(shape_dict, tvm.container.Map)
assert isinstance(dtype_dict, tvm.container.Map)
for i, key in enumerate(shape_dict):
assert mod.get_output_index(key) == i


@tvm.testing.requires_llvm
def test_benchmark():
Expand Down

0 comments on commit 74ff9bb

Please sign in to comment.