diff --git a/include/tvm/relax/vm/bytecode.h b/include/tvm/relax/vm/bytecode.h index a980705194b4..19654c05c038 100644 --- a/include/tvm/relax/vm/bytecode.h +++ b/include/tvm/relax/vm/bytecode.h @@ -57,6 +57,8 @@ using Index = ExecWord; enum class Opcode { Call = 1U, Ret = 2U, + Goto = 3U, + If = 4U, }; @@ -135,6 +137,20 @@ struct Instruction { /*! \brief The return result. */ RegName result; }; + struct /* Goto */ { + /*! \brief The jump offset. */ + Index pc_offset; + }; + struct /* If */ { + /*! \brief The register containing the test value. */ + RegName test; + /*! \brief The register containing the target value. */ + RegName target; + /*! \brief The program counter offset for the true branch. */ + Index true_offset; + /*! \brief The program counter offset for the false branch. */ + Index false_offset; + }; }; /*! * \brief Construct a Call instruction. @@ -153,6 +169,21 @@ struct Instruction { * \return The return instruction. */ static Instruction Ret(RegName result); + /*! + * \brief Construct a goto instruction. + * \param pc_offset The register containing the jump offset. + * \return The goto instruction. + */ + static Instruction Goto(RegName pc_offset); + /*! + * \brief Construct an If instruction. + * \param test The register containing the test value. + * \param target The register containing the target value. + * \param true_offset The program counter offset for the true branch. + * \param false_offset The program counter offset for the false branch. + * \return The If instruction. + */ + static Instruction If(RegName test, RegName target, Index true_offset, Index false_offset); }; } // namespace relax_vm diff --git a/include/tvm/relax/vm/exec_builder.h b/include/tvm/relax/vm/exec_builder.h index c59cba99d86e..1f85245207a8 100644 --- a/include/tvm/relax/vm/exec_builder.h +++ b/include/tvm/relax/vm/exec_builder.h @@ -65,6 +65,19 @@ class ExecBuilderNode : public Object { * \param result The return result. */ void EmitRet(vm::RegName result); + /*! + * \brief Emit a goto instruction. + * \param pc_offset The program counter offset as the jump offset. + */ + void EmitGoto(vm::Index pc_offset); + /*! + * \brief Emit an If instruction. + * \param test The register containing the test value. + * \param target The register containing the target value. + * \param true_offset The program counter offset for the true branch. + * \param false_offset The program counter offset for the false branch. + */ + void EmitIf(vm::RegName test, vm::RegName target, vm::Index true_offset, vm::Index false_offset); /*! * \brief Emit a constant value to the constant pool. * \return The index that represents the constant. diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 53a621d0b29f..314f10737da3 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -106,6 +106,16 @@ def emit_ret(self, result: int) -> None: self._check_scope() _ffi_api.ExecBuilderEmitRet(self, result) + def emit_goto(self, pc_offset): + """emit a goto instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitGoto(self, pc_offset) + + def emit_if(self, test, target, true_offset, false_offset): + """emit an if instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitIf(self, test, target, true_offset, false_offset) + def get(self) -> Executable: """return the executable""" return _ffi_api.ExecBuilderGet(self) diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc index d4f158c34f88..6088d29bc4ba 100644 --- a/src/relax/vm/builtin.cc +++ b/src/relax/vm/builtin.cc @@ -133,7 +133,7 @@ TVM_REGISTER_GLOBAL("vm.call_tir_dyn") } ShapeTuple to_unpack = args[args.size() - 1]; - int num_tensor_args = args.size() - 3; + size_t num_tensor_args = args.size() - 3; std::vector values(num_tensor_args + to_unpack.size()); std::vector tcodes(num_tensor_args + to_unpack.size()); runtime::TVMArgsSetter setter(values.data(), tcodes.data()); diff --git a/src/relax/vm/bytecode.cc b/src/relax/vm/bytecode.cc index 6da75f3893dd..97993aebfa6f 100644 --- a/src/relax/vm/bytecode.cc +++ b/src/relax/vm/bytecode.cc @@ -50,6 +50,22 @@ Instruction Instruction::Ret(RegName result) { return instr; } +Instruction Instruction::Goto(Index pc_offset) { + Instruction instr; + instr.op = Opcode::Goto; + instr.pc_offset = pc_offset; + return instr; +} + +Instruction Instruction::If(RegName test, RegName target, Index true_branch, Index false_branch) { + Instruction instr; + instr.op = Opcode::If; + instr.test = test; + instr.target = target; + instr.true_offset = true_branch; + instr.false_offset = false_branch; + return instr; +} } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/src/relax/vm/exec_builder.cc b/src/relax/vm/exec_builder.cc index 2b2545f9663d..29414f329290 100644 --- a/src/relax/vm/exec_builder.cc +++ b/src/relax/vm/exec_builder.cc @@ -78,6 +78,22 @@ void ExecBuilderNode::EmitRet(RegName result) { exec->instr_data.push_back(result); } +void ExecBuilderNode::EmitGoto(Index pc_offset) { + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::Goto)); + exec->instr_data.push_back(pc_offset); +} + +void ExecBuilderNode::EmitIf(vm::RegName test, vm::RegName target, vm::Index true_offset, + vm::Index false_offset){ + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::If)); + exec->instr_data.push_back(test); + exec->instr_data.push_back(target); + exec->instr_data.push_back(true_offset); + exec->instr_data.push_back(false_offset); +} + // helper function to check if an executable is legal by checking if registers are used properly bool CheckExecutable(Executable exec) { for (auto it = exec->global_funcs.cbegin(); it != exec->global_funcs.cend(); ++it) { @@ -120,6 +136,17 @@ bool CheckExecutable(Executable exec) { } break; } + case Opcode::Goto: { + ICHECK_GT(instr.pc_offset, 0); + break; + } + case Opcode::If: { + ICHECK_GT(instr.true_offset, 0); + ICHECK_GT(instr.false_offset, 0); + arg_registers.emplace(instr.test); + arg_registers.emplace(instr.target); + break; + } default: LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); break; @@ -168,6 +195,12 @@ void ExecBuilderNode::Formalize() { } break; } + case Opcode::Goto: { + break; + } + case Opcode::If: { + break; + } default: LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); break; @@ -189,9 +222,7 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitConstant") }); TVM_REGISTER_GLOBAL("relax.ExecBuilderFunction") -.set_body_typed([](ExecBuilder builder, String name, int64_t num_inputs) { - return builder->EmitFunction(name, num_inputs); -}); +.set_body_method(&ExecBuilderNode::EmitFunction); TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { @@ -205,9 +236,13 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") }); TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") -.set_body_typed([](ExecBuilder builder, int64_t result) { - builder->EmitRet(result); -}); +.set_body_method(&ExecBuilderNode::EmitRet); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto") +.set_body_method(&ExecBuilderNode::EmitGoto); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") +.set_body_method(&ExecBuilderNode::EmitIf); TVM_REGISTER_GLOBAL("relax.ExecBuilderR") .set_body_typed([](ExecBuilder builder, int64_t value) { @@ -225,9 +260,7 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderC") }); TVM_REGISTER_GLOBAL("relax.ExecBuilderGet") -.set_body_typed([](ExecBuilder builder) { - return builder->Get(); -}); +.set_body_method(&ExecBuilderNode::Get); } // namespace relax } // namespace tvm diff --git a/src/relax/vm/executable.cc b/src/relax/vm/executable.cc index f291373f9d5d..01a6e9b42b80 100644 --- a/src/relax/vm/executable.cc +++ b/src/relax/vm/executable.cc @@ -138,6 +138,17 @@ Instruction ExecutableNode::GetInstruction(Index i) const { RegName result = instr_data[offset + 1]; return Instruction::Ret(result); } + case Opcode::Goto: { + Index pc_offset = instr_data[offset + 1]; + return Instruction::Goto(pc_offset); + } + case Opcode::If: { + RegName test = instr_data[offset + 1]; + RegName target = instr_data[offset + 2]; + Index true_branch = instr_data[offset + 3]; + Index false_branch = instr_data[offset + 4]; + return Instruction::If(test, target, true_branch, false_branch); + } default: LOG(FATAL) << "should never hit this case: " << static_cast(op); break; @@ -448,6 +459,19 @@ String ExecutableNode::AsText() const { << "ret " << RegNameToStr(instr.result) << "\n"; break; } + case Opcode::Goto: { + os << std::setw(6) << std::left << "goto" + << "goto " << instr.pc_offset << "\n"; + break; + } + case Opcode::If: { + os << std::setw(6) << std::left << "If" + << RegNameToStr(instr.test) << ", " + << RegNameToStr(instr.target) << ", " + << instr.true_offset << ", " + << instr.false_offset << "\n"; + break; + } default: LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); break; @@ -485,6 +509,17 @@ String ExecutableNode::AsPython() const { os << " ib.emit_ret(ib.r(" << instr.result << "))\n"; break; } + case Opcode::Goto: { + os << " ib.emit_goto(" << instr.pc_offset << ")\n"; + break; + } + case Opcode::If: { + os << " ib.emit_if(ib.r(" << instr.test + << "), ib.r(" << instr.target + << "), " << instr.true_offset + << ", " << instr.false_offset << ")\n"; + break; + } default: LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); break; diff --git a/src/relax/vm/vm.cc b/src/relax/vm/vm.cc index 2ebc4369116c..8a9852a07ac6 100644 --- a/src/relax/vm/vm.cc +++ b/src/relax/vm/vm.cc @@ -157,6 +157,22 @@ void VirtualMachine::RunLoop() { WriteRegister(caller_return_register, return_value_); break; } + case Opcode::Goto: { + pc_ += instr.pc_offset; + break; + } + case Opcode::If: { + int64_t test_val = ReadRegister(instr.test); + int64_t target_val = ReadRegister(instr.target); + if (test_val == target_val) { + ICHECK_NE(instr.true_offset, 0); + pc_ += instr.true_offset; + } else { + ICHECK_NE(instr.false_offset, 0); + pc_ += instr.false_offset; + } + break; + } } } } diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 17a8f3a3ee15..58ac63abd3eb 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -28,6 +28,7 @@ from tvm.script import tir as T, relax as R from tvm.relax.testing import nn + @tvm.register_func("test.vm.move") def move(src): return src @@ -128,8 +129,12 @@ def test_vm_constant_serialize(): inp = tvm.nd.array(np.random.rand(4, 6).astype(np.float32)) ib = relax.ExecBuilder() with ib.function("main", num_inputs=1): - ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), ib.imm(1), dtype], dst=ib.r(1)) - ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), shape, ib.imm(0), dtype], dst=ib.r(2)) + ib.emit_call( + "vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), ib.imm(1), dtype], dst=ib.r(1) + ) + ib.emit_call( + "vm.builtin.alloc_tensor", args=[ib.r(1), shape, ib.imm(0), dtype], dst=ib.r(2) + ) ib.emit_call("test.vm.identity", args=[ib.r(0), ib.r(2)]) ib.emit_ret(ib.r(2)) exec0 = ib.get() @@ -218,17 +223,69 @@ def test_vm_storage(): shape = (4, 6) ib = relax.ExecBuilder() with ib.function("main", num_inputs=0): - ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), ib.imm(1), dtype], dst=ib.r(1)) - ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), shape, ib.imm(0), dtype], dst=ib.r(2)) + ib.emit_call( + "vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), ib.imm(1), dtype], dst=ib.r(1) + ) + ib.emit_call( + "vm.builtin.alloc_tensor", args=[ib.r(1), shape, ib.imm(0), dtype], dst=ib.r(2) + ) ib.emit_ret(ib.r(2)) ex = ib.get() vm = relax.VirtualMachine(ex, tvm.cpu()) - shape_tuple = container.ShapeTuple(shape) res = vm["main"]() assert res.device == tvm.cpu() assert res.shape == shape +def test_vm_goto(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(2), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = vm["main"](a, b) + np.testing.assert_allclose(res.asnumpy(), a.asnumpy() + b.asnumpy()) + + +def test_vm_if(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=4): + ib.emit_if(ib.r(0), ib.r(1), 1, 3) + ib.emit_call("test.vm.add", args=[ib.r(2), ib.r(3)], dst=ib.r(4)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(2), ib.r(3)], dst=ib.r(4)) + ib.emit_ret(ib.r(4)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = vm["main"](True, False, a, b) + np.testing.assert_allclose(res.asnumpy(), a.asnumpy() * b.asnumpy()) + res = vm["main"](1, 1, a, b) + np.testing.assert_allclose(res.asnumpy(), a.asnumpy() + b.asnumpy()) + + def test_vm_compile_stage0(): @tvm.script.ir_module class TestVMCompileStage0: @@ -240,8 +297,8 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): mod = TestVMCompileStage0 target = tvm.target.Target("llvm", host="llvm") ex, lib = relax.vm.build(mod, target) - inp1 = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) - inp2 = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) + inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) vm["foo"](inp1, inp2) np.testing.assert_allclose(inp2.numpy(), inp1.numpy()) @@ -263,12 +320,8 @@ def shape_func0(heap: T.handle) -> None: offset_factor=1, ) # body - T.store( - H.data, T.int64(2), (T.load("int64", H.data, T.int64(0)) * T.int64(2)), True - ) - T.store( - H.data, T.int64(3), (T.load("int64", H.data, T.int64(1)) * T.int64(3)), True - ) + T.store(H.data, T.int64(2), (T.load("int64", H.data, T.int64(0)) * T.int64(2)), True) + T.store(H.data, T.int64(3), (T.load("int64", H.data, T.int64(1)) * T.int64(3)), True) @R.function def foo(x: Tensor[_, "float32"]) -> Shape: @@ -356,6 +409,7 @@ def foo(x: Tensor[_, "float32"]) -> Tensor: res = vm["foo"](inp) np.testing.assert_allclose(np.tile(inp.numpy(), (1, 2)), res.numpy()) + def test_vm_compile_e2e_func_param_with_shape(): @tvm.script.ir_module class TestVMCompileE2E2: @@ -365,9 +419,9 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: m = T.var("int32") n = T.var("int32") k = T.var("int32") - A = T.match_buffer(x, (m,n)) - B = T.match_buffer(y, (n,k)) - C = T.match_buffer(z, (m,k)) + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) for i, j, k in T.grid(m, k, n): with T.block("matmul"): @@ -377,11 +431,10 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @R.function - def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: + def func(x: Tensor[(m, n), "float32"], w: Tensor[(n, k), "float32"]) -> Tensor: gv0 = R.call_tir((m, k), tir_matmul, (x, w)) return gv0 - mod = TestVMCompileE2E2 target = tvm.target.Target("llvm", host="llvm") @@ -401,11 +454,11 @@ def test_vm_emit_te_extern(): type_anno = relax.DynTensorType(2, "float32") x = relax.Var("x", [n, m], type_anno) y = relax.Var("y", [m, n], type_anno) - + with bb.function("rx_cblas_matmul", [x, y]): out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) bb.emit_func_output(out) - + mod = bb.get() target = tvm.target.Target("llvm", host="llvm") @@ -418,6 +471,7 @@ def test_vm_emit_te_extern(): expected = np.dot(data.numpy(), weight.numpy()) np.testing.assert_allclose(expected, res.numpy(), rtol=1e-4, atol=1e-4) + def test_vm_emit_te_concat(): # concatenate of two vectors of size (n,) and (m,) bb = relax.BlockBuilder() @@ -427,7 +481,7 @@ def test_vm_emit_te_concat(): y = relax.Var("y", [m], type_anno) def te_func(A, B): - C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i-n])) + C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i - n])) return C with bb.function("rx_func", [x, y]): @@ -440,12 +494,21 @@ def te_func(A, B): ex, lib = relax.vm.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) - inp = tvm.nd.array(np.random.rand(1, ).astype(np.float32)) - inp2 = tvm.nd.array(np.random.rand(2, ).astype(np.float32)) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + inp2 = tvm.nd.array( + np.random.rand( + 2, + ).astype(np.float32) + ) res = vm["rx_func"](inp, inp2) np.testing.assert_allclose(res.numpy(), np.append(inp.numpy(), inp2.numpy())) + def test_vm_emit_te_floor_symbolic_shape(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") @@ -466,16 +529,17 @@ def te_func(A): ex, lib = relax.vm.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) - shape = (9, ) + shape = (9,) inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) res = vm["rx_func"](inp) def expected_output(): - output_shape = (shape[0] // 2, ) - return inp.numpy()[:output_shape[0]] + 1 + output_shape = (shape[0] // 2,) + return inp.numpy()[: output_shape[0]] + 1 np.testing.assert_allclose(res.numpy(), expected_output()) + def test_vm_relax_symbolic_shape(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") @@ -484,7 +548,7 @@ def test_vm_relax_symbolic_shape(): y = relax.Var("y", [(n // 2) + 1], type_anno) def te_func(A, B): - C = te.compute((n, ), lambda i: A[i] + B[i // 2]) + C = te.compute((n,), lambda i: A[i] + B[i // 2]) return C with bb.function("rx_func", [x, y]): @@ -497,8 +561,8 @@ def te_func(A, B): ex, lib = relax.vm.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) - shape1 = (5, ) - shape2 = (3, ) + shape1 = (5,) + shape2 = (3,) inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32)) inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32)) res = vm["rx_func"](inp, inp2) @@ -508,6 +572,7 @@ def expected_output(): np.testing.assert_allclose(res.numpy(), expected_output()) + def test_vm_relax_dyn_tir_shape(): # case where TIR variables are unbound in generated PrimFunc bb = relax.BlockBuilder() @@ -532,7 +597,7 @@ def te_func(A): ex.save_to_file("exec.tmp") exec1 = relax.load_exec_from_file("exec.tmp") assert ex.astext() == exec1.astext() - + vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) @@ -542,6 +607,7 @@ def te_func(A): np.testing.assert_allclose(res.asnumpy(), inp2.asnumpy()) os.remove("exec.tmp") + if __name__ == "__main__": test_vm_execute() test_vm_multiple_func() @@ -552,6 +618,8 @@ def te_func(A): test_vm_constant_serialize() test_vm_shapeof() test_vm_storage() + test_vm_goto() + test_vm_if() test_vm_compile_stage0() test_vm_compile_stage1() test_vm_compile_stage2() @@ -562,4 +630,4 @@ def te_func(A): test_vm_emit_te_concat() test_vm_emit_te_floor_symbolic_shape() test_vm_relax_symbolic_shape() - test_vm_relax_dyn_tir_shape() \ No newline at end of file + test_vm_relax_dyn_tir_shape()