Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(client): adjust TransportSenderT #852

Merged
merged 10 commits into from
Aug 12, 2022
9 changes: 0 additions & 9 deletions client/transport/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,6 @@ impl TransportSenderT for Sender {
self.0.send(Message::Text(msg)).await.map_err(|e| Error::WebSocket(e))?;
Ok(())
}

async fn send_ping(&mut self) -> Result<(), Self::Error> {
tracing::trace!("send ping - not implemented for wasm");
Err(Error::NotSupported)
}

async fn close(&mut self) -> Result<(), Error> {
Ok(())
}
}

#[async_trait(?Send)]
Expand Down
4 changes: 2 additions & 2 deletions client/transport/src/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl TransportSenderT for Sender {

/// Sends out a ping request. Returns a `Future` that finishes when the request has been
/// successfully sent.
async fn send_ping(&mut self) -> Result<(), Self::Error> {
async fn optional_send_ping(&mut self) -> Result<(), Self::Error> {
tracing::debug!("send ping");
// Submit empty slice as "optional" parameter.
let slice: &[u8] = &[];
Expand All @@ -211,7 +211,7 @@ impl TransportSenderT for Sender {
}

/// Send a close message and close the connection.
async fn close(&mut self) -> Result<(), WsError> {
async fn optional_close(&mut self) -> Result<(), WsError> {
self.inner.close().await.map_err(Into::into)
}
}
Expand Down
277 changes: 143 additions & 134 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,104 +243,110 @@ impl Drop for Client {
#[async_trait]
impl ClientT for Client {
async fn notification<'a>(&self, method: &'a str, params: Option<ParamsSer<'a>>) -> Result<(), Error> {
// NOTE: we use this to guard against max number of concurrent requests.
let _req_id = self.id_manager.next_request_id()?;
let notif = NotificationSer::new(method, params);
let trace = RpcTracing::batch();

async {
let raw = serde_json::to_string(&notif).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length);

let mut sender = self.to_back.clone();
let fut = sender.send(FrontToBack::Notification(raw));

match future::select(fut, Delay::new(self.request_timeout)).await {
Either::Left((Ok(()), _)) => Ok(()),
Either::Left((Err(_), _)) => Err(self.read_error_from_backend().await),
Either::Right((_, _)) => Err(Error::RequestTimeout),
}
}.instrument(trace.into_span()).await
}
// NOTE: we use this to guard against max number of concurrent requests.
let _req_id = self.id_manager.next_request_id()?;
let notif = NotificationSer::new(method, params);
let trace = RpcTracing::batch();

async {
let raw = serde_json::to_string(&notif).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length);

let mut sender = self.to_back.clone();
let fut = sender.send(FrontToBack::Notification(raw));

match future::select(fut, Delay::new(self.request_timeout)).await {
Either::Left((Ok(()), _)) => Ok(()),
Either::Left((Err(_), _)) => Err(self.read_error_from_backend().await),
Either::Right((_, _)) => Err(Error::RequestTimeout),
}
}
.instrument(trace.into_span())
.await
}

async fn request<'a, R>(&self, method: &'a str, params: Option<ParamsSer<'a>>) -> Result<R, Error>
where
R: DeserializeOwned,
{
let (send_back_tx, send_back_rx) = oneshot::channel();
let guard = self.id_manager.next_request_id()?;
let id = guard.inner();
let trace = RpcTracing::method_call(method);

async {
let raw = serde_json::to_string(&RequestSer::new(&id, method, params)).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length);

if self
.to_back
.clone()
.send(FrontToBack::Request(RequestMessage { raw, id: id.clone(), send_back: Some(send_back_tx) }))
.await
.is_err()
{
return Err(self.read_error_from_backend().await);
}

let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let json_value = match res {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};

rx_log_from_json(&Response::new(&json_value, id), self.max_log_length);

serde_json::from_value(json_value).map_err(Error::ParseError)
}.instrument(trace.into_span()).await
}
{
let (send_back_tx, send_back_rx) = oneshot::channel();
let guard = self.id_manager.next_request_id()?;
let id = guard.inner();
let trace = RpcTracing::method_call(method);

async {
let raw = serde_json::to_string(&RequestSer::new(&id, method, params)).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length);

if self
.to_back
.clone()
.send(FrontToBack::Request(RequestMessage { raw, id: id.clone(), send_back: Some(send_back_tx) }))
.await
.is_err()
{
return Err(self.read_error_from_backend().await);
}

let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let json_value = match res {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};

rx_log_from_json(&Response::new(&json_value, id), self.max_log_length);

serde_json::from_value(json_value).map_err(Error::ParseError)
}
.instrument(trace.into_span())
.await
}

