From b51d0f83a7e071d698019e4803390b0eaa290fc0 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 14 Apr 2021 17:15:34 -0700 Subject: [PATCH] apply feedbacks --- src/runtime/crt/common/crt_runtime_api.c | 14 +++--- src/runtime/crt/host/crt_config.h | 3 -- src/runtime/crt/host/main.cc | 2 +- src/runtime/rpc/rpc_endpoint.cc | 55 +++++++++++++----------- src/runtime/rpc/rpc_endpoint.h | 13 +++++- 5 files changed, 50 insertions(+), 37 deletions(-) diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index c53c8cad8119..93d694e5e81c 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -300,7 +300,7 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { tvm_crt_error_t to_return = FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out); - // For compatibility with C++ + // For compatibility with the C++ runtime equivalent, in src/runtime/registry.cc. if (to_return == kTvmErrorFunctionNameNotFound) { *out = NULL; to_return = kTvmErrorNoError; @@ -378,10 +378,12 @@ int TVMFuncFree(TVMFunctionHandle func) { int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, int* ret_type_code); -// Sends maximum transfer size for RPC. -int RPCGetTransferMaxSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value, - int* ret_type_codes) { - ret_value[0].v_int64 = TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES; +// Sends CRT max packet size. +int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value, + int* ret_type_codes) { + // 11 bytes is for microtvm overhead: + // packet start(2), length(4), session header(3), crc(2) + ret_value[0].v_int64 = TVM_CRT_MAX_PACKET_SIZE_BYTES - 11; ret_type_codes[0] = kTVMArgInt; return 0; } @@ -427,7 +429,7 @@ tvm_crt_error_t TVMInitializeRuntime() { } if (error == kTvmErrorNoError) { - error = TVMFuncRegisterGlobal("tvm.rpc.server.GetTransferMaxSize", &RPCGetTransferMaxSize, 0); + error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0); } if (error != kTvmErrorNoError) { diff --git a/src/runtime/crt/host/crt_config.h b/src/runtime/crt/host/crt_config.h index e6987d96bb84..1644d3251057 100644 --- a/src/runtime/crt/host/crt_config.h +++ b/src/runtime/crt/host/crt_config.h @@ -51,9 +51,6 @@ /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 -/*! Size of the global function for max RPC transfer, in bytes. */ -#define TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES 2048 - // #define TVM_CRT_FRAMER_ENABLE_LOGS #endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index c56d3fb3768a..07bc6d15afc8 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -110,7 +110,7 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { } } -uint8_t memory[2048 * 1024]; +uint8_t memory[1024 * 1024]; static char** g_argv = NULL; diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 48e403384f33..40db8e33c2e8 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -800,14 +800,13 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) std::lock_guard lock(mutex_); RPCCode code = RPCCode::kCopyToRemote; - uint64_t tensor_max_size_bytes = static_cast(GetDataSize(*to)); - ICHECK_LE(to->byte_offset + nbytes, tensor_max_size_bytes) << "Overflow in tensor size."; + uint64_t tensor_total_size_bytes = static_cast(GetDataSize(*to)); + ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes) + << "Overflow in tensor size: (" << to->byte_offset << ", " << nbytes << ", " + << tensor_total_size_bytes << ")"; - uint64_t to_data = reinterpret_cast(static_cast(to->data) + to->byte_offset); - uint64_t shape_bytes = to->ndim * sizeof(int64_t); - uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) + - sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes + - sizeof(nbytes) + nbytes; + uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes); + uint64_t packet_nbytes = overhead + nbytes; handler_->Write(packet_nbytes); handler_->Write(code); @@ -824,11 +823,8 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes uint64_t num_data_bytes = static_cast(GetDataSize(*from)); CHECK_EQ(nbytes, num_data_bytes); - uint64_t from_data = reinterpret_cast(from->data); - uint64_t shape_bytes = from->ndim * sizeof(int64_t); - uint64_t packet_nbytes = sizeof(code) + sizeof(from_data) + sizeof(from->device) + - sizeof(from->ndim) + sizeof(from->dtype) + sizeof(from->byte_offset) + - shape_bytes + sizeof(nbytes); + uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes); + uint64_t packet_nbytes = overhead; handler_->Write(packet_nbytes); handler_->Write(code); @@ -967,10 +963,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { /*! * \brief param endpoint The client endpoint of the session. */ - explicit RPCClientSession(std::shared_ptr endpoint) : endpoint_(endpoint) { - // update max transfer size if not set already. - SetRPCMaxTransferSize(); - } + explicit RPCClientSession(std::shared_ptr endpoint) : endpoint_(endpoint) {} // function overrides PackedFuncHandle GetFunction(const std::string& name) final { @@ -983,7 +976,9 @@ class RPCClientSession : public RPCSession, public DeviceAPI { } void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final { - uint64_t block_size = (uint64_t)rpc_chunk_max_size_bytes_; + RPCCode code = RPCCode::kCopyToRemote; + uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes); + const uint64_t block_size = GetRPCMaxTransferSize() - overhead; uint64_t block_count = 0; uint64_t num_blocks = nbytes / block_size; @@ -1059,26 +1054,36 @@ class RPCClientSession : public RPCSession, public DeviceAPI { private: void RPCMaxTransferRemoteReturnValue(TVMArgs args) { // Use args[1] as return value, args[0] is tcode - rpc_chunk_max_size_bytes_ = (int64_t)args[1]; + rpc_chunk_max_size_bytes_ = (uint64_t)args[1]; } - void SetRPCMaxTransferSize() { - PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetTransferMaxSize"); + uint64_t GetRPCMaxTransferSize() { + PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize"); if (rpc_func == nullptr) { - rpc_chunk_max_size_bytes_ = kRPCMaxTransferSizeDefault; - return; + rpc_chunk_max_size_bytes_ = kRPCMaxTransferSizeBytesDefault; + } else { + CallFunc(rpc_func, nullptr, nullptr, 0, + [this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); }); } - CallFunc(rpc_func, nullptr, nullptr, 0, - [this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); }); + return rpc_chunk_max_size_bytes_; } std::shared_ptr endpoint_; - int64_t rpc_chunk_max_size_bytes_; + uint64_t rpc_chunk_max_size_bytes_; }; std::shared_ptr CreateClientSession(std::shared_ptr endpoint) { return std::make_shared(endpoint); } +uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) { + uint64_t shape_bytes = tensor->ndim * sizeof(int64_t); + uint64_t to_data = reinterpret_cast(static_cast(tensor->data)); + uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) + + sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) + + shape_bytes + sizeof(nbytes); + return overhead; +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 1fcdcf6400ac..32a225f4dbd5 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -48,8 +48,8 @@ const int kRPCSuccess = kRPCMagic + 0; // cannot found matched key in server const int kRPCMismatch = kRPCMagic + 2; -// When tvm.rpc.server.GetTransferMaxSize global function is not registered. -const int kRPCMaxTransferSizeDefault = 128000; +// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered. +const uint64_t kRPCMaxTransferSizeBytesDefault = 128 * 1024; /*! \brief Enumeration code for the RPC tracker */ enum class TrackerCode : int { @@ -207,6 +207,15 @@ template inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { return syscall_remote_(static_cast(code), std::forward(args)...); } + +/*! + * \brief Calculates overhead size of a CopyToRemote packet. + * \param to DLTensor to copy. + * \param code RPCCode for this transfer. + * \param nbytes Number of bytes to transfer. + */ +uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes); + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_