Skip to content

Commit

Permalink
Consolidate RPC Context helper functions (apache#6915)
Browse files Browse the repository at this point in the history
  • Loading branch information
areusch authored and Trevor Morris committed Dec 4, 2020
1 parent c5aebfa commit 0584634
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 39 deletions.
52 changes: 47 additions & 5 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,55 @@ inline const char* DeviceName(int type) {
}
}

/*!
* \brief Return true if a TVMContext is owned by an RPC session.
*/
inline bool IsRPCSessionContext(TVMContext ctx) { return (ctx.device_type / kRPCSessMask) > 0; }

/*!
* \brief Return the RPCSessTable index of the RPC Session that owns this context.
* \return the table index.
*/
inline int GetRPCSessionIndex(TVMContext ctx) {
ICHECK(IsRPCSessionContext(ctx)) << "GetRPCSessionIndex: ctx has no RPC session";
return ctx.device_type / kRPCSessMask - 1;
}

/*!
* \brief Remove the RPC session mask from a TVMContext.
* RPC clients typically do this when encoding a TVMContext for transmission to an RPC remote.
* On the wire, RPCContext are expected to be valid on the server without interpretation.
* \param ctx A TVMContext with non-zero RPC Session mask, valid on the RPC client.
* \return A TVMContext without any RPC Session mask, valid on the RPC server.
*/
inline TVMContext RemoveRPCSessionMask(TVMContext ctx) {
ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
return ctx;
}

inline std::ostream& operator<<(std::ostream& os, DLContext ctx);

/*!
* \brief Add a RPC session mask to a TVMContext.
* RPC clients typically do this when decoding a TVMContext received from a RPC remote.
* \param ctx A TVMContext without any RPC Session mask, valid on the RPC server.
* \param session_table_index Numeric index of the RPC session in the session table.
* \return A TVMContext with RPC session mask added, valid on the RPC client.
*/
inline TVMContext AddRPCSessionMask(TVMContext ctx, int session_table_index) {
CHECK(!IsRPCSessionContext(ctx))
<< "AddRPCSessionMask: ctx already non-zero RPCSessionIndex: " << ctx;
ctx.device_type =
static_cast<DLDeviceType>(ctx.device_type | (kRPCSessMask * (session_table_index + 1)));
return ctx;
}

inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
int device_type = static_cast<int>(ctx.device_type);
if (device_type > kRPCSessMask) {
os << "remote[" << (device_type / kRPCSessMask) << "]-";
device_type = device_type % kRPCSessMask;
if (IsRPCSessionContext(ctx)) {
os << "remote[" << GetRPCSessionIndex(ctx) << "]-";
ctx = RemoveRPCSessionMask(ctx);
}
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
os << runtime::DeviceName(static_cast<int>(ctx.device_type)) << "(" << ctx.device_id << ")";
return os;
}
} // namespace runtime
Expand Down
35 changes: 13 additions & 22 deletions src/runtime/rpc/rpc_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ namespace runtime {
class RPCDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {
auto remote_ctx = RemoveSessMask(ctx);
auto remote_ctx = RemoveRPCSessionMask(ctx);
GetSess(ctx)->GetDeviceAPI(remote_ctx)->SetDevice(remote_ctx);
}

void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
auto remote_ctx = RemoveSessMask(ctx);
auto remote_ctx = RemoveRPCSessionMask(ctx);
GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv);
}

void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
auto sess = GetSess(ctx);
auto remote_ctx = RemoveSessMask(ctx);
auto remote_ctx = RemoveRPCSessionMask(ctx);
void* data =
sess->GetDeviceAPI(remote_ctx)->AllocDataSpace(remote_ctx, nbytes, alignment, type_hint);

