Skip to content

Commit

Permalink
[VM] Add control flow to relax vm (apache#61)
Browse files Browse the repository at this point in the history
* Add if and goto instr.

* Update python/tvm/relax/exec_builder.py

Co-authored-by: Yong Wu <yongcale@gmail.com>

* use set_body_method.

Co-authored-by: Yong Wu <yongcale@gmail.com>
  • Loading branch information
YuchenJin and yongwww committed Nov 17, 2022
1 parent 7f8cb36 commit e87a582
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 41 deletions.
31 changes: 31 additions & 0 deletions include/tvm/relax/vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ using Index = ExecWord;
enum class Opcode {
Call = 1U,
Ret = 2U,
Goto = 3U,
If = 4U,
};


Expand Down Expand Up @@ -135,6 +137,20 @@ struct Instruction {
/*! \brief The return result. */
RegName result;
};
struct /* Goto */ {
/*! \brief The jump offset. */
Index pc_offset;
};
struct /* If */ {
/*! \brief The register containing the test value. */
RegName test;
/*! \brief The register containing the target value. */
RegName target;
/*! \brief The program counter offset for the true branch. */
Index true_offset;
/*! \brief The program counter offset for the false branch. */
Index false_offset;
};
};
/*!
* \brief Construct a Call instruction.
Expand All @@ -153,6 +169,21 @@ struct Instruction {
* \return The return instruction.
*/
static Instruction Ret(RegName result);
/*!
* \brief Construct a goto instruction.
* \param pc_offset The register containing the jump offset.
* \return The goto instruction.
*/
static Instruction Goto(RegName pc_offset);
/*!
* \brief Construct an If instruction.
* \param test The register containing the test value.
* \param target The register containing the target value.
* \param true_offset The program counter offset for the true branch.
* \param false_offset The program counter offset for the false branch.
* \return The If instruction.
*/
static Instruction If(RegName test, RegName target, Index true_offset, Index false_offset);
};

} // namespace relax_vm
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relax/vm/exec_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ class ExecBuilderNode : public Object {
* \param result The return result.
*/
void EmitRet(vm::RegName result);
/*!
* \brief Emit a goto instruction.
* \param pc_offset The program counter offset as the jump offset.
*/
void EmitGoto(vm::Index pc_offset);
/*!
* \brief Emit an If instruction.
* \param test The register containing the test value.
* \param target The register containing the target value.
* \param true_offset The program counter offset for the true branch.
* \param false_offset The program counter offset for the false branch.
*/
void EmitIf(vm::RegName test, vm::RegName target, vm::Index true_offset, vm::Index false_offset);
/*!
* \brief Emit a constant value to the constant pool.
* \return The index that represents the constant.
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relax/exec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def emit_ret(self, result: int) -> None:
self._check_scope()
_ffi_api.ExecBuilderEmitRet(self, result)

def emit_goto(self, pc_offset):
"""emit a goto instruction"""
self._check_scope()
_ffi_api.ExecBuilderEmitGoto(self, pc_offset)

def emit_if(self, test, target, true_offset, false_offset):
"""emit an if instruction"""
self._check_scope()
_ffi_api.ExecBuilderEmitIf(self, test, target, true_offset, false_offset)

def get(self) -> Executable:
"""return the executable"""
return _ffi_api.ExecBuilderGet(self)
2 changes: 1 addition & 1 deletion src/relax/vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ TVM_REGISTER_GLOBAL("vm.call_tir_dyn")
}

ShapeTuple to_unpack = args[args.size() - 1];
int num_tensor_args = args.size() - 3;
size_t num_tensor_args = args.size() - 3;
std::vector<TVMValue> values(num_tensor_args + to_unpack.size());
std::vector<int> tcodes(num_tensor_args + to_unpack.size());
runtime::TVMArgsSetter setter(values.data(), tcodes.data());
Expand Down
16 changes: 16 additions & 0 deletions src/relax/vm/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ Instruction Instruction::Ret(RegName result) {
return instr;
}

Instruction Instruction::Goto(Index pc_offset) {
Instruction instr;
instr.op = Opcode::Goto;
instr.pc_offset = pc_offset;
return instr;
}

Instruction Instruction::If(RegName test, RegName target, Index true_branch, Index false_branch) {
Instruction instr;
instr.op = Opcode::If;
instr.test = test;
instr.target = target;
instr.true_offset = true_branch;
instr.false_offset = false_branch;
return instr;
}
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
51 changes: 42 additions & 9 deletions src/relax/vm/exec_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ void ExecBuilderNode::EmitRet(RegName result) {
exec->instr_data.push_back(result);
}

void ExecBuilderNode::EmitGoto(Index pc_offset) {
exec->instr_offset.push_back(exec->instr_data.size());
exec->instr_data.push_back(static_cast<ExecWord>(Opcode::Goto));
exec->instr_data.push_back(pc_offset);
}

void ExecBuilderNode::EmitIf(vm::RegName test, vm::RegName target, vm::Index true_offset,
vm::Index false_offset){
exec->instr_offset.push_back(exec->instr_data.size());
exec->instr_data.push_back(static_cast<ExecWord>(Opcode::If));
exec->instr_data.push_back(test);
exec->instr_data.push_back(target);
exec->instr_data.push_back(true_offset);
exec->instr_data.push_back(false_offset);
}

// helper function to check if an executable is legal by checking if registers are used properly
bool CheckExecutable(Executable exec) {
for (auto it = exec->global_funcs.cbegin(); it != exec->global_funcs.cend(); ++it) {
Expand Down Expand Up @@ -120,6 +136,17 @@ bool CheckExecutable(Executable exec) {
}
break;
}
case Opcode::Goto: {
ICHECK_GT(instr.pc_offset, 0);
break;
}
case Opcode::If: {
ICHECK_GT(instr.true_offset, 0);
ICHECK_GT(instr.false_offset, 0);
arg_registers.emplace(instr.test);
arg_registers.emplace(instr.target);
break;
}
default:
LOG(FATAL) << "should never hit this case: " << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -168,6 +195,12 @@ void ExecBuilderNode::Formalize() {
}
break;
}
case Opcode::Goto: {
break;
}
case Opcode::If: {
break;
}
default:
LOG(FATAL) << "should never hit this case: " << static_cast<int>(instr.op);
break;
Expand All @@ -189,9 +222,7 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitConstant")
});

TVM_REGISTER_GLOBAL("relax.ExecBuilderFunction")
.set_body_typed([](ExecBuilder builder, String name, int64_t num_inputs) {
return builder->EmitFunction(name, num_inputs);
});
.set_body_method<ExecBuilder>(&ExecBuilderNode::EmitFunction);

TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall")
.set_body_typed([](ExecBuilder builder, String name, Array<IntImm> args, int64_t dst) {
Expand All @@ -205,9 +236,13 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall")
});

TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet")
.set_body_typed([](ExecBuilder builder, int64_t result) {
builder->EmitRet(result);
});
.set_body_method<ExecBuilder>(&ExecBuilderNode::EmitRet);

TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto")
.set_body_method<ExecBuilder>(&ExecBuilderNode::EmitGoto);

TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitIf")
.set_body_method<ExecBuilder>(&ExecBuilderNode::EmitIf);

TVM_REGISTER_GLOBAL("relax.ExecBuilderR")
.set_body_typed([](ExecBuilder builder, int64_t value) {
Expand All @@ -225,9 +260,7 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderC")
});

TVM_REGISTER_GLOBAL("relax.ExecBuilderGet")
.set_body_typed([](ExecBuilder builder) {
return builder->Get();
});
.set_body_method<ExecBuilder>(&ExecBuilderNode::Get);

} // namespace relax
} // namespace tvm
35 changes: 35 additions & 0 deletions src/relax/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ Instruction ExecutableNode::GetInstruction(Index i) const {
RegName result = instr_data[offset + 1];
return Instruction::Ret(result);
}
case Opcode::Goto: {
Index pc_offset = instr_data[offset + 1];
return Instruction::Goto(pc_offset);
}
case Opcode::If: {
RegName test = instr_data[offset + 1];
RegName target = instr_data[offset + 2];
Index true_branch = instr_data[offset + 3];
Index false_branch = instr_data[offset + 4];
return Instruction::If(test, target, true_branch, false_branch);
}
default:
LOG(FATAL) << "should never hit this case: " << static_cast<int>(op);
break;
Expand Down Expand Up @@ -448,6 +459,19 @@ String ExecutableNode::AsText() const {
<< "ret " << RegNameToStr(instr.result) << "\n";
break;
}
case Opcode::Goto: {
os << std::setw(6) << std::left << "goto"
<< "goto " << instr.pc_offset << "\n";
break;
}
case Opcode::If: {
os << std::setw(6) << std::left << "If"
<< RegNameToStr(instr.test) << ", "
<< RegNameToStr(instr.target) << ", "
<< instr.true_offset << ", "
<< instr.false_offset << "\n";
break;
}
default:
LOG(FATAL) << "should never hit this case: " << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -485,6 +509,17 @@ String ExecutableNode::AsPython() const {
os << " ib.emit_ret(ib.r(" << instr.result << "))\n";
break;
}
case Opcode::Goto: {
os << " ib.emit_goto(" << instr.pc_offset << ")\n";
break;
}
case Opcode::If: {
os << " ib.emit_if(ib.r(" << instr.test
<< "), ib.r(" << instr.target
<< "), " << instr.true_offset
<< ", " << instr.false_offset << ")\n";
break;
}
default:
LOG(FATAL) << "should never hit this case: " << static_cast<int>(instr.op);
break;
Expand Down
16 changes: 16 additions & 0 deletions src/relax/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ void VirtualMachine::RunLoop() {
WriteRegister(caller_return_register, return_value_);
break;
}
case Opcode::Goto: {
pc_ += instr.pc_offset;
break;
}
case Opcode::If: {
int64_t test_val = ReadRegister(instr.test);
int64_t target_val = ReadRegister(instr.target);
if (test_val == target_val) {
ICHECK_NE(instr.true_offset, 0);
pc_ += instr.true_offset;
} else {
ICHECK_NE(instr.false_offset, 0);
pc_ += instr.false_offset;
}
break;
}
}
}
}
Expand Down
Loading

0 comments on commit e87a582

Please sign in to comment.