async fn batch_request<'a, R>(&self, batch: Vec<(&'a str, Option<ParamsSer<'a>>)>) -> Result<Vec<R>, Error>
where
R: DeserializeOwned + Default + Clone,
{
let trace = RpcTracing::batch();
async {
{
let trace = RpcTracing::batch();
async {
let guard = self.id_manager.next_request_ids(batch.len())?;
let batch_ids: Vec<Id> = guard.inner();
let mut batches = Vec::with_capacity(batch.len());
for (idx, (method, params)) in batch.into_iter().enumerate() {
batches.push(RequestSer::new(&batch_ids[idx], method, params));
}

let (send_back_tx, send_back_rx) = oneshot::channel();

let raw = serde_json::to_string(&batches).map_err(Error::ParseError)?;

tx_log_from_str(&raw, self.max_log_length);

if self
.to_back
.clone()
.send(FrontToBack::Batch(BatchMessage { raw, ids: batch_ids, send_back: send_back_tx }))
.await
.is_err()
{
return Err(self.read_error_from_backend().await);
}

let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let json_values = match res {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};

rx_log_from_json(&json_values, self.max_log_length);

let values: Result<_, _> =
json_values.into_iter().map(|val| serde_json::from_value(val).map_err(Error::ParseError)).collect();
Ok(values?)
}.instrument(trace.into_span()).await
batches.push(RequestSer::new(&batch_ids[idx], method, params));
}

let (send_back_tx, send_back_rx) = oneshot::channel();

let raw = serde_json::to_string(&batches).map_err(Error::ParseError)?;

tx_log_from_str(&raw, self.max_log_length);

if self
.to_back
.clone()
.send(FrontToBack::Batch(BatchMessage { raw, ids: batch_ids, send_back: send_back_tx }))
.await
.is_err()
{
return Err(self.read_error_from_backend().await);
}

let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let json_values = match res {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};

rx_log_from_json(&json_values, self.max_log_length);

let values: Result<_, _> =
json_values.into_iter().map(|val| serde_json::from_value(val).map_err(Error::ParseError)).collect();
Ok(values?)
}
.instrument(trace.into_span())
.await
}
}

Expand All @@ -358,52 +364,55 @@ impl SubscriptionClientT for Client {
) -> Result<Subscription<N>, Error>
where
N: DeserializeOwned,
{
if subscribe_method == unsubscribe_method {
return Err(Error::SubscriptionNameConflict(unsubscribe_method.to_owned()));
}

let guard = self.id_manager.next_request_ids(2)?;
let mut ids: Vec<Id> = guard.inner();
let trace = RpcTracing::method_call(subscribe_method);

async {
let id = ids[0].clone();

let raw = serde_json::to_string(&RequestSer::new(&id, subscribe_method, params)).map_err(Error::ParseError)?;

tx_log_from_str(&raw, self.max_log_length);

let (send_back_tx, send_back_rx) = oneshot::channel();
if self
.to_back
.clone()
.send(FrontToBack::Subscribe(SubscriptionMessage {
raw,
subscribe_id: ids.swap_remove(0),
unsubscribe_id: ids.swap_remove(0),
unsubscribe_method: unsubscribe_method.to_owned(),
send_back: send_back_tx,
}))
.await
.is_err()
{
return Err(self.read_error_from_backend().await);
}

let res = call_with_timeout(self.request_timeout, send_back_rx).await;

let (notifs_rx, sub_id) = match res {
Ok(Ok(val)) => val,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};

rx_log_from_json(&Response::new(&sub_id, id), self.max_log_length);

Ok(Subscription::new(self.to_back.clone(), notifs_rx, SubscriptionKind::Subscription(sub_id)))
}.instrument(trace.into_span()).await
}
{
if subscribe_method == unsubscribe_method {
return Err(Error::SubscriptionNameConflict(unsubscribe_method.to_owned()));
}

let guard = self.id_manager.next_request_ids(2)?;
let mut ids: Vec<Id> = guard.inner();
let trace = RpcTracing::method_call(subscribe_method);

async {
let id = ids[0].clone();

let raw =
serde_json::to_string(&RequestSer::new(&id, subscribe_method, params)).map_err(Error::ParseError)?;

tx_log_from_str(&raw, self.max_log_length);

let (send_back_tx, send_back_rx) = oneshot::channel();
if self
.to_back
.clone()
.send(FrontToBack::Subscribe(SubscriptionMessage {
raw,
subscribe_id: ids.swap_remove(0),
unsubscribe_id: ids.swap_remove(0),
unsubscribe_method: unsubscribe_method.to_owned(),
send_back: send_back_tx,
}))
.await
.is_err()
{
return Err(self.read_error_from_backend().await);
}

let res = call_with_timeout(self.request_timeout, send_back_rx).await;

let (notifs_rx, sub_id) = match res {
Ok(Ok(val)) => val,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await),
};

rx_log_from_json(&Response::new(&sub_id, id), self.max_log_length);

Ok(Subscription::new(self.to_back.clone(), notifs_rx, SubscriptionKind::Subscription(sub_id)))
}
.instrument(trace.into_span())
.await
}

/// Subscribe to a specific method.
async fn subscribe_to_method<'a, N>(&self, method: &'a str) -> Result<Subscription<N>, Error>
Expand Down Expand Up @@ -683,7 +692,7 @@ async fn background_task<S, R>(
}
// Submit ping interval was triggered if enabled.
Either::Right((_, next_message_fut)) => {
if let Err(e) = sender.send_ping().await {
if let Err(e) = sender.optional_send_ping().await {
tracing::warn!("[backend]: client send ping failed: {:?}", e);
let _ = front_error.send(Error::Custom("Could not send ping frame".into()));
break;
Expand All @@ -693,5 +702,5 @@ async fn background_task<S, R>(
};
}
// Send close message to the server.
let _ = sender.close().await;
let _ = sender.optional_close().await;
}
14 changes: 11 additions & 3 deletions core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,19 @@ pub trait TransportSenderT: MaybeSend + 'static {
/// Send.
async fn send(&mut self, msg: String) -> Result<(), Self::Error>;

/// This is optional because it's most likely on relevant for WebSocket transports only.
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
/// You should only implement this is your transport supports sending periodic pings.
///
/// Send ping frame (opcode of 0x9).
async fn send_ping(&mut self) -> Result<(), Self::Error>;
async fn optional_send_ping(&mut self) -> Result<(), Self::Error> {
Ok(())
}

/// If the transport supports sending customized close messages.
async fn close(&mut self) -> Result<(), Self::Error> {
/// This is optional because it's most likely on relevant for WebSocket transports only.
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
/// You should only implement this is your transport supports sending periodic pings.
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
///
/// Send customized close message.
async fn optional_close(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
Expand Down