Expand All @@ -57,7 +57,7 @@ class RPCDeviceAPI final : public DeviceAPI {
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr);
auto remote_ctx = RemoveSessMask(ctx);
auto remote_ctx = RemoveRPCSessionMask(ctx);
try {
GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(remote_ctx, space->data);
} catch (const dmlc::Error& e) {
Expand All @@ -68,26 +68,24 @@ class RPCDeviceAPI final : public DeviceAPI {
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final {
int from_dev_type = ctx_from.device_type;
int to_dev_type = ctx_to.device_type;
if (from_dev_type > kRPCSessMask && to_dev_type > kRPCSessMask) {
if (IsRPCSessionContext(ctx_from) && IsRPCSessionContext(ctx_to)) {
ICHECK(ctx_from.device_type == ctx_to.device_type)
<< "Cannot copy across two different remote session";
auto remote_ctx_from = RemoveSessMask(ctx_from);
auto remote_ctx_to = RemoveSessMask(ctx_to);
auto remote_ctx_from = RemoveRPCSessionMask(ctx_from);
auto remote_ctx_to = RemoveRPCSessionMask(ctx_to);
auto remote_ctx = remote_ctx_from;
if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to;
GetSess(ctx_from)
->GetDeviceAPI(remote_ctx)
->CopyDataFromTo(static_cast<const RemoteSpace*>(from)->data, from_offset,
static_cast<const RemoteSpace*>(to)->data, to_offset, size,
remote_ctx_from, remote_ctx_to, type_hint, stream);
} else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) {
auto remote_ctx_from = RemoveSessMask(ctx_from);
} else if (IsRPCSessionContext(ctx_from) && ctx_to.device_type == kDLCPU) {
auto remote_ctx_from = RemoveRPCSessionMask(ctx_from);
GetSess(ctx_from)->CopyFromRemote(static_cast<const RemoteSpace*>(from)->data, from_offset,
to, to_offset, size, remote_ctx_from, type_hint);
} else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) {
auto remote_ctx_to = RemoveSessMask(ctx_to);
} else if (ctx_from.device_type == kDLCPU && IsRPCSessionContext(ctx_to)) {
auto remote_ctx_to = RemoveRPCSessionMask(ctx_to);
GetSess(ctx_to)->CopyToRemote(const_cast<void*>(from), from_offset,
static_cast<const RemoteSpace*>(to)->data, to_offset, size,
remote_ctx_to, type_hint);
Expand All @@ -97,22 +95,15 @@ class RPCDeviceAPI final : public DeviceAPI {
}

void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
auto remote_ctx = RemoveSessMask(ctx);
auto remote_ctx = RemoveRPCSessionMask(ctx);
GetSess(ctx)->GetDeviceAPI(remote_ctx)->StreamSync(remote_ctx, stream);
}

private:
std::shared_ptr<RPCSession> GetSess(TVMContext ctx) {
int dev_type = ctx.device_type;
ICHECK_GE(dev_type, kRPCSessMask);
int tbl_index = dev_type / kRPCSessMask - 1;
int tbl_index = GetRPCSessionIndex(ctx);
return RPCSession::Get(tbl_index);
}

static TVMContext RemoveSessMask(TVMContext ctx) {
ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
return ctx;
}
};

TVM_REGISTER_GLOBAL("device_api.rpc").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
<< args[i].AsObjectRef<ObjectRef>()->GetTypeKey() << " is not supported by RPC";
} else if (tcode == kTVMContext) {
DLContext ctx = args[i];
ICHECK_LT(static_cast<int>(ctx.device_type), kRPCSessMask)
ICHECK(!IsRPCSessionContext(ctx))
<< "InternalError: cannot pass RPC context in the channel";
}
}
Expand Down
19 changes: 8 additions & 11 deletions src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,10 @@ class RPCWrappedFunc : public Object {

// remove a remote session mask
TVMContext RemoveSessMask(TVMContext ctx) const {
int dev_type = ctx.device_type;
ICHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1)
<< "Can not pass in local context or context with a different remote session";
ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
return ctx;
ICHECK(IsRPCSessionContext(ctx)) << "Can not pass in local context";
ICHECK_EQ(GetRPCSessionIndex(ctx), sess_->table_index())
<< "Can not pass in context with a different remote session";
return RemoveRPCSessionMask(ctx);
}

// deleter of RPC remote array
Expand Down Expand Up @@ -141,13 +140,12 @@ class RPCWrappedFunc : public Object {
// setup dtype
data->dl_tensor.dtype = tensor->dtype;
// setup ctx, encode as remote session
data->dl_tensor.ctx.device_id = tensor->ctx.device_id;
data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>(
static_cast<int>(tensor->ctx.device_type) + kRPCSessMask * (sess_->table_index() + 1));
data->dl_tensor.ctx = AddRPCSessionMask(tensor->ctx, sess_->table_index());
// check strides.
ICHECK(tensor->strides == nullptr);
// setup byteoffset
data->dl_tensor.byte_offset = tensor->byte_offset;

return ret;
}
};
Expand Down Expand Up @@ -189,10 +187,9 @@ class RPCModuleNode final : public ModuleNode {
int min_repeat_ms, const std::string& f_preproc_name) {
InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator");
// Remove session mask because we pass ctx by parts.
int dev_type = ctx.device_type;
ICHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1)
ICHECK_EQ(GetRPCSessionIndex(ctx), sess_->table_index())
<< "ValueError: Need to pass the matched remote context to RPCModule.GetTimeEvaluator";
ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
ctx = RemoveRPCSessionMask(ctx);

if (module_handle_ != nullptr) {
return remote_get_time_evaluator_(GetRef<Module>(this), name,
Expand Down

0 comments on commit 0584634

Please sign in to comment.