Skip to content

Commit

Permalink
types: make subscription ID String Cow. (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 authored Dec 10, 2021
1 parent 9bbfd69 commit 59925c0
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 30 deletions.
6 changes: 3 additions & 3 deletions types/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use std::sync::Arc;
#[non_exhaustive]
pub enum SubscriptionKind {
/// Get notifications based on Subscription ID.
Subscription(SubscriptionId),
Subscription(SubscriptionId<'static>),
/// Get notifications based on method name.
Method(String),
}
Expand Down Expand Up @@ -115,7 +115,7 @@ pub struct SubscriptionMessage {
/// If the subscription succeeds, we return a [`mpsc::Receiver`] that will receive notifications.
/// When we get a response from the server about that subscription, we send the result over
/// this channel.
pub send_back: oneshot::Sender<Result<(mpsc::Receiver<JsonValue>, SubscriptionId), Error>>,
pub send_back: oneshot::Sender<Result<(mpsc::Receiver<JsonValue>, SubscriptionId<'static>), Error>>,
}

/// RegisterNotification message.
Expand Down Expand Up @@ -149,7 +149,7 @@ pub enum FrontToBack {
// NOTE: It is not possible to cancel pending subscriptions or pending requests.
// Such operations will be blocked until a response is received or the background
// thread has been terminated.
SubscriptionClosed(SubscriptionId),
SubscriptionClosed(SubscriptionId<'static>),
}

impl<Notif> Subscription<Notif>
Expand Down
41 changes: 36 additions & 5 deletions types/src/v2/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use serde::de::{self, Deserializer, Unexpected, Visitor};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::fmt;
use std::{convert::TryFrom, fmt};

/// JSON-RPC v2 marker type.
#[derive(Clone, Copy, Debug, Default, PartialEq)]
Expand Down Expand Up @@ -288,18 +288,49 @@ impl<'a> From<&'a [JsonValue]> for ParamsSer<'a> {
#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
#[serde(untagged)]
pub enum SubscriptionId {
pub enum SubscriptionId<'a> {
/// Numeric id
Num(u64),
/// String id
Str(String),
#[serde(borrow)]
Str(Cow<'a, str>),
}

impl From<SubscriptionId> for JsonValue {
impl<'a> From<SubscriptionId<'a>> for JsonValue {
fn from(sub_id: SubscriptionId) -> Self {
match sub_id {
SubscriptionId::Num(n) => n.into(),
SubscriptionId::Str(s) => s.into(),
SubscriptionId::Str(s) => s.into_owned().into(),
}
}
}

impl<'a> TryFrom<JsonValue> for SubscriptionId<'a> {
type Error = ();

fn try_from(json: JsonValue) -> Result<SubscriptionId<'a>, ()> {
match json {
JsonValue::String(s) => Ok(SubscriptionId::Str(s.into())),
JsonValue::Number(n) => {
if let Some(n) = n.as_u64() {
Ok(SubscriptionId::Num(n))
} else {
Err(())
}
}
_ => Err(()),
}
}
}

impl<'a> SubscriptionId<'a> {
/// Convert `SubscriptionId<'a>` to `SubscriptionId<'static>` so that it can be moved across threads.
///
/// This can cause an allocation if the id is a string.
pub fn into_owned(self) -> SubscriptionId<'static> {
match self {
SubscriptionId::Num(num) => SubscriptionId::Num(num),
SubscriptionId::Str(s) => SubscriptionId::Str(Cow::owned(s.into_owned())),
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions types/src/v2/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,16 @@ impl<'a, T> Response<'a, T> {

/// Return value for subscriptions.
#[derive(Serialize, Deserialize, Debug)]
pub struct SubscriptionPayload<T> {
pub struct SubscriptionPayload<'a, T> {
/// Subscription ID
pub subscription: SubscriptionId,
#[serde(borrow)]
pub subscription: SubscriptionId<'a>,
/// Result.
pub result: T,
}

/// Subscription response object, embedding a [`SubscriptionPayload`] in the `params` member.
pub type SubscriptionResponse<'a, T> = Notification<'a, SubscriptionPayload<T>>;
pub type SubscriptionResponse<'a, T> = Notification<'a, SubscriptionPayload<'a, T>>;

#[cfg(test)]
mod tests {
Expand Down
4 changes: 2 additions & 2 deletions utils/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,11 +827,11 @@ impl TestSubscription {
/// # Panics
///
/// If the decoding the value as `T` fails.
pub async fn next<T: DeserializeOwned>(&mut self) -> Option<(T, jsonrpsee_types::v2::SubscriptionId)> {
pub async fn next<T: DeserializeOwned>(&mut self) -> Option<(T, jsonrpsee_types::v2::SubscriptionId<'static>)> {
let raw = self.rx.next().await?;
let val: SubscriptionResponse<T> =
serde_json::from_str(&raw).expect("valid response in TestSubscription::next()");
Some((val.params.result, val.params.subscription))
Some((val.params.result, val.params.subscription.into_owned()))
}
}

Expand Down
1 change: 1 addition & 0 deletions ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-client"

[dependencies]
async-trait = "0.1"
beef = "0.5.1"
rustc-hash = "1"
futures = { version = "0.3.14", default-features = false, features = ["std"] }
http = "0.2"
Expand Down
10 changes: 6 additions & 4 deletions ws-client/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::types::v2::{
use crate::types::{Error, RequestMessage};
use futures::channel::{mpsc, oneshot};
use serde_json::Value as JsonValue;
use std::convert::TryInto;
use std::time::Duration;

/// Attempts to process a batch response.
Expand Down Expand Up @@ -75,7 +76,7 @@ pub fn process_subscription_response(
manager: &mut RequestManager,
response: SubscriptionResponse<JsonValue>,
) -> Result<(), Option<RequestMessage>> {
let sub_id = response.params.subscription;
let sub_id = response.params.subscription.into_owned();
let request_id = match manager.get_request_id_by_subscription_id(&sub_id) {
Some(request_id) => request_id,
None => return Err(None),
Expand Down Expand Up @@ -144,8 +145,9 @@ pub fn process_single_response(
let (unsub_id, send_back_oneshot, unsubscribe_method) =
manager.complete_pending_subscription(response_id).ok_or(Error::InvalidRequestId)?;

let sub_id: SubscriptionId = match serde_json::from_value(response.result) {
Ok(sub_id) => sub_id,
let sub_id: Result<SubscriptionId, _> = response.result.try_into();
let sub_id = match sub_id {
Ok(sub_id) => sub_id.into_owned(),
Err(_) => {
let _ = send_back_oneshot.send(Err(Error::InvalidSubscriptionId));
return Ok(None);
Expand Down Expand Up @@ -185,7 +187,7 @@ pub async fn stop_subscription(sender: &mut WsSender, manager: &mut RequestManag
pub fn build_unsubscribe_message(
manager: &mut RequestManager,
sub_req_id: u64,
sub_id: SubscriptionId,
sub_id: SubscriptionId<'static>,
) -> Option<RequestMessage> {
let (unsub_req_id, _, unsub, sub_id) = manager.remove_subscription(sub_req_id, sub_id)?;
let sub_id_slice: &[JsonValue] = &[sub_id.into()];
Expand Down
18 changes: 6 additions & 12 deletions ws-client/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub enum RequestStatus {

type PendingCallOneshot = Option<oneshot::Sender<Result<JsonValue, Error>>>;
type PendingBatchOneshot = oneshot::Sender<Result<Vec<JsonValue>, Error>>;
type PendingSubscriptionOneshot = oneshot::Sender<Result<(mpsc::Receiver<JsonValue>, SubscriptionId), Error>>;
type PendingSubscriptionOneshot = oneshot::Sender<Result<(mpsc::Receiver<JsonValue>, SubscriptionId<'static>), Error>>;
type SubscriptionSink = mpsc::Sender<JsonValue>;
type UnsubscribeMethod = String;
type RequestId = u64;
Expand All @@ -82,7 +82,7 @@ pub struct RequestManager {
requests: FxHashMap<RequestId, Kind>,
/// Reverse lookup, to find a request ID in constant time by `subscription ID` instead of looking through all
/// requests.
subscriptions: HashMap<SubscriptionId, RequestId>,
subscriptions: HashMap<SubscriptionId<'static>, RequestId>,
/// Pending batch requests
batches: FxHashMap<Vec<RequestId>, BatchState>,
/// Registered Methods for incoming notifications
Expand Down Expand Up @@ -161,7 +161,7 @@ impl RequestManager {
&mut self,
sub_req_id: RequestId,
unsub_req_id: RequestId,
subscription_id: SubscriptionId,
subscription_id: SubscriptionId<'static>,
send_back: SubscriptionSink,
unsubscribe_method: UnsubscribeMethod,
) -> Result<(), SubscriptionSink> {
Expand Down Expand Up @@ -251,7 +251,7 @@ impl RequestManager {
pub fn remove_subscription(
&mut self,
request_id: RequestId,
subscription_id: SubscriptionId,
subscription_id: SubscriptionId<'static>,
) -> Option<(RequestId, SubscriptionSink, UnsubscribeMethod, SubscriptionId)> {
match (self.requests.entry(request_id), self.subscriptions.entry(subscription_id)) {
(Entry::Occupied(request), Entry::Occupied(subscription))
Expand Down Expand Up @@ -329,17 +329,11 @@ mod tests {
let (unsub_req_id, _send_back_oneshot, unsubscribe_method) = manager.complete_pending_subscription(1).unwrap();
assert_eq!(unsub_req_id, 2);
assert!(manager
.insert_subscription(
1,
2,
SubscriptionId::Str("uniq_id_from_server".to_string()),
sub_tx,
unsubscribe_method
)
.insert_subscription(1, 2, SubscriptionId::Str("uniq_id_from_server".into()), sub_tx, unsubscribe_method)
.is_ok());

assert!(manager.as_subscription_mut(&1).is_some());
assert!(manager.remove_subscription(1, SubscriptionId::Str("uniq_id_from_server".to_string())).is_some());
assert!(manager.remove_subscription(1, SubscriptionId::Str("uniq_id_from_server".into())).is_some());
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion ws-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
// DEALINGS IN THE SOFTWARE.

use crate::{stream::EitherStream, types::CertificateStore};
use beef::Cow;
use futures::io::{BufReader, BufWriter};
use http::Uri;
use soketto::connection;
use soketto::handshake::client::{Client as WsHandshakeClient, Header, ServerResponse};
use std::convert::TryInto;
use std::{
borrow::Cow,
convert::TryFrom,
io,
net::{SocketAddr, ToSocketAddrs},
Expand Down

0 comments on commit 59925c0

Please sign in to comment.