Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Feb 21, 2022
1 parent 1b705f8 commit 56b028f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 34 deletions.
7 changes: 6 additions & 1 deletion include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,12 @@ class VirtualMachine : public runtime::ModuleNode {
* \brief Internal hook for profiling the end of an op.
*/
virtual void OpStopHook();

private:
const VMFunction& checkAndGetVMFunction(const std::string& func_name) const;
void SetInputTensorWithIndex(std::vector<ObjectRef>& tensors,
const TVMArgValue& tensor,
int index,
Device dev);
protected:
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs_;
Expand Down
59 changes: 26 additions & 33 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,45 +237,25 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
}

void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
ICHECK(exec_) << "The executable is not created yet.";
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& vm_func = checkAndGetVMFunction(func_name);
const auto& param_names = vm_func.params;
ICHECK_EQ(args.size() - offset, param_names.size())
<< "The number of provided parameters doesn't match the number of arguments";
// TODO(vvchernov): Looks like it should be checked earlier and in other place
ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size())
<< "The number of provided parameters doesn't match the number of assigned devices";
std::vector<ObjectRef> func_args(param_names.size());
for (int i = offset; i < args.size(); ++i) {
Device dev = GetDevice(vm_func.param_device_indexes[i - offset]);

if (args[i].type_code() == kTVMDLTensorHandle) {
// Automatically convert input DLTensors to NDArray
DLTensor* tensor = args[i];
std::vector<int64_t> shape;
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
ary.CopyFrom(tensor);
func_args[i - offset] = ary;
} else {
ObjectRef obj = CopyTo(args[i], dev);
func_args[i - offset] = obj;
}
int index = i - offset;
Device dev = GetDevice(vm_func.param_device_indexes[index]);
SetInputTensorWithIndex(func_args, args[i], index, dev);
}
inputs_.erase(func_name);
inputs_.emplace(func_name, func_args);
}

void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
ICHECK(exec_) << "The executable is not created yet.";
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& vm_func = checkAndGetVMFunction(func_name);
const auto& param_names = vm_func.params;
ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 (func_name, index, tensor)";
// TODO(vvchernov): Looks like it should be checked earlier and in other place
Expand All @@ -289,24 +269,37 @@ void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
std::vector<ObjectRef> func_args(param_names.size());
inputs_.emplace(func_name, func_args);
}
ICHECK_EQ(args[1].type_code(), kTVMArgInt) << "The second argument doesn't match integer index";
ICHECK_EQ(args[1].type_code(), kTVMArgInt) << "The second argument doesn't match integer";
int inp_index = args[1];
auto& input_tensors = inputs_[func_name];
Device dev = GetDevice(vm_func.param_device_indexes[inp_index]);

if (args[2].type_code() == kTVMDLTensorHandle) {
SetInputTensorWithIndex(inputs_[func_name], args[2], inp_index, dev);
}

const VMFunction& VirtualMachine::checkAndGetVMFunction(const std::string& func_name) const {
ICHECK(exec_) << "The executable is not created yet.";
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
return exec_->functions[func_index];
}

void VirtualMachine::SetInputTensorWithIndex(std::vector<ObjectRef>& tensors,
const TVMArgValue& inp_tensor,
int index,
Device dev) {
if (inp_tensor.type_code() == kTVMDLTensorHandle) {
// Automatically convert input DLTensors to NDArray
DLTensor* tensor = args[2];
DLTensor* tensor = inp_tensor;
std::vector<int64_t> shape;
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
ary.CopyFrom(tensor);
input_tensors[inp_index] = ary;
tensors[index] = ary;
} else {
ObjectRef obj = CopyTo(args[2], dev);
input_tensors[inp_index] = obj;
tensors[index] = CopyTo(inp_tensor, dev);
}
}

Expand Down

0 comments on commit 56b028f

Please sign in to comment.