Skip to content

Commit

Permalink
Implement fallible streams for FlightClient::do_put (#3464)
Browse files Browse the repository at this point in the history
* Implement fallible streams for do_put

* Another approach to error wrapping

* implement basic client error test

* Add last error test

* comments

* fix docs

* Simplify

---------

Co-authored-by: Raphael Taylor-Davies <r.taylordavies@googlemail.com>
  • Loading branch information
alamb and tustvold authored Feb 23, 2023
1 parent 47e4b61 commit 0373a9d
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 21 deletions.
58 changes: 44 additions & 14 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use std::task::Poll;

use crate::{
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action,
ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
Expand All @@ -24,8 +26,9 @@ use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
future::ready,
ready,
stream::{self, BoxStream},
Stream, StreamExt, TryStreamExt,
FutureExt, Stream, StreamExt, TryStreamExt,
};
use tonic::{metadata::MetadataMap, transport::Channel};

Expand Down Expand Up @@ -262,6 +265,15 @@ impl FlightClient {
/// [`Stream`](futures::Stream) of [`FlightData`] and returning a
/// stream of [`PutResult`].
///
/// # Note
///
/// The input stream is [`Result`] so that this can be connected
/// to a streaming data source, such as [`FlightDataEncoder`](crate::encode::FlightDataEncoder),
/// without having to buffer. If the input stream returns an error
/// that error will not be sent to the server, instead it will be
/// placed into the result stream and the server connection
/// terminated.
///
/// # Example:
/// ```no_run
/// # async fn run() {
Expand All @@ -279,9 +291,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// .build(futures::stream::iter(vec![Ok(batch)]))
/// // data encoder return Results, but do_put requires FlightData
/// .map(|batch|batch.unwrap());
/// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `PutResult`
/// let response: Vec<PutResult>= client
Expand All @@ -293,20 +303,40 @@ impl FlightClient {
/// .expect("error calling do_put");
/// # }
/// ```
pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let request = self.make_request(request);

let response = self
.inner
.do_put(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
let (sender, mut receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_put(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});

Ok(response.boxed())
// combine the response from the server and any error from the client
Ok(error_stream.boxed())
}

/// Make a `DoExchange` call to the server with the provided
Expand Down
99 changes: 92 additions & 7 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ async fn test_do_put() {
test_server
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");

Expand All @@ -266,15 +268,15 @@ async fn test_do_put() {
}

#[tokio::test]
async fn test_do_put_error() {
async fn test_do_put_error_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let input_flight_data = test_flight_data().await;

let response = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.await;
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response = client.do_put(input_stream).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
Expand All @@ -290,7 +292,7 @@ async fn test_do_put_error() {
}

#[tokio::test]
async fn test_do_put_error_stream() {
async fn test_do_put_error_stream_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

Expand All @@ -307,8 +309,10 @@ async fn test_do_put_error_stream() {

test_server.set_do_put_response(response);

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");

Expand All @@ -326,6 +330,87 @@ async fn test_do_put_error_stream() {
.await;
}

#[tokio::test]
async fn test_do_put_error_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e = Status::invalid_argument("bad arg: client");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));

// server responds with one good message
let response = vec![Ok(PutResult {
app_metadata: Bytes::from("foo-metadata"),
})];
test_server.set_do_put_response(response);

let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client
expect_status(response, e);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_put_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));

// server responds with an error (e.g. because it got truncated data)
let response = vec![Err(e_server)];
test_server.set_do_put_response(response);

let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client (not the server)
expect_status(response, e_client);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_exchange() {
do_test(|test_server, mut client| async move {
Expand Down

0 comments on commit 0373a9d

Please sign in to comment.