Skip to content

Commit

Permalink
[Runtime][PipelineExecutor] Getting the asynchronous output (apache#1…
Browse files Browse the repository at this point in the history
…0723)

This patch create a new GlobalRuntime to check whether the output data
ready and poll global output of pipeline, it also removed the sequence
pipeline execution logic as the asynchronous logic already done.
  • Loading branch information
huajsj authored and pfk-beta committed Apr 11, 2022
1 parent eff0949 commit 8346f9f
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 208 deletions.
4 changes: 2 additions & 2 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def __init__(self, module):
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
self._get_pipe_execute_count = self.module["get_execute_count"]

def run(self, sync=False):
def run(self):
"""Run the pipeline executor."""
self._run(sync)
self._run()

def get_input_pipeline_map(self, name):
"""Using the "name" to get the corresponding subgraph index and also get the "input name"
Expand Down
11 changes: 3 additions & 8 deletions src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(); });
} else if (name == "run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(args[0]); });
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); });
} else if (name == "get_execute_count") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetExecutionCount(); });
Expand Down Expand Up @@ -140,13 +140,8 @@ int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) {
return param_connection_config[name];
}

/*!
* \brief Run the pipeline executor.
* \param serialized_mode Whether run the pipeline executor in serialized mode.
*/
void PipelineExecutor::Run(bool serialized_mode) {
pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_, serialized_mode);
}
/*!\brief Run the pipeline executor.*/
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_); }
/*!
* \brief return A list of global output data.
*/
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
* \return The number of outputs.
*/
int NumOutputs() const { return num_outputs_; }
/*!
* \brief Run the pipeline executor.
* \param serialized_mode Whether run the pipeline executor in serialized mode.
*/
void Run(bool serialized_mode);
/*!\brief Run the pipeline executor.*/
void Run();
/*!
* \brief Get a list output data.
* \return A list of output data.
Expand Down
61 changes: 8 additions & 53 deletions src/runtime/pipeline/pipeline_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config) {
std::vector<std::shared_ptr<BackendRuntime>> runtimes;
graph_modules_ = modules;
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
// Creating a list of runtimes.
for (size_t i = 0; i < graph_modules_.size(); i++) {
auto run_item = std::make_shared<BackendRuntime>(graph_modules_[i], i);
Expand All @@ -49,71 +50,25 @@ std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
}
// Initializing and then running the worker thread.
for (auto runtime : runtimes) {
runtime->InitializePipeline(pipeline_config, &runtimes);
runtime->InitializePipeline(pipeline_config, &runtimes, global_runtime_);
}
return runtimes;
}
/*!
* \brief Running the pipeline logic in the sequential mode.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependent configuration of each runtime module.
*/
void PipelineScheduler::PipelineRunSequential(
const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config) {
for (size_t i = 0; i < runtimes.size(); i++) {
// The "runtimes" is a list of runtime sorted by the runtime index which should be
// contiguous ascend.
if (static_cast<int>(i) != runtimes[i]->GetModuleIndex()) {
LOG(FATAL) << "Runtime index " << runtimes[i]->GetModuleIndex()
<< " is not as same as vector offset value " << i;
}

if (!pipeline_config.FindModuleInConfig(i)) {
LOG(FATAL) << "Not find the configuration for the module " << i;
}

runtimes[i]->Run();
// Getting the output then forwarding into other module once it is configured as input of
// another module or storaging into the "output_array" when the output is a global one.
int outputs_num = runtimes[i]->NumOutputs();
for (int j = 0; j < outputs_num; j++) {
ConfigBindings& out_binding = pipeline_config[i][j];
std::unordered_map<int, std::string>& input_connections = out_binding.Get();
NDArray output = runtimes[i]->GetOutput(j);
for (auto bind : input_connections) {
// "bind.first < 0" means the bind is a global bind, by pass the forwarding for
// a global bind.
if (bind.first < 0) continue;
// Setting the output as an input data into the runtime module.
runtimes[bind.first]->SetInput(bind.second, const_cast<DLTensor*>(output.operator->()));
}
// Store the output.
if (out_binding.IsGlobalOutput()) {
int global_idx = out_binding.GetGlobalOutputIndex();
TVMArrayCopyFromTo(const_cast<DLTensor*>(output.operator->()),
const_cast<DLTensor*>(output_arrays_[global_idx].operator->()), nullptr);
}
}
}
}
/*!
* \brief Running pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
* \param sequential_mode Whether the execution is in a sequential mode.
*/
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config, bool sequential_mode) {
if (!sequential_mode) {
runtimes.front()->RunPipeline();
} else {
PipelineRunSequential(runtimes, pipeline_config);
}
ConfigPipelineExecution pipeline_config) {
runtimes.front()->RunPipeline();
}
/*!
* \brief Get a list of output.
*/
Array<NDArray> PipelineScheduler::PipelineGetOutput() { return output_arrays_; }
Array<NDArray> PipelineScheduler::PipelineGetOutput() {
bool ret = global_runtime_->GetOutput(&output_arrays_);
return ret ? output_arrays_ : Array<NDArray>{};
}
} // namespace runtime
} // namespace tvm
12 changes: 3 additions & 9 deletions src/runtime/pipeline/pipeline_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,9 @@ class PipelineScheduler {
* \brief Running the pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
* \param sequential_mode Whether the execution is in a sequential mode.
*/
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config, bool sequential_mode = false);
/*!
* \brief Running the pipeline logic in the sequential mode.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependent configuration of each runtime module.
*/
void PipelineRunSequential(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config);
ConfigPipelineExecution pipeline_config);
/*!
* \brief Get a list of outputs.
*/
Expand All @@ -68,6 +60,8 @@ class PipelineScheduler {
std::vector<Module> graph_modules_;
/*!\brief A list of NDArray used to storage outputs.*/
Array<NDArray> output_arrays_;
/*!\brief The global runtime to represent the pipeline executor.*/
std::shared_ptr<GlobalRuntime> global_runtime_;
};
} // namespace runtime
} // namespace tvm
Expand Down
Loading

0 comments on commit 8346f9f

Please sign in to comment.