Skip to content

Commit

Permalink
set_input_with_index was implemented for VM
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Feb 17, 2022
1 parent f583a70 commit 53eb213
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,15 @@ class VirtualMachine : public runtime::ModuleNode {
*/
void SetInput(std::string name, TVMArgs args, int offset);

/*!
* \brief Set input tensor with index to a function.
* \param name The function name
* \param args args[1:] are two arguments (index, tensor) to the
* function. If the tensor is not of the correct device for the function,
* they will be copied to the device.
*/
void SetInputWithIndex(std::string name, TVMArgs args);

/*!
* \brief Internal hook for profiling the start of an op.
*
Expand Down
44 changes: 44 additions & 0 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} else if (name == "set_input") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); });
} else if (name == "set_input_with_index") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInputWithIndex(args[0], args); });
} else if (name == "load_late_bound_consts") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
Expand Down Expand Up @@ -267,6 +270,47 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
inputs_.emplace(func_name, func_args);
}

void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
ICHECK(exec_) << "The executable is not created yet.";
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& param_names = vm_func.params;
ICHECK_EQ(args.size(), 3)
<< "The expected number of arguments is 3 (func_name, index, tensor)";
// TODO(vvchernov): Looks like it should be checked earlier and in other place
ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size())
<< "The number of provided parameters doesn't match the number of assigned devices";
if (inputs_.count(func_name)) {
ICHECK_EQ(inputs_[func_name].size(), param_names.size())
<< "The size of function" << func_name << " doesn't match the number of provided parameters";
} else {
std::vector<ObjectRef> func_args(param_names.size());
inputs_.emplace(func_name, func_args);
}
ICHECK_EQ(args[1].type_code(), kTVMArgInt)
<< "The second argument doesn't match integer index";
int inp_index = args[1];
auto& input_tensors = inputs_[func_name];
Device dev = GetDevice(vm_func.param_device_indexes[inp_index]);

if (args[2].type_code() == kTVMDLTensorHandle) {
// Automatically convert input DLTensors to NDArray
DLTensor* tensor = args[2];
std::vector<int64_t> shape;
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
ary.CopyFrom(tensor);
input_tensors[inp_index] = ary;
} else {
ObjectRef obj = CopyTo(args[2], dev);
input_tensors[inp_index] = obj;
}
}

inline Device VirtualMachine::GetDevice(Index device_index) const {
ICHECK_GE(devices_.size(), device_index) << "invalid device index: " << device_index;
return devices_[device_index];
Expand Down

0 comments on commit 53eb213

Please sign in to comment.