Skip to content

Commit

Permalink
apply feedbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Apr 15, 2021
1 parent 75c1c0a commit 447aa46
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 37 deletions.
14 changes: 8 additions & 6 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
tvm_crt_error_t to_return =
FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out);
// For compatibility with C++
// For compatibility with the C++ runtime equivalent, in src/runtime/registry.cc.
if (to_return == kTvmErrorFunctionNameNotFound) {
*out = NULL;
to_return = kTvmErrorNoError;
Expand Down Expand Up @@ -378,10 +378,12 @@ int TVMFuncFree(TVMFunctionHandle func) {
int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);

// Sends maximum transfer size for RPC.
int RPCGetTransferMaxSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value,
int* ret_type_codes) {
ret_value[0].v_int64 = TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES;
// Sends CRT max packet size.
int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value,
int* ret_type_codes) {
// 11 bytes is for microtvm overhead:
// packet start(2), length(4), session header(3), crc(2)
ret_value[0].v_int64 = TVM_CRT_MAX_PACKET_SIZE_BYTES - 11;
ret_type_codes[0] = kTVMArgInt;
return 0;
}
Expand Down Expand Up @@ -427,7 +429,7 @@ tvm_crt_error_t TVMInitializeRuntime() {
}

if (error == kTvmErrorNoError) {
error = TVMFuncRegisterGlobal("tvm.rpc.server.GetTransferMaxSize", &RPCGetTransferMaxSize, 0);
error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0);
}

if (error != kTvmErrorNoError) {
Expand Down
3 changes: 0 additions & 3 deletions src/runtime/crt/host/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@
/*! \brief Maximum length of a PackedFunc function name. */
#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30

/*! Size of the global function for max RPC transfer, in bytes. */
#define TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES 2048

// #define TVM_CRT_FRAMER_ENABLE_LOGS

#endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_
2 changes: 1 addition & 1 deletion src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
}
}

uint8_t memory[2048 * 1024];
uint8_t memory[1024 * 1024];

static char** g_argv = NULL;

Expand Down
55 changes: 30 additions & 25 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -800,14 +800,13 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyToRemote;

uint64_t tensor_max_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_LE(to->byte_offset + nbytes, tensor_max_size_bytes) << "Overflow in tensor size.";
uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes)
<< "Overflow in tensor size: (" << to->byte_offset << ", " << nbytes << ", "
<< tensor_total_size_bytes << ")";

uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<char*>(to->data) + to->byte_offset);
uint64_t shape_bytes = to->ndim * sizeof(int64_t);
uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) +
sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes +
sizeof(nbytes) + nbytes;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes);
uint64_t packet_nbytes = overhead + nbytes;

handler_->Write(packet_nbytes);
handler_->Write(code);
Expand All @@ -824,11 +823,8 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes
uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*from));
CHECK_EQ(nbytes, num_data_bytes);

uint64_t from_data = reinterpret_cast<uint64_t>(from->data);
uint64_t shape_bytes = from->ndim * sizeof(int64_t);
uint64_t packet_nbytes = sizeof(code) + sizeof(from_data) + sizeof(from->device) +
sizeof(from->ndim) + sizeof(from->dtype) + sizeof(from->byte_offset) +
shape_bytes + sizeof(nbytes);
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes);
uint64_t packet_nbytes = overhead;

handler_->Write(packet_nbytes);
handler_->Write(code);
Expand Down Expand Up @@ -967,10 +963,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
/*!
* \brief param endpoint The client endpoint of the session.
*/
explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {
// update max transfer size if not set already.
SetRPCMaxTransferSize();
}
explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {}

// function overrides
PackedFuncHandle GetFunction(const std::string& name) final {
Expand All @@ -983,7 +976,9 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
}

void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
uint64_t block_size = (uint64_t)rpc_chunk_max_size_bytes_;
RPCCode code = RPCCode::kCopyToRemote;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes);
const uint64_t block_size = GetRPCMaxTransferSize() - overhead;
uint64_t block_count = 0;
uint64_t num_blocks = nbytes / block_size;

Expand Down Expand Up @@ -1059,26 +1054,36 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
private:
void RPCMaxTransferRemoteReturnValue(TVMArgs args) {
// Use args[1] as return value, args[0] is tcode
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
rpc_chunk_max_size_bytes_ = (uint64_t)args[1];
}

void SetRPCMaxTransferSize() {
PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetTransferMaxSize");
uint64_t GetRPCMaxTransferSize() {
PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize");
if (rpc_func == nullptr) {
rpc_chunk_max_size_bytes_ = kRPCMaxTransferSizeDefault;
return;
rpc_chunk_max_size_bytes_ = kRPCMaxTransferSizeBytesDefault;
} else {
CallFunc(rpc_func, nullptr, nullptr, 0,
[this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); });
}
CallFunc(rpc_func, nullptr, nullptr, 0,
[this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); });
return rpc_chunk_max_size_bytes_;
}

std::shared_ptr<RPCEndpoint> endpoint_;
int64_t rpc_chunk_max_size_bytes_;
uint64_t rpc_chunk_max_size_bytes_;
};

std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
return std::make_shared<RPCClientSession>(endpoint);
}

uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) {
uint64_t shape_bytes = tensor->ndim * sizeof(int64_t);
uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<uint8_t*>(tensor->data));
uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) +
sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) +
shape_bytes + sizeof(nbytes);
return overhead;
}

} // namespace runtime
} // namespace tvm
13 changes: 11 additions & 2 deletions src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ const int kRPCSuccess = kRPCMagic + 0;
// cannot found matched key in server
const int kRPCMismatch = kRPCMagic + 2;

// When tvm.rpc.server.GetTransferMaxSize global function is not registered.
const int kRPCMaxTransferSizeDefault = 128000;
// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
const uint64_t kRPCMaxTransferSizeBytesDefault = 128 * 1024;

/*! \brief Enumeration code for the RPC tracker */
enum class TrackerCode : int {
Expand Down Expand Up @@ -207,6 +207,15 @@ template <typename... Args>
inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) {
return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...);
}

/*!
* \brief Calculates overhead size of a CopyToRemote packet.
* \param to DLTensor on target to copy.
* \param code RPCCode for this transfer.
* \param nbytes Number of bytes to transfer.
*/
uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes);

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_

0 comments on commit 447aa46

Please sign in to comment.