Skip to content

Commit

Permalink
move timeout management logic out of preSendValidation method
Browse files Browse the repository at this point in the history
Summary:
in this diff, we move timeout management logic out of `preSendValidation` method

the logic to compute clientTimeout is pulled to `getClientTimeout`
the logic to skip setting timeouts in RpcMetadata if `rpcOptions.getClientOnlyTimeouts()` is set is moved to `makeRequestRpcMetadata`

Reviewed By: robertroeser

Differential Revision: D63304569

fbshipit-source-id: 5dfec5496a64e6b9e133ae4bad26ecb7820e3db7
  • Loading branch information
avalonalex authored and facebook-github-bot committed Sep 25, 2024
1 parent 5fcf01b commit f10a1a1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 34 deletions.
45 changes: 23 additions & 22 deletions thrift/lib/cpp2/async/RocketClientChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,17 +857,17 @@ void RocketClientChannel::sendRequestStream(
return;
}
preprocessHeader(header.get());
auto firstResponseTimeout = getClientTimeout(rpcOptions);

auto metadata = apache::thrift::detail::makeRequestRpcMetadata(
rpcOptions,
RpcKind::SINGLE_REQUEST_STREAMING_RESPONSE,
static_cast<ProtocolId>(header->getProtocolId()),
methodMetadata.name_managed(),
timeout_,
firstResponseTimeout,
*header);

std::chrono::milliseconds firstResponseTimeout;
preSendValidation(metadata, rpcOptions, firstResponseTimeout);
preSendValidation(metadata, rpcOptions);

