diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index cd519f09d6e89..c977957f9077f 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -294,7 +294,12 @@ class VirtualMachine : public runtime::ModuleNode { * \brief Internal hook for profiling the end of an op. */ virtual void OpStopHook(); - + private: + const VMFunction& checkAndGetVMFunction(const std::string& func_name) const; + void SetInputTensorWithIndex(std::vector& tensors, + const TVMArgValue& tensor, + int index, + Device dev); protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs_; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c1d2c6bb5c916..b440a63c1705f 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -237,45 +237,25 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { - ICHECK(exec_) << "The executable is not created yet."; - auto gvit = exec_->global_map.find(func_name); - ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; - auto func_index = gvit->second; - const auto& vm_func = exec_->functions[func_index]; + const auto& vm_func = checkAndGetVMFunction(func_name); const auto& param_names = vm_func.params; ICHECK_EQ(args.size() - offset, param_names.size()) << "The number of provided parameters doesn't match the number of arguments"; + // TODO(vvchernov): Looks like it should be checked earlier and in other place ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size()) << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = offset; i < args.size(); ++i) { - Device dev = GetDevice(vm_func.param_device_indexes[i - offset]); - - if (args[i].type_code() == kTVMDLTensorHandle) { - // Automatically convert input DLTensors to NDArray - DLTensor* tensor = args[i]; - std::vector shape; - for (int64_t i = 0; i < tensor->ndim; i++) { - shape.push_back(tensor->shape[i]); - } - NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); - ary.CopyFrom(tensor); - func_args[i - offset] = ary; - } else { - ObjectRef obj = CopyTo(args[i], dev); - func_args[i - offset] = obj; - } + int index = i - offset; + Device dev = GetDevice(vm_func.param_device_indexes[index]); + SetInputTensorWithIndex(func_args, args[i], index, dev); } inputs_.erase(func_name); inputs_.emplace(func_name, func_args); } void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) { - ICHECK(exec_) << "The executable is not created yet."; - auto gvit = exec_->global_map.find(func_name); - ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; - auto func_index = gvit->second; - const auto& vm_func = exec_->functions[func_index]; + const auto& vm_func = checkAndGetVMFunction(func_name); const auto& param_names = vm_func.params; ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 (func_name, index, tensor)"; // TODO(vvchernov): Looks like it should be checked earlier and in other place @@ -289,24 +269,37 @@ void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) { std::vector func_args(param_names.size()); inputs_.emplace(func_name, func_args); } - ICHECK_EQ(args[1].type_code(), kTVMArgInt) << "The second argument doesn't match integer index"; + ICHECK_EQ(args[1].type_code(), kTVMArgInt) << "The second argument doesn't match integer"; int inp_index = args[1]; - auto& input_tensors = inputs_[func_name]; Device dev = GetDevice(vm_func.param_device_indexes[inp_index]); - if (args[2].type_code() == kTVMDLTensorHandle) { + SetInputTensorWithIndex(inputs_[func_name], args[2], inp_index, dev); +} + +const VMFunction& VirtualMachine::checkAndGetVMFunction(const std::string& func_name) const { + ICHECK(exec_) << "The executable is not created yet."; + auto gvit = exec_->global_map.find(func_name); + ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; + auto func_index = gvit->second; + return exec_->functions[func_index]; +} + +void VirtualMachine::SetInputTensorWithIndex(std::vector& tensors, + const TVMArgValue& inp_tensor, + int index, + Device dev) { + if (inp_tensor.type_code() == kTVMDLTensorHandle) { // Automatically convert input DLTensors to NDArray - DLTensor* tensor = args[2]; + DLTensor* tensor = inp_tensor; std::vector shape; for (int64_t i = 0; i < tensor->ndim; i++) { shape.push_back(tensor->shape[i]); } NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); ary.CopyFrom(tensor); - input_tensors[inp_index] = ary; + tensors[index] = ary; } else { - ObjectRef obj = CopyTo(args[2], dev); - input_tensors[inp_index] = obj; + tensors[index] = CopyTo(inp_tensor, dev); } }