From 516847bd567330bf98d2953715fd6fadeef56c0e Mon Sep 17 00:00:00 2001 From: Niklas Date: Wed, 8 Dec 2021 16:41:26 +0100 Subject: [PATCH] types: make subscription ID String Cow. --- types/src/client.rs | 6 ++--- types/src/v2/params.rs | 41 +++++++++++++++++++++++++++++----- types/src/v2/response.rs | 7 +++--- utils/src/server/rpc_module.rs | 4 ++-- ws-client/Cargo.toml | 1 + ws-client/src/helpers.rs | 10 +++++---- ws-client/src/manager.rs | 18 +++++---------- ws-client/src/transport.rs | 2 +- 8 files changed, 59 insertions(+), 30 deletions(-) diff --git a/types/src/client.rs b/types/src/client.rs index c0a0cf0685..901bd0de39 100644 --- a/types/src/client.rs +++ b/types/src/client.rs @@ -38,7 +38,7 @@ use std::sync::Arc; #[non_exhaustive] pub enum SubscriptionKind { /// Get notifications based on Subscription ID. - Subscription(SubscriptionId), + Subscription(SubscriptionId<'static>), /// Get notifications based on method name. Method(String), } @@ -115,7 +115,7 @@ pub struct SubscriptionMessage { /// If the subscription succeeds, we return a [`mpsc::Receiver`] that will receive notifications. /// When we get a response from the server about that subscription, we send the result over /// this channel. - pub send_back: oneshot::Sender, SubscriptionId), Error>>, + pub send_back: oneshot::Sender, SubscriptionId<'static>), Error>>, } /// RegisterNotification message. @@ -149,7 +149,7 @@ pub enum FrontToBack { // NOTE: It is not possible to cancel pending subscriptions or pending requests. // Such operations will be blocked until a response is received or the background // thread has been terminated. - SubscriptionClosed(SubscriptionId), + SubscriptionClosed(SubscriptionId<'static>), } impl Subscription diff --git a/types/src/v2/params.rs b/types/src/v2/params.rs index ea399be59b..6286caf2e1 100644 --- a/types/src/v2/params.rs +++ b/types/src/v2/params.rs @@ -35,7 +35,7 @@ use serde::de::{self, Deserializer, Unexpected, Visitor}; use serde::ser::Serializer; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; -use std::fmt; +use std::{convert::TryFrom, fmt}; /// JSON-RPC v2 marker type. #[derive(Clone, Copy, Debug, Default, PartialEq)] @@ -288,18 +288,49 @@ impl<'a> From<&'a [JsonValue]> for ParamsSer<'a> { #[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)] #[serde(deny_unknown_fields)] #[serde(untagged)] -pub enum SubscriptionId { +pub enum SubscriptionId<'a> { /// Numeric id Num(u64), /// String id - Str(String), + #[serde(borrow)] + Str(Cow<'a, str>), } -impl From for JsonValue { +impl<'a> From> for JsonValue { fn from(sub_id: SubscriptionId) -> Self { match sub_id { SubscriptionId::Num(n) => n.into(), - SubscriptionId::Str(s) => s.into(), + SubscriptionId::Str(s) => s.into_owned().into(), + } + } +} + +impl<'a> TryFrom for SubscriptionId<'a> { + type Error = (); + + fn try_from(json: JsonValue) -> Result, ()> { + match json { + JsonValue::String(s) => Ok(SubscriptionId::Str(s.into())), + JsonValue::Number(n) => { + if let Some(n) = n.as_u64() { + Ok(SubscriptionId::Num(n)) + } else { + Err(()) + } + } + _ => Err(()), + } + } +} + +impl<'a> SubscriptionId<'a> { + /// Convert `SubscriptionId<'a>` to `SubscriptionId<'static>` so that it can be moved across threads. + /// + /// This can cause an allocation if the id is a string. + pub fn into_owned(self) -> SubscriptionId<'static> { + match self { + SubscriptionId::Num(num) => SubscriptionId::Num(num), + SubscriptionId::Str(s) => SubscriptionId::Str(Cow::owned(s.into_owned())), } } } diff --git a/types/src/v2/response.rs b/types/src/v2/response.rs index 3ce3324912..60c338440e 100644 --- a/types/src/v2/response.rs +++ b/types/src/v2/response.rs @@ -54,15 +54,16 @@ impl<'a, T> Response<'a, T> { /// Return value for subscriptions. #[derive(Serialize, Deserialize, Debug)] -pub struct SubscriptionPayload { +pub struct SubscriptionPayload<'a, T> { /// Subscription ID - pub subscription: SubscriptionId, + #[serde(borrow)] + pub subscription: SubscriptionId<'a>, /// Result. pub result: T, } /// Subscription response object, embedding a [`SubscriptionPayload`] in the `params` member. -pub type SubscriptionResponse<'a, T> = Notification<'a, SubscriptionPayload>; +pub type SubscriptionResponse<'a, T> = Notification<'a, SubscriptionPayload<'a, T>>; #[cfg(test)] mod tests { diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 8efcdaeaf7..913577c049 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -827,11 +827,11 @@ impl TestSubscription { /// # Panics /// /// If the decoding the value as `T` fails. - pub async fn next(&mut self) -> Option<(T, jsonrpsee_types::v2::SubscriptionId)> { + pub async fn next(&mut self) -> Option<(T, jsonrpsee_types::v2::SubscriptionId<'static>)> { let raw = self.rx.next().await?; let val: SubscriptionResponse = serde_json::from_str(&raw).expect("valid response in TestSubscription::next()"); - Some((val.params.result, val.params.subscription)) + Some((val.params.result, val.params.subscription.into_owned())) } } diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index cdc461b82c..6bcc868556 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -11,6 +11,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-client" [dependencies] async-trait = "0.1" +beef = "0.5.1" rustc-hash = "1" futures = { version = "0.3.14", default-features = false, features = ["std"] } http = "0.2" diff --git a/ws-client/src/helpers.rs b/ws-client/src/helpers.rs index 7be177ff7c..eed9a60318 100644 --- a/ws-client/src/helpers.rs +++ b/ws-client/src/helpers.rs @@ -32,6 +32,7 @@ use crate::types::v2::{ use crate::types::{Error, RequestMessage}; use futures::channel::{mpsc, oneshot}; use serde_json::Value as JsonValue; +use std::convert::TryInto; use std::time::Duration; /// Attempts to process a batch response. @@ -75,7 +76,7 @@ pub fn process_subscription_response( manager: &mut RequestManager, response: SubscriptionResponse, ) -> Result<(), Option> { - let sub_id = response.params.subscription; + let sub_id = response.params.subscription.into_owned(); let request_id = match manager.get_request_id_by_subscription_id(&sub_id) { Some(request_id) => request_id, None => return Err(None), @@ -144,8 +145,9 @@ pub fn process_single_response( let (unsub_id, send_back_oneshot, unsubscribe_method) = manager.complete_pending_subscription(response_id).ok_or(Error::InvalidRequestId)?; - let sub_id: SubscriptionId = match serde_json::from_value(response.result) { - Ok(sub_id) => sub_id, + let sub_id: Result = response.result.try_into(); + let sub_id = match sub_id { + Ok(sub_id) => sub_id.into_owned(), Err(_) => { let _ = send_back_oneshot.send(Err(Error::InvalidSubscriptionId)); return Ok(None); @@ -185,7 +187,7 @@ pub async fn stop_subscription(sender: &mut WsSender, manager: &mut RequestManag pub fn build_unsubscribe_message( manager: &mut RequestManager, sub_req_id: u64, - sub_id: SubscriptionId, + sub_id: SubscriptionId<'static>, ) -> Option { let (unsub_req_id, _, unsub, sub_id) = manager.remove_subscription(sub_req_id, sub_id)?; let sub_id_slice: &[JsonValue] = &[sub_id.into()]; diff --git a/ws-client/src/manager.rs b/ws-client/src/manager.rs index 45e8df4ca4..faaa8f02cf 100644 --- a/ws-client/src/manager.rs +++ b/ws-client/src/manager.rs @@ -59,7 +59,7 @@ pub enum RequestStatus { type PendingCallOneshot = Option>>; type PendingBatchOneshot = oneshot::Sender, Error>>; -type PendingSubscriptionOneshot = oneshot::Sender, SubscriptionId), Error>>; +type PendingSubscriptionOneshot = oneshot::Sender, SubscriptionId<'static>), Error>>; type SubscriptionSink = mpsc::Sender; type UnsubscribeMethod = String; type RequestId = u64; @@ -82,7 +82,7 @@ pub struct RequestManager { requests: FxHashMap, /// Reverse lookup, to find a request ID in constant time by `subscription ID` instead of looking through all /// requests. - subscriptions: HashMap, + subscriptions: HashMap, RequestId>, /// Pending batch requests batches: FxHashMap, BatchState>, /// Registered Methods for incoming notifications @@ -161,7 +161,7 @@ impl RequestManager { &mut self, sub_req_id: RequestId, unsub_req_id: RequestId, - subscription_id: SubscriptionId, + subscription_id: SubscriptionId<'static>, send_back: SubscriptionSink, unsubscribe_method: UnsubscribeMethod, ) -> Result<(), SubscriptionSink> { @@ -251,7 +251,7 @@ impl RequestManager { pub fn remove_subscription( &mut self, request_id: RequestId, - subscription_id: SubscriptionId, + subscription_id: SubscriptionId<'static>, ) -> Option<(RequestId, SubscriptionSink, UnsubscribeMethod, SubscriptionId)> { match (self.requests.entry(request_id), self.subscriptions.entry(subscription_id)) { (Entry::Occupied(request), Entry::Occupied(subscription)) @@ -329,17 +329,11 @@ mod tests { let (unsub_req_id, _send_back_oneshot, unsubscribe_method) = manager.complete_pending_subscription(1).unwrap(); assert_eq!(unsub_req_id, 2); assert!(manager - .insert_subscription( - 1, - 2, - SubscriptionId::Str("uniq_id_from_server".to_string()), - sub_tx, - unsubscribe_method - ) + .insert_subscription(1, 2, SubscriptionId::Str("uniq_id_from_server".into()), sub_tx, unsubscribe_method) .is_ok()); assert!(manager.as_subscription_mut(&1).is_some()); - assert!(manager.remove_subscription(1, SubscriptionId::Str("uniq_id_from_server".to_string())).is_some()); + assert!(manager.remove_subscription(1, SubscriptionId::Str("uniq_id_from_server".into())).is_some()); } #[test] diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index f33885422d..8721f93858 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -25,13 +25,13 @@ // DEALINGS IN THE SOFTWARE. use crate::{stream::EitherStream, types::CertificateStore}; +use beef::Cow; use futures::io::{BufReader, BufWriter}; use http::Uri; use soketto::connection; use soketto::handshake::client::{Client as WsHandshakeClient, Header, ServerResponse}; use std::convert::TryInto; use std::{ - borrow::Cow, convert::TryFrom, io, net::{SocketAddr, ToSocketAddrs},