Skip to content

Commit

Permalink
[Runtime][PipelineExecutor] Added Interface to Track Number of Global…
Browse files Browse the repository at this point in the history
… Inputs (apache#11315)

* [Runtime][PipleineExecutor] Added Interface to Track Number of Global Inputs

Added a feature to PipelineExecutor to track number of Global Inputs.

* Fixed CI Error

* Fixed remaining CI Error
  • Loading branch information
Raghav-Chakravarthy authored Jun 17, 2022
1 parent 648154d commit 8a94b66
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 1 deletion.
11 changes: 11 additions & 0 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, module):
self._get_input = self.module["get_input"]
self._get_output = self.module["get_output"]
self._get_num_outputs = self.module["get_num_outputs"]
self._get_num_inputs = self.module["get_num_inputs"]
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
self._get_pipe_execute_count = self.module["get_execute_count"]

Expand Down Expand Up @@ -159,6 +160,16 @@ def num_outputs(self):
"""
return self._get_num_outputs()

@property
def num_inputs(self):
"""Get the number of inputs
Returns
-------
count : int
The number of inputs
"""
return self._get_num_inputs()

@staticmethod
def load_library(config_file_name):
"""Import files to create a pipeline executor.
Expand Down
8 changes: 7 additions & 1 deletion src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
if (name == "get_num_outputs") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
} else if (name == "get_num_inputs") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); });
} else if (name == "get_input_pipeline_map") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
Expand Down Expand Up @@ -87,7 +90,10 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
return PackedFunc();
}
}

/*!
* brief Returns number of global inputs.
*/
int PipelineExecutor::NumInputs(void) { return input_connection_config_.GetInputNum(); }
/*!
* \brief set input to the runtime module.
* \param input_name The input name.
Expand Down
1 change: 1 addition & 0 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
int NumOutputs() const { return num_outputs_; }
/*!\brief Run the pipeline executor.*/
void Run();
int NumInputs();
/*!
* \brief Get a list output data.
* \return A list of output data.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/pipeline/pipeline_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ struct InputConnectionConfig {
}
return input_connection[key];
}
/*!\brief Returns the number of global inputs through the input_runtime_map list size.*/
int GetInputNum() { return input_runtime_map.size(); }

/*!
* \brief Getting the global input index through the input name.
* \param input_name The global input name.
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def test_pipeline():
if input_map[0] == "0":
input_data = pipeline_module_test.get_input("data_a")
tvm.testing.assert_allclose(data, input_data.numpy())

assert pipeline_module_test.num_inputs == 2
# Running the pipeline executor in the pipeline mode.
pipeline_module_test.run()

Expand Down

0 comments on commit 8a94b66

Please sign in to comment.