Skip to content

Commit

Permalink
Fix case where resource_name not set in stream error
Browse files Browse the repository at this point in the history
The RBE protocol specifies that if the first message in a stream
has the resource_name set all subsequent messages do not need to
have it set. With this patch we now honor that requirement for
GrpcStore hotpath.

closes: #745
  • Loading branch information
allada committed Mar 10, 2024
1 parent 3e6f154 commit 2674584
Show file tree
Hide file tree
Showing 9 changed files with 424 additions and 309 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion nativelink-service/src/bytestream_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ use nativelink_store::grpc_store::GrpcStore;
use nativelink_store::store_manager::StoreManager;
use nativelink_util::buf_channel::{make_buf_channel_pair, DropCloserReadHalf, DropCloserWriteHalf};
use nativelink_util::common::DigestInfo;
use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper;
use nativelink_util::resource_info::ResourceInfo;
use nativelink_util::store_trait::{Store, UploadSizeInfo};
use nativelink_util::write_request_stream_wrapper::WriteRequestStreamWrapper;
use parking_lot::Mutex;
use tokio::task::AbortHandle;
use tokio::time::sleep;
Expand Down
170 changes: 6 additions & 164 deletions nativelink-store/src/grpc_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use async_trait::async_trait;
Expand All @@ -37,10 +36,10 @@ use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf};
use nativelink_util::common::DigestInfo;
use nativelink_util::grpc_utils::ConnectionManager;
use nativelink_util::health_utils::HealthStatusIndicator;
use nativelink_util::proto_stream_utils::{FirstStream, WriteRequestStreamWrapper, WriteState, WriteStateWrapper};
use nativelink_util::resource_info::ResourceInfo;
use nativelink_util::retry::{Retrier, RetryResult};
use nativelink_util::store_trait::{Store, UploadSizeInfo};
use nativelink_util::write_request_stream_wrapper::WriteRequestStreamWrapper;
use nativelink_util::{default_health_status_indicator, tls_utils};
use parking_lot::Mutex;
use prost::Message;
Expand All @@ -63,157 +62,6 @@ pub struct GrpcStore {
connection_manager: ConnectionManager,
}

/// This provides a buffer for the first response from GrpcStore.read in order
/// to allow the first read to occur within the retry loop. That means that if
/// the connection establishes fine, but reading the first byte of the file
/// fails we have the ability to retry before returning to the caller.
struct FirstStream {
/// Contains the first response from the stream (which could be an EOF,
/// hence the nested Option). This should be populated on creation and
/// returned as the first result from the stream. Subsequent reads from the
/// stream will use the encapsulated stream.
first_response: Option<Option<ReadResponse>>,
/// The stream to get responses from when first_response is None.
stream: Streaming<ReadResponse>,
}

impl Stream for FirstStream {
type Item = Result<ReadResponse, Status>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
if let Some(first_response) = self.first_response.take() {
return std::task::Poll::Ready(first_response.map(Ok));
}
Pin::new(&mut self.stream).poll_next(cx)
}
}

/// This structure wraps all of the information required to perform a write
/// request on the GrpcStore. It stores the last message retrieved which allows
/// the write to resume since the UUID allows upload resume at the server.
struct WriteState<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
instance_name: String,
read_stream_error: Option<Error>,
read_stream: WriteRequestStreamWrapper<T, E>,
// Tonic doesn't appear to report an error until it has taken two messages,
// therefore we are required to buffer the last two messages.
cached_messages: [Option<WriteRequest>; 2],
// When resuming after an error, the previous messages are cloned into this
// queue upfront to allow them to be served back.
resume_queue: [Option<WriteRequest>; 2],
// An optimisation to avoid having to manage resume_queue when it's empty.
is_resumed: bool,
}

impl<T, E> WriteState<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T, E>) -> Self {
Self {
instance_name,
read_stream_error: None,
read_stream,
cached_messages: [None, None],
resume_queue: [None, None],
is_resumed: false,
}
}

fn push_message(&mut self, message: WriteRequest) {
self.cached_messages.swap(0, 1);
self.cached_messages[0] = Some(message);
}

fn resumed_message(&mut self) -> Option<WriteRequest> {
if self.is_resumed {
// The resume_queue is a circular buffer, that we have to shift,
// since its only got two elements its a trivial swap.
self.resume_queue.swap(0, 1);
let message = self.resume_queue[0].take();
if message.is_none() {
self.is_resumed = false;
}
message
} else {
None
}
}

fn can_resume(&self) -> bool {
self.read_stream_error.is_none() && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg())
}

fn resume(&mut self) {
self.resume_queue = self.cached_messages.clone();
self.is_resumed = true;
}
}

/// A wrapper around WriteState to allow it to be reclaimed from the underlying
/// write call in the case of failure.
struct WriteStateWrapper<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
shared_state: Arc<Mutex<WriteState<T, E>>>,
}

