diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 2405b3c0ba8c..fdbc1769c353 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -126,6 +126,11 @@ class TVM_DLL Executable : public ModuleNode { */ void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit); + /*! + * \brief Get a map of all constants with larger that byte_limit in size. + */ + Map GetLateBoundConstants(size_t byte_limit); + /*! * \brief Restores the late-bound constants for the executable (if any) from given byte-stream. * @@ -134,6 +139,14 @@ class TVM_DLL Executable : public ModuleNode { */ void LoadLateBoundConstantsFromStream(dmlc::Stream* stream); + /*! + * \brief Restores the late-bound constants for the executable (if any) from given map. + * + * Must be called after \p Load but before any other methods if \p MoveLateBoundConstantsToBinary + * was used when saving. Otherwise can be ignored. + */ + void LoadLateBoundConstantsFromMap(Map map); + /*! * \brief As for \p LoadLateBoundConstantsFromStream, but load from file at \p path. */ diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index c065d77a7c9f..615f66fdcc1c 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -86,7 +86,9 @@ def __init__(self, mod): self._get_function_arity = self.mod["get_function_arity"] self._get_function_param_name = self.mod["get_function_param_name"] self._move_late_bound_consts = self.mod["move_late_bound_consts"] + self._get_late_bound_consts = self.mod["get_late_bound_consts"] self._load_late_bound_consts = self.mod["load_late_bound_consts"] + self._load_late_bound_consts_from_map = self.mod["load_late_bound_consts_from_map"] def save(self): """Save the Relay VM Executable. @@ -312,10 +314,18 @@ def move_late_bound_consts(self, path, byte_limit): """Move all constants of byte size greater or equal to byte_limit to file at path""" return self._move_late_bound_consts(path, byte_limit) + def get_late_bound_consts(self, byte_limit): + """Return all constants of byte size greater or equal to byte_limit""" + return self._get_late_bound_consts(byte_limit) + def load_late_bound_consts(self, path): """Re-load constants previously saved to file at path""" return self._load_late_bound_consts(path) + def load_late_bound_consts_from_map(self, map): + """Re-load constants supplied in map""" + return self._load_late_bound_consts_from_map(map) + class VirtualMachine(object): """Relay VM runtime. diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 85dad2839a8a..2484ece3081d 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -97,12 +97,25 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr(byte_limit)); }); + } else if (name == "get_late_bound_consts") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 1); + uint64_t byte_limit = args[0]; + Map consts = GetLateBoundConstants(static_cast(byte_limit)); + *rv = consts; + }); } else if (name == "load_late_bound_consts") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 1); std::string path = args[0]; LoadLateBoundConstantsFromFile(path); }); + } else if (name == "load_late_bound_consts_from_map") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 1); + Map map = args[0]; + LoadLateBoundConstantsFromMap(map); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(); @@ -300,7 +313,7 @@ void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) { strm->Write(host_device_index); } -void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) { +Map Executable::GetLateBoundConstants(size_t byte_limit) { ICHECK(late_bound_constant_names.empty()); late_bound_constant_names.reserve(constants.size()); Map map; @@ -323,6 +336,11 @@ void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byt } VLOG(1) << "moved " << map.size() << " constants of " << total_late_bound_bytes << " bytes (out of " << constants.size() << " overall) to be late-bound"; + return map; +} + +void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) { + Map map = GetLateBoundConstants(byte_limit); runtime::SaveParams(stream, map); } @@ -341,6 +359,10 @@ void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) { ICHECK_EQ(late_bound_constant_names.size(), constants.size()); Map map = runtime::LoadParams(stream); VLOG(1) << "loaded " << map.size() << " late-bound constants"; + LoadLateBoundConstantsFromMap(map); +} + +void Executable::LoadLateBoundConstantsFromMap(Map map) { for (size_t const_index = 0; const_index < constants.size(); ++const_index) { if (!late_bound_constant_names[const_index].defined()) { ICHECK(constants[const_index].defined()) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 4f649ad9beba..0b62db85c904 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -1405,5 +1405,85 @@ def test_vm_save_and_load_without_designating_late_bound_consts(): tvm.testing.assert_allclose(expected, actual.numpy()) +def test_load_and_save_constants_via_map(): + """Large constants can be serialized outside of executable""" + target = tvm.target.Target("llvm") + dev = tvm.cpu() + + # fn(x) { add(x, ) } + x = relay.var("x", shape=(1000, 1000)) + const_data = np.random.rand(1000, 1000).astype("float32") + const = relay.const(const_data, dtype="float32") + func = relay.Function([x], relay.op.add(x, const)) + mod = tvm.IRModule.from_expr(func) + + # Compile to executable. + vm_exec = vm.compile(mod, target=target) + + consts_map = vm_exec.get_late_bound_consts(byte_limit=256) + + # Save to constants and library files + temp = utils.tempdir() + path_dso = temp.relpath("lib.so") + vm_exec.mod.export_library(path_dso) + + # Load library files and constants + mod = runtime.load_module(path_dso) + mod["load_late_bound_consts_from_map"](consts_map) + + # Test main + x_data = np.random.rand(1000, 1000).astype("float32") + the_vm = runtime.vm.VirtualMachine(mod, dev) + actual = the_vm.invoke("main", x_data) + expected = x_data + const_data + tvm.testing.assert_allclose(expected, actual.numpy()) + + # We load the mod again so it's missing the consts. + mod = runtime.load_module(path_dso) + exe = runtime.vm.Executable(mod) + + # Also test loading consts via the VM's wrapper API. + exe.load_late_bound_consts_from_map(consts_map) + + # Test main again with consts now loaded via the above API. + x_data = np.random.rand(1000, 1000).astype("float32") + the_vm = runtime.vm.VirtualMachine(exe, dev) + actual = the_vm.invoke("main", x_data) + expected = x_data + const_data + tvm.testing.assert_allclose(expected, actual.numpy()) + + +def test_load_late_bound_consts_via_map_with_no_late_bound_consts(): + """Check that load_late_bound_consts handles a model with no late bound consts.""" + target = tvm.target.Target("llvm") + dev = tvm.cpu() + + const_data = np.random.rand(1).astype("float64") + x = relay.var("x", shape=(1,), dtype="float64") + const = relay.const(const_data, dtype="float64") + + func = relay.Function([x], relay.op.add(x, const)) + mod = tvm.IRModule.from_expr(func) + + vm_exec = vm.compile(mod, target=target) + + temp = utils.tempdir() + path_dso = temp.relpath("lib.so") + + # Ensure const_data is below the byte threshold for a late-bound const. + byte_limit = len(const_data.tobytes()) + 1 + consts_map = vm_exec.get_late_bound_consts(byte_limit=byte_limit) + vm_exec.mod.export_library(path_dso) + + mod = runtime.load_module(path_dso) + mod["load_late_bound_consts_from_map"](consts_map) + + x_data = np.random.rand(1).astype("float64") + loaded_vm = runtime.vm.VirtualMachine(mod, dev) + actual = loaded_vm.invoke("main", x_data) + expected = x_data + const_data + tvm.testing.assert_allclose(expected, actual.numpy()) + + if __name__ == "__main__": tvm.testing.main()