Skip to content

Commit

Permalink
add getInputIndexFromName. lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Feb 21, 2022
1 parent e9eb686 commit cfecf52
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
4 changes: 3 additions & 1 deletion include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,13 @@ class VirtualMachine : public runtime::ModuleNode {
virtual void OpStopHook();

private:
int64_t getInputIndexFromName(const std::string& input_name,
const std::string& func_name) const;
const VMFunction& checkAndGetVMFunction(const std::string& func_name) const;
void SetInputTensorWithIndex(std::vector<ObjectRef>& tensors,
const TVMArgValue& tensor,
int index,
Device dev);
Device dev); // NOLINT(*)

protected:
/*! \brief The virtual machine's packed function table. */
Expand Down
24 changes: 13 additions & 11 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} else if (name == "get_input_index") {
return TypedPackedFunc<int64_t(std::string, std::string)>(
[this](std::string input_name, std::string func_name) {
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& param_names = vm_func.params;
for (uint64_t i = 0; i < param_names.size(); i++) {
if (input_name == param_names[i]) {
return static_cast<int64_t>(i);
}
}
return static_cast<int64_t>(-1);
return getInputIndexFromName(input_name, func_name);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -277,6 +267,18 @@ void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
SetInputTensorWithIndex(inputs_[func_name], args[2], inp_index, dev);
}

int64_t VirtualMachine::getInputIndexFromName(const std::string& input_name,
const std::string& func_name) const {
const auto& vm_func = checkAndGetVMFunction(func_name);
const auto& param_names = vm_func.params;
for (uint64_t i = 0; i < param_names.size(); i++) {
if (input_name == param_names[i]) {
return static_cast<int64_t>(i);
}
}
return static_cast<int64_t>(-1);
}

const VMFunction& VirtualMachine::checkAndGetVMFunction(const std::string& func_name) const {
ICHECK(exec_) << "The executable is not created yet.";
auto gvit = exec_->global_map.find(func_name);
Expand Down

0 comments on commit cfecf52

Please sign in to comment.