auto buf = std::move(request.buffer);
setCompression(metadata, buf->computeChainDataLength());
Expand All @@ -880,7 +880,7 @@ void RocketClientChannel::sendRequestStream(
assert(metadata.name_ref());
return rocket::RocketClient::sendRequestStream(
std::move(payload),
firstResponseTimeout,
firstResponseTimeout.value_or(std::chrono::milliseconds::zero()),
rpcOptions.getChunkTimeout(),
rpcOptions.getChunkBufferSize(),
new FirstRequestProcessorStream(
Expand All @@ -902,16 +902,16 @@ void RocketClientChannel::sendRequestSink(
}
preprocessHeader(header.get());

auto firstResponseTimeout = getClientTimeout(rpcOptions);
auto metadata = apache::thrift::detail::makeRequestRpcMetadata(
rpcOptions,
RpcKind::SINK,
static_cast<ProtocolId>(header->getProtocolId()),
methodMetadata.name_managed(),
timeout_,
firstResponseTimeout,
*header);

std::chrono::milliseconds firstResponseTimeout;
preSendValidation(metadata, rpcOptions, firstResponseTimeout);
preSendValidation(metadata, rpcOptions);

auto buf = std::move(request.buffer);
setCompression(metadata, buf->computeChainDataLength());
Expand All @@ -924,7 +924,7 @@ void RocketClientChannel::sendRequestSink(
assert(metadata.name_ref());
return rocket::RocketClient::sendRequestSink(
std::move(payload),
firstResponseTimeout,
firstResponseTimeout.value_or(std::chrono::milliseconds(0)),
new FirstRequestProcessorSink(
header->getProtocolId(),
std::move(*metadata.name_ref()),
Expand All @@ -947,17 +947,17 @@ void RocketClientChannel::sendThriftRequest(
}
preprocessHeader(header.get());

auto timeout = getClientTimeout(rpcOptions);
auto metadata = apache::thrift::detail::makeRequestRpcMetadata(
rpcOptions,
kind,
static_cast<ProtocolId>(header->getProtocolId()),
std::move(methodName),
timeout_,
timeout,
*header);
header.reset();

std::chrono::milliseconds timeout;
preSendValidation(metadata, rpcOptions, timeout);
preSendValidation(metadata, rpcOptions);

auto buf = std::move(request.buffer);
setCompression(metadata, buf->computeChainDataLength());
Expand All @@ -975,7 +975,7 @@ void RocketClientChannel::sendThriftRequest(
sendSingleRequestSingleResponse(
rpcOptions,
std::move(metadata),
timeout,
timeout.value_or(std::chrono::milliseconds(0)),
std::move(buf),
std::move(clientCallback));
break;
Expand Down Expand Up @@ -1093,19 +1093,20 @@ bool RocketClientChannel::canHandleRequest(CallbackPtr& cb) {
return true;
}

std::optional<std::chrono::milliseconds> RocketClientChannel::getClientTimeout(
const RpcOptions& rpcOptions) const {
if (rpcOptions.getTimeout() > std::chrono::milliseconds::zero()) {
return rpcOptions.getTimeout();
} else if (timeout_ > std::chrono::milliseconds::zero()) {
return timeout_;
}
return std::nullopt;
}

void RocketClientChannel::preSendValidation(
RequestRpcMetadata& metadata,
const RpcOptions& rpcOptions,
std::chrono::milliseconds& firstResponseTimeout) {
RequestRpcMetadata& metadata, const RpcOptions& rpcOptions) {
DCHECK(metadata.kind_ref().has_value());

firstResponseTimeout =
std::chrono::milliseconds(metadata.clientTimeoutMs_ref().value_or(0));
if (rpcOptions.getClientOnlyTimeouts()) {
metadata.clientTimeoutMs_ref().reset();
metadata.queueTimeoutMs_ref().reset();
}

if (auto interactionId = rpcOptions.getInteractionId()) {
evb_->dcheckIsInEventBaseThread();
if (auto* name = folly::get_ptr(pendingInteractions_, interactionId)) {
Expand Down
7 changes: 4 additions & 3 deletions thrift/lib/cpp2/async/RocketClientChannel.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,14 @@ class RocketClientChannel final : public ClientChannel,
std::unique_ptr<folly::IOBuf> buf,
RequestClientCallback::Ptr cb);

std::optional<std::chrono::milliseconds> getClientTimeout(
const RpcOptions& rpcOptions) const;

template <typename CallbackPtr>
bool canHandleRequest(CallbackPtr& cb);

void preSendValidation(
RequestRpcMetadata& metadata,
const RpcOptions& rpcOptions,
std::chrono::milliseconds& firstResponseTimeout);
RequestRpcMetadata& metadata, const RpcOptions& rpcOptions);

rocket::SetupFrame makeSetupFrame(RequestSetupMetadata meta);

Expand Down
18 changes: 10 additions & 8 deletions thrift/lib/cpp2/transport/core/RpcMetadataUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,22 @@ RequestRpcMetadata makeRequestRpcMetadata(
RpcKind kind,
ProtocolId protocolId,
ManagedStringView&& methodName,
std::chrono::milliseconds defaultChannelTimeout,
std::optional<std::chrono::milliseconds> clientTimeout,
transport::THeader& header) {
RequestRpcMetadata metadata;
metadata.protocol_ref() = protocolId;
metadata.kind_ref() = kind;
metadata.name_ref() = ManagedStringViewWithConversions(std::move(methodName));
if (rpcOptions.getTimeout() > std::chrono::milliseconds::zero()) {
metadata.clientTimeoutMs_ref() = rpcOptions.getTimeout().count();
} else if (defaultChannelTimeout > std::chrono::milliseconds::zero()) {
metadata.clientTimeoutMs_ref() = defaultChannelTimeout.count();
}
if (rpcOptions.getQueueTimeout() > std::chrono::milliseconds::zero()) {
metadata.queueTimeoutMs_ref() = rpcOptions.getQueueTimeout().count();

if (!rpcOptions.getClientOnlyTimeouts()) {
if (clientTimeout.has_value()) {
metadata.clientTimeoutMs_ref() = clientTimeout->count();
}
if (rpcOptions.getQueueTimeout() > std::chrono::milliseconds::zero()) {
metadata.queueTimeoutMs_ref() = rpcOptions.getQueueTimeout().count();
}
}

if (rpcOptions.getPriority() < concurrency::N_PRIORITIES) {
metadata.priority_ref() =
static_cast<RpcPriority>(rpcOptions.getPriority());
Expand Down
2 changes: 1 addition & 1 deletion thrift/lib/cpp2/transport/core/RpcMetadataUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ RequestRpcMetadata makeRequestRpcMetadata(
RpcKind kind,
ProtocolId protocolId,
ManagedStringView&& methodName,
std::chrono::milliseconds defaultChannelTimeout,
std::optional<std::chrono::milliseconds> clientTimeout,
transport::THeader& header);

void fillTHeaderFromResponseRpcMetadata(
Expand Down

0 comments on commit f10a1a1

Please sign in to comment.