impl<T, E> Stream for WriteStateWrapper<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
type Item = WriteRequest;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// This should be an uncontended lock since write was called.
let mut local_state = self.shared_state.lock();
// If this is the first or second call after a failure and we have
// cached messages, then use the cached write requests.
let cached_message = local_state.resumed_message();
if cached_message.is_some() {
return Poll::Ready(cached_message);
}
// Read a new write request from the downstream.
let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) else {
return Poll::Pending;
};
// Update the instance name in the write request and forward it on.
const IS_UPLOAD_TRUE: bool = true;
let result = match maybe_message {
Some(Ok(mut message)) => match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) {
Ok(mut resource_name) => {
if resource_name.instance_name != local_state.instance_name {
resource_name.instance_name = &local_state.instance_name;
message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE);
}
// Cache the last request in case there is an error to allow
// the upload to be resumed.
local_state.push_message(message.clone());
Some(message)
}
Err(err) => {
local_state.read_stream_error = Some(err);
None
}
},
Some(Err(err)) => {
local_state.read_stream_error = Some(err);
None
}
None => None,
};
Poll::Ready(result)
}
}

impl GrpcStore {
pub async fn new(config: &nativelink_config::stores::GrpcStore) -> Result<Self, Error> {
let jitter_amt = config.retry.jitter;
Expand Down Expand Up @@ -402,10 +250,7 @@ impl GrpcStore {
.message()
.await
.err_tip(|| "Fetching first chunk in GrpcStore::read()")?;
Ok(FirstStream {
first_response: Some(first_response),
stream: response,
})
Ok(FirstStream::new(first_response, response))
}

pub async fn read(
Expand Down Expand Up @@ -444,19 +289,16 @@ impl GrpcStore {
// has completed. There is no way to get the value back
// from the client.
let result = ByteStreamClient::new(channel)
.write(WriteStateWrapper {
shared_state: local_state.clone(),
})
.write(WriteStateWrapper::new(local_state.clone()))
.await;

// Get the state back from StateWrapper, this should be
// uncontended since write has returned.
let mut local_state_locked = local_state.lock();

let result = if let Some(err) = local_state_locked.read_stream_error.take() {
// If there was an error with the stream, then don't
// retry.
RetryResult::Err(err)
let result = if let Some(err) = local_state_locked.take_read_stream_error() {
// If there was an error with the stream, then don't retry.
RetryResult::Err(err.append("Where read_stream_error was set"))
} else {
// On error determine whether it is possible to retry.
match result.err_tip(|| "in GrpcStore::write") {
Expand Down
6 changes: 5 additions & 1 deletion nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ rust_library(
"src/lib.rs",
"src/metrics_utils.rs",
"src/platform_properties.rs",
"src/proto_stream_utils.rs",
"src/resource_info.rs",
"src/retry.rs",
"src/store_trait.rs",
"src/tls_utils.rs",
"src/write_counter.rs",
"src/write_request_stream_wrapper.rs",
],
proc_macro_deps = [
"@crates//:async-trait",
Expand Down Expand Up @@ -67,6 +67,7 @@ rust_test_suite(
"tests/fastcdc_test.rs",
"tests/fs_test.rs",
"tests/health_utils_test.rs",
"tests/proto_stream_utils_test.rs",
"tests/resource_info_test.rs",
"tests/retry_test.rs",
],
Expand All @@ -80,14 +81,17 @@ rust_test_suite(
":nativelink-util",
"//nativelink-config",
"//nativelink-error",
"//nativelink-proto",
"@crates//:bytes",
"@crates//:futures",
"@crates//:hex",
"@crates//:mock_instant",
"@crates//:parking_lot",
"@crates//:pretty_assertions",
"@crates//:rand",
"@crates//:sha2",
"@crates//:tokio",
"@crates//:tokio-stream",
"@crates//:tokio-util",
],
)
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rand = "0.8.5"
serde = { version = "1.0.197", features = ["derive"] }
sha2 = "0.10.8"
tokio = { version = "1.36.0", features = [ "sync", "fs", "rt", "time", "io-util", "macros" ] }
tokio-stream = { version = "0.1.14", features = ["sync"] }
tokio-util = { version = "0.7.10" }
tonic = { version = "0.11.0", features = ["tls"] }
tracing = "0.1.40"
Expand Down
2 changes: 1 addition & 1 deletion nativelink-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ pub mod grpc_utils;
pub mod health_utils;
pub mod metrics_utils;
pub mod platform_properties;
pub mod proto_stream_utils;
pub mod resource_info;
pub mod retry;
pub mod store_trait;
pub mod tls_utils;
pub mod write_counter;
pub mod write_request_stream_wrapper;
Loading

0 comments on commit 2674584

Please sign in to comment.