From 2dfcfbd43ec26fcce1b277b92907be4c70a40d2c Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Tue, 14 Dec 2021 22:38:41 -0800 Subject: [PATCH] [Relay] Support large constants saved/loaded outside of VM executable (#9734) * [Relay] Support large constants. This allows constant tensors at or above a given byte limit to be marked as 'late bound' and saved/reloaded to a file independently of the overall executable. Since the executable is often embedded in the data segment of generated runtime Modules this avoids problems with external tools which can't handle multi-gigabyte data segments. [ACE-466 in OctoML JIRA] * [checkpoint] fix latent bytecode/code bug --- include/tvm/runtime/vm/executable.h | 79 ++++++++++-- include/tvm/runtime/vm/vm.h | 4 +- python/tvm/runtime/vm.py | 14 ++- src/runtime/vm/executable.cc | 187 +++++++++++++++++++++++----- src/runtime/vm/profiler/vm.cc | 6 +- src/runtime/vm/profiler/vm.h | 2 +- src/runtime/vm/vm.cc | 14 ++- tests/python/relay/test_vm.py | 34 +++++ 8 files changed, 284 insertions(+), 56 deletions(-) diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index f07db36dabc2..6359da0a5375 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -68,12 +68,20 @@ class Executable : public ModuleNode { /*! * \brief Write the Executable to the binary stream in serialized form. + * + * Late-bound constants (if any) must have already been saved by \p + * MoveLateBoundConstantsToBinary. + * * \param stream The binary stream to save the executable to. */ void SaveToBinary(dmlc::Stream* stream) final; /*! - * \brief Write the Executable to the provided path as a file contianing its serialized content. + * \brief Write the Executable to the provided path as a file containing its serialized content. + * + * Late-bound constants (if any) must have already been saved by \p + * MoveLateBoundConstantsToBinary. + * * \param path The path to write the serialized data to. * \param format The format of the serialized blob. */ @@ -81,7 +89,10 @@ class Executable : public ModuleNode { /*! * \brief Serialize the executable into global section, constant section, and - * code section. + * code section. This object must outlive the returned byte array. + * + * Late-bound constants (if any) must have already been saved by \p + * MoveLateBoundConstantsToBinary. * * \return The binary representation of the VM. */ @@ -90,6 +101,8 @@ class Executable : public ModuleNode { /*! * \brief Load the saved VM executable. * + * Late-bound constants (if any) must then be loaded by \p LoadLateBoundConstantsFromBinary. + * * \param code The bytecode in string. * \param lib The compiled runtime library. * @@ -97,6 +110,35 @@ class Executable : public ModuleNode { */ static runtime::Module Load(const std::string& code, const runtime::Module lib); + /*! + * \brief Returns the late-bound constants for the executable (if any) as a byte-stream. + * Leaves the executable's late-bound constants map empty. Only constants who's byte + * tensor size is greater than or equal to \p byte_limit are marked as late-bound. \p byte_limit + * may be zero. + * + * Must be called before \p SaveToBinary and friends if late-bound constants are + * desired. Otherwise can be ignore. + */ + void MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit); + + /*! + * \brief As for \p MoveLateBoundConstantsToStream, but save to file at \p path. + */ + void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit); + + /*! + * \brief Restores the late-bound constants for the executable (if any) from given byte-stream. + * + * Must be called after \p Load but before any other methods if \p MoveLateBoundConstantsToBinary + * was used when saving. Otherwise can be ignored. + */ + void LoadLateBoundConstantsFromStream(dmlc::Stream* stream); + + /*! + * \brief As for \p LoadLateBoundConstantsFromStream, but load from file at \p path. + */ + void LoadLateBoundConstantsFromFile(const std::string& path); + /*! * \brief Get the serialized form of the `functions`. This is * essentially bytecode serialization. @@ -125,7 +167,7 @@ class Executable : public ModuleNode { * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). * 4. The rest of the line indicates the field with variable length, e.g., * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. - + * * The field starting from # is only used for debugging. The serialized code * doesn't contain it, therefore the deserializer doens't need to handle it. */ @@ -205,8 +247,19 @@ class Executable : public ModuleNode { * shape-related data and code. */ int host_device_index = -1; - /*! \brief The global constant pool. */ + /*! + * \brief The global constant array. + * + * LoadConst instructions indexes are w.r.t. this vector. Late-bound constants are removed + * from this table after saving late-bound constants. + */ std::vector constants; + /*! + * \brief For each constant index the name of the late-bound constant, or null if constant is + * immediate. Only populated after loading executable but before loading late-bound constants. + */ + std::vector late_bound_constant_names; + /*! \brief A map from globals (as strings) to their index in the Relay function map. */ std::unordered_map global_map; /*! \brief A mapping from the packed function's global name (as string) to the index that @@ -238,9 +291,16 @@ class Executable : public ModuleNode { /*! * \brief Save the constant pool. * - * \param strm The output stream. + * \param stream The output stream. + */ + void SaveConstantSection(dmlc::Stream* stream); + + /*! + * \brief Load the constant pool. + * + * \param stream The input stream. */ - void SaveConstantSection(dmlc::Stream* strm); + void LoadConstantSection(dmlc::Stream* stream); /*! * \brief Save primitive op names. @@ -270,13 +330,6 @@ class Executable : public ModuleNode { */ void LoadGlobalSection(dmlc::Stream* strm); - /*! - * \brief Load the constant pool. - * - * \param strm The input stream. - */ - void LoadConstantSection(dmlc::Stream* strm); - /*! * \brief Load primitive op names. * diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 604c97330d99..67c21a1b479f 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -174,7 +174,7 @@ class VirtualMachine : public runtime::ModuleNode { * \brief load the executable for the virtual machine. * \param exec The executable. */ - virtual void LoadExecutable(const Executable* exec); + virtual void LoadExecutable(Executable* exec); protected: /*! \brief Push a call frame on to the call stack. */ @@ -300,7 +300,7 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The special return register. */ ObjectRef return_register_; /*! \brief The executable the VM will operate on. */ - const Executable* exec_; + Executable* exec_; /*! \brief The function name to inputs mapping. */ std::unordered_map> inputs_; /*! diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 65609ffa2b48..d9cab84a5000 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -77,6 +77,8 @@ def __init__(self, mod): self._get_stats = self.mod["get_stats"] self._get_function_arity = self.mod["get_function_arity"] self._get_function_param_name = self.mod["get_function_param_name"] + self._move_late_bound_consts = self.mod["move_late_bound_consts"] + self._load_late_bound_consts = self.mod["load_late_bound_consts"] def save(self): """Save the Relay VM Executable. @@ -162,11 +164,11 @@ def load_exec(bytecode, lib): An executable constructed using the provided artifacts. """ if isinstance(bytecode, (bytes, str)): - code = bytearray(bytecode) + bytecode = bytearray(bytecode) elif not isinstance(bytecode, (bytearray, TVMByteArray)): raise TypeError( "bytecode is expected to be the type of bytearray " - + "or TVMByteArray, but received {}".format(type(code)) + + "or TVMByteArray, but received {}".format(type(bytecode)) ) if lib is not None and not isinstance(lib, tvm.runtime.Module): @@ -298,6 +300,14 @@ def get_function_params(self, func_name): self._function_params[func_name] = params return params + def move_late_bound_consts(self, path, byte_limit): + """Move all constants of byte size greater or equal to byte_limit to file at path""" + return self._move_late_bound_consts(path, byte_limit) + + def load_late_bound_consts(self, path): + """Re-load constants previously saved to file at path""" + return self._load_late_bound_consts(path, bytes) + class VirtualMachine(object): """Relay VM runtime. diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 44971c0bcee9..76c385ae9918 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -88,6 +88,19 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrLoadExecutable(this); *rv = Module(vm); }); + } else if (name == "move_late_bound_consts") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 2); + std::string path = args[0]; + uint64_t byte_limit = args[1]; + MoveLateBoundConstantsToFile(path, static_cast(byte_limit)); + }); + } else if (name == "load_late_bound_consts") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 1); + std::string path = args[0]; + LoadLateBoundConstantsFromFile(path); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(nullptr); @@ -306,6 +319,68 @@ void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) { strm->Write(host_device_index); } +void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) { + ICHECK(late_bound_constant_names.empty()); + late_bound_constant_names.reserve(constants.size()); + Map map; + size_t total_late_bound_bytes = 0; + for (size_t const_index = 0; const_index < constants.size(); ++const_index) { + const auto ndarray = Downcast(constants[const_index]); + ICHECK(ndarray.defined()) << "Undefined constant at index " << const_index; + size_t num_bytes = runtime::GetDataSize(*ndarray.operator->()); + if (num_bytes < byte_limit) { + // Leave as immediate. + late_bound_constant_names.emplace_back(nullptr); + continue; + } + total_late_bound_bytes += num_bytes; + std::ostringstream os; + os << "const_" << const_index; + String name = os.str(); + map.Set(name, Downcast(std::move(constants[const_index]))); + late_bound_constant_names.emplace_back(std::move(name)); + } + VLOG(1) << "moved " << map.size() << " constants of " << total_late_bound_bytes + << " bytes (out of " << constants.size() << " overall) to be late-bound"; + runtime::SaveParams(stream, map); +} + +void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit) { + std::string bytes; + dmlc::MemoryStringStream stream(&bytes); + MoveLateBoundConstantsToStream(&stream, byte_limit); + SaveBinaryToFile(path, bytes); +} + +void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) { + ICHECK_EQ(late_bound_constant_names.size(), constants.size()); + Map map = runtime::LoadParams(stream); + VLOG(1) << "loaded " << map.size() << " late-bound constants"; + for (size_t const_index = 0; const_index < constants.size(); ++const_index) { + if (!late_bound_constant_names[const_index].defined()) { + ICHECK(constants[const_index].defined()) + << "Undefined immediate constant at index " << const_index; + continue; + } + const String& name = late_bound_constant_names[const_index]; + ICHECK(!constants[const_index].defined()) << "Unexpected constant at index " << const_index; + auto itr = map.find(name); + ICHECK(itr != map.end()) << "No binding for late-bound constant at index " << const_index + << " with name '" << name << "'"; + constants[const_index] = (*itr).second; + map.erase(name); + } + late_bound_constant_names.clear(); + ICHECK(map.empty()) << "Have " << map.size() << " unused late-bound constants"; +} + +void Executable::LoadLateBoundConstantsFromFile(const std::string& path) { + std::string bytes; + LoadBinaryFromFile(path, &bytes); + dmlc::MemoryStringStream stream(&bytes); + LoadLateBoundConstantsFromStream(&stream); +} + void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector> globals(this->global_map.begin(), this->global_map.end()); @@ -321,19 +396,88 @@ void Executable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(glbs); } -void Executable::SaveConstantSection(dmlc::Stream* strm) { - std::vector arrays; - for (const auto& obj : this->constants) { - const auto cell = Downcast(obj); - arrays.push_back(const_cast(cell.operator->())); - } - strm->Write(static_cast(this->constants.size())); - for (const auto& it : arrays) { - runtime::SaveDLTensor(strm, it); +namespace { +// Tags to distinguish immediate vs late-bound constants in constants table bytestream. +constexpr uint32_t kImmediateConstTag = 0; +constexpr uint32_t kLateBoundConstTag = 1; +} // namespace + +void Executable::SaveConstantSection(dmlc::Stream* stream) { + // Save the overall number of constants. + stream->Write(static_cast(constants.size())); + + for (size_t const_index = 0; const_index < constants.size(); ++const_index) { + if (late_bound_constant_names.empty() || !late_bound_constant_names[const_index].defined()) { + // Tag immediate constants by 0. + stream->Write(kImmediateConstTag); + // Write as DLTensor. + const auto ndarray = Downcast(constants[const_index]); + ICHECK(ndarray.defined()); + runtime::SaveDLTensor(stream, ndarray.operator->()); + VLOG(1) << "save " << const_index << " as immediate"; + } else { + // Tag late-bound constants by 1. + const String& name = late_bound_constant_names[const_index]; + ICHECK(!constants[const_index].defined()); + stream->Write(kLateBoundConstTag); + // Write a string. + stream->Write(std::string(name)); + VLOG(1) << "save " << const_index << " as late-bound"; + } } + VLOG(1) << "saved " << constants.size() << " constants"; + // Save the const to device index mapping. - strm->Write(this->const_device_indexes); + stream->Write(const_device_indexes); +} + +void Executable::LoadConstantSection(dmlc::Stream* stream) { + uint64_t sz; + // Load the overall number of constants. + STREAM_CHECK(stream->Read(&sz, sizeof(sz)), "constants table size"); + size_t size = static_cast(sz); + + VLOG(1) << "loading " << size << " constants"; + + constants.resize(size); + late_bound_constant_names.resize(size); + bool any_late_bound = false; + + // Load each of the constants. + for (size_t const_index = 0; const_index < size; const_index++) { + uint32_t tag; + STREAM_CHECK(stream->Read(&tag, sizeof(tag)), "constant tag"); + if (tag == kImmediateConstTag) { + // Immediate constants tagged by 0. + VLOG(1) << "load " << const_index << " as immediate"; + runtime::NDArray ndarray; + STREAM_CHECK(ndarray.Load(stream), "constant tensor"); + constants[const_index] = std::move(ndarray); + late_bound_constant_names[const_index] = String(ObjectPtr(nullptr)); + } else if (tag == kLateBoundConstTag) { + // Late-bound constants tagged by 1. + VLOG(1) << "load " << const_index << " as late-bound"; + std::string name; + STREAM_CHECK(stream->Read(&name), "late-bound constant name"); + constants[const_index] = NDArray(nullptr); + late_bound_constant_names[const_index] = std::move(name); + any_late_bound = true; + } else { + STREAM_CHECK(false, "constant tag"); + } + } + + if (!any_late_bound) { + late_bound_constant_names.clear(); + } + + // Load the const to device index mapping. + std::vector indexes; + indexes.reserve(size); + STREAM_CHECK(stream->Read(&indexes), "constant devices"); + ICHECK_EQ(size, indexes.size()); + const_device_indexes = std::move(indexes); } void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { @@ -597,7 +741,7 @@ runtime::Module Executable::Load(const std::string& code, const runtime::Module auto exec = make_object(); // Support null-initialization of lib, to enable initialization during - // deserialization before we have we have deserialized the imports. + // deserialization before we have deserialized the imports. if (lib.defined()) { exec->SetLib(lib); } @@ -640,27 +784,6 @@ void Executable::LoadGlobalSection(dmlc::Stream* strm) { } } -void Executable::LoadConstantSection(dmlc::Stream* strm) { - uint64_t sz; - // Load the number of constants. - STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); - - size_t size = static_cast(sz); - // Load each of the constants. - for (size_t i = 0; i < size; i++) { - runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm), "constant"); - this->constants.emplace_back(std::move(constant)); - } - - // Load the const to device index mapping. - std::vector const_device_indexes; - const_device_indexes.reserve(size); - STREAM_CHECK(strm->Read(&const_device_indexes), "constant"); - ICHECK_EQ(size, const_device_indexes.size()); - this->const_device_indexes = std::move(const_device_indexes); -} - void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { std::vector primitive_names; STREAM_CHECK(strm->Read(&primitive_names), "primitive name"); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index fe27595052be..67344df7dbe6 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -90,7 +90,7 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, } } -void VirtualMachineDebug::LoadExecutable(const Executable* exec) { +void VirtualMachineDebug::LoadExecutable(Executable* exec) { VirtualMachine::LoadExecutable(exec); ICHECK(exec_); for (auto kv : exec_->primitive_map) { @@ -202,7 +202,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun } } -runtime::Module CreateVirtualMachineDebug(const Executable* exec) { +runtime::Module CreateVirtualMachineDebug(Executable* exec) { auto vm = make_object(); vm->LoadExecutable(exec); return runtime::Module(vm); @@ -210,7 +210,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) { TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; - const auto* exec = dynamic_cast(mod.operator->()); + auto* exec = dynamic_cast(mod.operator->()); ICHECK(exec) << "Virtual machine has not been defined yet." << "\n"; *rv = CreateVirtualMachineDebug(exec); diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 4325fa8a7999..4a09b51fb86e 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -44,7 +44,7 @@ class VirtualMachineDebug : public VirtualMachine { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void LoadExecutable(const Executable* exec) final; + void LoadExecutable(Executable* exec) final; ~VirtualMachineDebug() {} diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index e0570226d455..acbbec0d2991 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -219,6 +219,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } else if (name == "set_input") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); }); + } else if (name == "load_late_bound_consts") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 1); + std::string path = args[0]; + exec_->LoadLateBoundConstantsFromFile(path); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -365,8 +371,10 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In } } -void VirtualMachine::LoadExecutable(const Executable* exec) { +void VirtualMachine::LoadExecutable(Executable* exec) { ICHECK(exec) << "The executable is not created yet."; + ICHECK(exec->late_bound_constant_names.empty()) + << "Need to load late-bound-constants before creating VM"; exec_ = exec; runtime::Module lib = exec_->GetLib(); @@ -753,7 +761,7 @@ void VirtualMachine::RunLoop() { } } -runtime::Module CreateVirtualMachine(const Executable* exec) { +runtime::Module CreateVirtualMachine(Executable* exec) { auto vm = make_object(); vm->LoadExecutable(exec); return runtime::Module(vm); @@ -761,7 +769,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) { TVM_REGISTER_GLOBAL("runtime._VirtualMachine").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; - const auto* exec = dynamic_cast(mod.operator->()); + auto* exec = dynamic_cast(mod.operator->()); ICHECK(exec) << "The virtual machine executable has not been defined yet."; *rv = CreateVirtualMachine(exec); }); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index ea1f4ddb3b62..1c60702982cc 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -1112,6 +1112,40 @@ def test_multi_targets(): tvm.testing.assert_allclose(actual_result.numpy(), expected_result) +def test_large_constants(): + """Large constants can be serialized outside of executable""" + target = tvm.target.Target("llvm") + dev = tvm.cpu() + + # fn(x) { add(x, ) } + x = relay.var("x", shape=(1000, 1000)) + const_data = np.random.rand(1000, 1000).astype("float32") + const = relay.const(const_data, dtype="float32") + func = relay.Function([x], relay.op.add(x, const)) + mod = tvm.IRModule.from_expr(func) + + # Compile to executable. + vm_exec = vm.compile(mod, target=target) + + # Save to constants and library files + temp = utils.tempdir() + path_consts = temp.relpath("consts") + vm_exec.move_late_bound_consts(path_consts, byte_limit=256) + path_dso = temp.relpath("lib.so") + vm_exec.mod.export_library(path_dso) + + # Load library files and constants + mod = runtime.load_module(path_dso) + mod["load_late_bound_consts"](path_consts) + + # Test main + x_data = np.random.rand(1000, 1000).astype("float32") + the_vm = runtime.vm.VirtualMachine(mod, dev) + actual = the_vm.invoke("main", x_data) + expected = x_data + const_data + tvm.testing.assert_allclose(expected, actual.numpy()) + + if __name__ == "__main__": import sys