diff --git a/benches/helpers.rs b/benches/helpers.rs index 7935650d77..1ed7bc80d3 100644 --- a/benches/helpers.rs +++ b/benches/helpers.rs @@ -132,7 +132,7 @@ pub async fn http_server(handle: tokio::runtime::Handle) -> (String, jsonrpsee:: /// Run jsonrpsee WebSocket server for benchmarks. #[cfg(not(feature = "jsonrpc-crate"))] pub async fn ws_server(handle: tokio::runtime::Handle) -> (String, jsonrpsee::server::ServerHandle) { - use jsonrpsee::server::ServerBuilder; + use jsonrpsee::{core::server::rpc_module::SubscriptionMessage, server::ServerBuilder}; let server = ServerBuilder::default() .max_request_body_size(u32::MAX) @@ -146,11 +146,17 @@ pub async fn ws_server(handle: tokio::runtime::Handle) -> (String, jsonrpsee::se let mut module = gen_rpc_module(); module - .register_subscription(SUB_METHOD_NAME, SUB_METHOD_NAME, UNSUB_METHOD_NAME, |_params, mut sink, _ctx| { - let x = "Hello"; - tokio::spawn(async move { sink.send(&x) }); - Ok(()) - }) + .register_subscription( + SUB_METHOD_NAME, + SUB_METHOD_NAME, + UNSUB_METHOD_NAME, + |_params, pending, _ctx| async move { + let sink = pending.accept().await?; + let msg = SubscriptionMessage::from_json(&"Hello")?; + sink.send(msg).await?; + Ok(()) + }, + ) .unwrap(); let addr = format!("ws://{}", server.local_addr().unwrap()); diff --git a/core/Cargo.toml b/core/Cargo.toml index e1234f94ae..257f3d9ada 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -27,7 +27,6 @@ soketto = { version = "0.7.1", optional = true } parking_lot = { version = "0.12", optional = true } tokio = { version = "1.16", optional = true } wasm-bindgen-futures = { version = "0.4.19", optional = true } -futures-channel = { version = "0.3.14", optional = true } futures-timer = { version = "3", optional = true } globset = { version = "0.4", optional = true } http = { version = "0.2.7", optional = true } @@ -45,7 +44,8 @@ server = [ "rand", "tokio/rt", "tokio/sync", - "futures-channel", + "tokio/macros", + "tokio/time", ] client = ["futures-util/sink", "tokio/sync"] async-client = [ diff --git a/core/src/client/async_client/mod.rs b/core/src/client/async_client/mod.rs index 63fd0aa016..874cdaef87 100644 --- a/core/src/client/async_client/mod.rs +++ b/core/src/client/async_client/mod.rs @@ -40,7 +40,7 @@ use tracing::instrument; use super::{generate_batch_id_range, FrontToBack, IdKind, RequestIdManager}; -/// Wrapper over a [`oneshot::Receiver`](futures_channel::oneshot::Receiver) that reads +/// Wrapper over a [`oneshot::Receiver`](tokio::sync::oneshot::Receiver) that reads /// the underlying channel once and then stores the result in String. /// It is possible that the error is read more than once if several calls are made /// when the background thread has been terminated. diff --git a/core/src/error.rs b/core/src/error.rs index 0dd1ab55ce..9ffd7360be 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -27,8 +27,7 @@ use std::fmt; use jsonrpsee_types::error::{ - CallError, ErrorObject, ErrorObjectOwned, CALL_EXECUTION_FAILED_CODE, INVALID_PARAMS_CODE, SUBSCRIPTION_CLOSED, - UNKNOWN_ERROR_CODE, + CallError, ErrorObject, ErrorObjectOwned, CALL_EXECUTION_FAILED_CODE, INVALID_PARAMS_CODE, UNKNOWN_ERROR_CODE, }; /// Convenience type for displaying errors. @@ -108,21 +107,6 @@ pub enum Error { /// Access control verification of HTTP headers failed. #[error("HTTP header: `{0}` value: `{1}` verification failed")] HttpHeaderRejected(&'static str, String), - /// Failed to execute a method because a resource was already at capacity - #[error("Resource at capacity: {0}")] - ResourceAtCapacity(&'static str), - /// Failed to register a resource due to a name conflict - #[error("Resource name already taken: {0}")] - ResourceNameAlreadyTaken(&'static str), - /// Failed to initialize resources for a method at startup - #[error("Resource name `{0}` not found for method `{1}`")] - ResourceNameNotFoundForMethod(&'static str, &'static str), - /// Trying to claim resources for a method execution, but the method resources have not been initialized - #[error("Method `{0}` has uninitialized resources")] - UninitializedMethod(Box), - /// Failed to register a resource due to a maximum number of resources already registered - #[error("Maximum number of resources reached")] - MaxResourcesReached, /// Custom error. #[error("Custom error: {0}")] Custom(String), @@ -161,34 +145,6 @@ impl From for ErrorObjectOwned { } } -/// A type to represent when a subscription gets closed -/// by either the server or client side. -#[derive(Clone, Debug)] -pub enum SubscriptionClosed { - /// The remote peer closed the connection or called the unsubscribe method. - RemotePeerAborted, - /// The subscription was completed successfully by the server. - Success, - /// The subscription failed during execution by the server. - Failed(ErrorObject<'static>), -} - -impl From for ErrorObjectOwned { - fn from(err: SubscriptionClosed) -> Self { - match err { - SubscriptionClosed::RemotePeerAborted => { - ErrorObject::owned(SUBSCRIPTION_CLOSED, "Subscription was closed by the remote peer", None::<()>) - } - SubscriptionClosed::Success => ErrorObject::owned( - SUBSCRIPTION_CLOSED, - "Subscription was completed by the server successfully", - None::<()>, - ), - SubscriptionClosed::Failed(err) => err, - } - } -} - /// Generic transport error. #[derive(Debug, thiserror::Error)] pub enum GenericTransportError { @@ -229,3 +185,79 @@ impl From for Error { Error::Transport(hyper_err.into()) } } + +/// The error returned by the subscription's method for the rpc server implementation. +/// +/// It provides an abstraction to make the API more ergonomic while handling errors +/// that may occur during the subscription callback. +#[derive(Debug)] +pub enum SubscriptionCallbackError { + /// Error cause is propagated by other code or connection related. + None, + /// Some error happened to be logged. + Some(String), +} + +// User defined error. +impl From for SubscriptionCallbackError { + fn from(e: anyhow::Error) -> Self { + Self::Some(format!("Other: {e}")) + } +} + +// User defined error. +impl From> for SubscriptionCallbackError { + fn from(e: Box) -> Self { + Self::Some(format!("Other: {e}")) + } +} + +impl From for SubscriptionCallbackError { + fn from(e: CallError) -> Self { + Self::Some(e.to_string()) + } +} + +impl From for SubscriptionCallbackError { + fn from(_: SubscriptionAcceptRejectError) -> Self { + Self::None + } +} + +impl From for SubscriptionCallbackError { + fn from(e: serde_json::Error) -> Self { + Self::Some(format!("Failed to parse SubscriptionMessage::from_json: {e}")) + } +} + +#[cfg(feature = "server")] +impl From for SubscriptionCallbackError { + fn from(e: crate::server::rpc_module::TrySendError) -> Self { + Self::Some(format!("SubscriptionSink::try_send failed: {e}")) + } +} + +#[cfg(feature = "server")] +impl From for SubscriptionCallbackError { + fn from(e: crate::server::rpc_module::DisconnectError) -> Self { + Self::Some(format!("SubscriptionSink::send failed: {e}")) + } +} + +#[cfg(feature = "server")] +impl From for SubscriptionCallbackError { + fn from(e: crate::server::rpc_module::SendTimeoutError) -> Self { + Self::Some(format!("SubscriptionSink::send_timeout failed: {e}")) + } +} + +/// The error returned while accepting or rejecting a subscription. +#[derive(Debug, Copy, Clone)] +pub enum SubscriptionAcceptRejectError { + /// The method was already called. + AlreadyCalled, + /// The remote peer closed the connection or called the unsubscribe method. + RemotePeerAborted, + /// The subscription response message was too large. + MessageTooLarge, +} diff --git a/core/src/lib.rs b/core/src/lib.rs index 7a626aad04..96919676f7 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -58,11 +58,19 @@ cfg_client! { /// Shared tracing helpers to trace RPC calls. pub mod tracing; pub use async_trait::async_trait; -pub use error::Error; +pub use error::{Error, SubscriptionAcceptRejectError, SubscriptionCallbackError}; /// JSON-RPC result. pub type RpcResult = std::result::Result; +/// The return type of the subscription's method for the rpc server implementation. +/// +/// **Note**: The error does not contain any data and is discarded on drop. +pub type SubscriptionResult = Result<(), SubscriptionCallbackError>; + +/// Empty server `RpcParams` type to use while registering modules. +pub type EmptyServerParams = Vec<()>; + /// Re-exports for proc-macro library to not require any additional /// dependencies to be explicitly added on the client side. #[doc(hidden)] diff --git a/core/src/server/helpers.rs b/core/src/server/helpers.rs index 4b32be19f1..92f166faa2 100644 --- a/core/src/server/helpers.rs +++ b/core/src/server/helpers.rs @@ -26,15 +26,18 @@ use std::io; use std::sync::Arc; +use std::time::Duration; use crate::tracing::tx_log_from_str; use crate::Error; -use futures_channel::mpsc; use jsonrpsee_types::error::{ErrorCode, ErrorObject, ErrorResponse, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG}; use jsonrpsee_types::{Id, InvalidRequest, Response}; use serde::Serialize; +use tokio::sync::mpsc::{self, Permit}; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore}; +use super::rpc_module::{DisconnectError, SendTimeoutError, SubscriptionMessage, TrySendError}; + /// Bounded writer that allows writing at most `max_len` bytes. /// /// ``` @@ -84,7 +87,7 @@ impl<'a> io::Write for &'a mut BoundedWriter { #[derive(Clone, Debug)] pub struct MethodSink { /// Channel sender. - tx: mpsc::UnboundedSender, + tx: mpsc::Sender, /// Max response size in bytes for a executed call. max_response_size: u32, /// Max log length. @@ -93,12 +96,12 @@ pub struct MethodSink { impl MethodSink { /// Create a new `MethodSink` with unlimited response size. - pub fn new(tx: mpsc::UnboundedSender) -> Self { + pub fn new(tx: mpsc::Sender) -> Self { MethodSink { tx, max_response_size: u32::MAX, max_log_length: u32::MAX } } /// Create a new `MethodSink` with a limited response size. - pub fn new_with_limit(tx: mpsc::UnboundedSender, max_response_size: u32, max_log_length: u32) -> Self { + pub fn new_with_limit(tx: mpsc::Sender, max_response_size: u32, max_log_length: u32) -> Self { MethodSink { tx, max_response_size, max_log_length } } @@ -107,47 +110,75 @@ impl MethodSink { self.tx.is_closed() } - /// Send a JSON-RPC error to the client - pub fn send_error(&self, id: Id, error: ErrorObject) -> bool { - let json = match serde_json::to_string(&ErrorResponse::borrowed(error, id)) { - Ok(json) => json, - Err(err) => { - tracing::error!("Error serializing response: {:?}", err); + /// Same as [`tokio::sync::mpsc::Sender::closed`]. + /// + /// # Cancel safety + /// This method is cancel safe. Once the channel is closed, + /// it stays closed forever and all future calls to closed will return immediately. + pub async fn closed(&self) { + self.tx.closed().await + } - return false; - } - }; + /// Get the max response size. + pub const fn max_response_size(&self) -> u32 { + self.max_response_size + } - tx_log_from_str(&json, self.max_log_length); + /// Attempts to send out the message immediately and fails if the underlying + /// connection has been closed or if the message buffer is full. + /// + /// Returns the message if the send fails such that either can be thrown away or re-sent later. + pub fn try_send(&mut self, msg: String) -> Result<(), TrySendError> { + tx_log_from_str(&msg, self.max_log_length); + self.tx.try_send(msg).map_err(Into::into) + } - if let Err(err) = self.send_raw(json) { - tracing::warn!("Error sending response {:?}", err); - false - } else { - true + /// Async send which will wait until there is space in channel buffer or that the subscription is disconnected. + pub async fn send(&self, msg: String) -> Result<(), DisconnectError> { + tx_log_from_str(&msg, self.max_log_length); + self.tx.send(msg).await.map_err(Into::into) + } + + /// Similar to to `MethodSink::send` but only waits for a limited time. + pub async fn send_timeout(&self, msg: String, timeout: Duration) -> Result<(), SendTimeoutError> { + tx_log_from_str(&msg, self.max_log_length); + self.tx.send_timeout(msg, timeout).await.map_err(Into::into) + } + + /// Waits for channel capacity. Once capacity to send one message is available, it is reserved for the caller. + pub async fn reserve(&self) -> Result { + match self.tx.reserve().await { + Ok(permit) => Ok(MethodSinkPermit { tx: permit, max_log_length: self.max_log_length }), + Err(_) => Err(DisconnectError(SubscriptionMessage::empty())), } } +} + +/// A method sink with reserved spot in the bounded queue. +#[derive(Debug)] +pub struct MethodSinkPermit<'a> { + tx: Permit<'a, String>, + max_log_length: u32, +} + +impl<'a> MethodSinkPermit<'a> { + /// Send a JSON-RPC error to the client + pub fn send_error(self, id: Id, error: ErrorObject) { + let json = serde_json::to_string(&ErrorResponse::borrowed(error, id)).expect("valid JSON; qed"); + + self.send_raw(json) + } /// Helper for sending the general purpose `Error` as a JSON-RPC errors to the client. - pub fn send_call_error(&self, id: Id, err: Error) -> bool { + pub fn send_call_error(self, id: Id, err: Error) { self.send_error(id, err.into()) } - /// Send a raw JSON-RPC message to the client, `MethodSink` does not check verify the validity + /// Send a raw JSON-RPC message to the client, `MethodSink` does not check the validity /// of the JSON being sent. - pub fn send_raw(&self, json: String) -> Result<(), mpsc::TrySendError> { + pub fn send_raw(self, json: String) { + self.tx.send(json.clone()); tx_log_from_str(&json, self.max_log_length); - self.tx.unbounded_send(json) - } - - /// Close the channel for any further messages. - pub fn close(&self) { - self.tx.close_channel(); - } - - /// Get the maximum number of permitted subscriptions. - pub const fn max_response_size(&self) -> u32 { - self.max_response_size } } @@ -331,7 +362,6 @@ impl BatchResponse { #[cfg(test)] mod tests { - use crate::server::helpers::BoundedSubscriptions; use super::{BatchResponseBuilder, BoundedWriter, Id, MethodResponse, Response}; @@ -351,20 +381,6 @@ mod tests { assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err()); } - #[test] - fn bounded_subscriptions_work() { - let subs = BoundedSubscriptions::new(5); - let mut handles = Vec::new(); - - for _ in 0..5 { - handles.push(subs.acquire().unwrap()); - } - - assert!(subs.acquire().is_none()); - handles.swap_remove(0); - assert!(subs.acquire().is_some()); - } - #[test] fn batch_with_single_works() { let method = MethodResponse::response(Id::Number(1), "a", usize::MAX); diff --git a/core/src/server/mod.rs b/core/src/server/mod.rs index b761a8db23..1e5a748076 100644 --- a/core/src/server/mod.rs +++ b/core/src/server/mod.rs @@ -30,7 +30,5 @@ pub mod helpers; /// Host filtering. pub mod host_filtering; -/// Resource limiting. Create generic "resources" and configure their limits to ensure servers are not overloaded. -pub mod resource_limiting; /// JSON-RPC "modules" group sets of methods that belong together and handles method/subscription registration. pub mod rpc_module; diff --git a/core/src/server/resource_limiting.rs b/core/src/server/resource_limiting.rs deleted file mode 100644 index e50398b82c..0000000000 --- a/core/src/server/resource_limiting.rs +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! # Resource Limiting -//! -//! This module handles limiting the capacity of the server to respond to requests. -//! -//! `jsonrpsee` is agnostic about the types of resources available on the server, and the units used are arbitrary. -//! The units are used to model the availability of a resource, be it something mundane like CPU or Memory, -//! or more exotic things like remote API access to a 3rd party service, or use of some external hardware -//! that's under the control of the server. -//! -//! To get the most out of this feature, we suggest benchmarking individual methods to see how many resources they -//! consume, in particular anything critical that is expected to result in a lot of stress on the server, -//! and then defining your units such that the limits (`capacity`) can be adjusted for different hardware configurations. -//! -//! Up to 8 resources can be defined using the [`ServerBuilder::register_resource`](../../../jsonrpsee_server/struct.ServerBuilder.html#method.register_resource) -//! -//! -//! Each method will claim the specified number of units (or the default) for the duration of its execution. -//! Any method execution that would cause the total sum of claimed resource units to exceed -//! the `capacity` of that resource will be denied execution, immediately returning JSON-RPC error object with code `-32604`. -//! -//! Setting the execution cost to `0` equates to the method effectively not being limited by a given resource. Likewise setting the -//! `capacity` to `0` disables any limiting for a given resource. -//! -//! To specify a different than default number of units a method should use, use the `resources` argument in the -//! `#[method]` attribute: -//! -//! ``` -//! # use jsonrpsee::{core::RpcResult, proc_macros::rpc}; -//! # -//! #[rpc(server)] -//! pub trait Rpc { -//! #[method(name = "my_expensive_method", resources("cpu" = 5, "mem" = 2))] -//! async fn my_expensive_method(&self) -> RpcResult<&'static str> { -//! // Do work -//! Ok("hello") -//! } -//! } -//! ``` -//! -//! Alternatively, you can use the `resource` method when creating a module manually without the help of the macro: -//! -//! ``` -//! # use jsonrpsee::{RpcModule, core::RpcResult}; -//! # -//! # fn main() -> RpcResult<()> { -//! # -//! let mut module = RpcModule::new(()); -//! -//! module -//! .register_async_method("my_expensive_method", |_, _| async move { -//! // Do work -//! Result::<_, jsonrpsee::core::Error>::Ok("hello") -//! })? -//! .resource("cpu", 5)? -//! .resource("mem", 2)?; -//! # Ok(()) -//! # } -//! ``` -//! -//! Each resource needs to have a unique name, such as `"cpu"` or `"memory"`, which can then be used across all -//! [`RpcModule`s](crate::server::rpc_module::RpcModule). In case a module definition uses a resource label not -//! defined on the server, starting the server with such a module will result in a runtime error containing the -//! information about the offending method. - -use std::sync::Arc; - -use crate::Error; -use arrayvec::ArrayVec; -use parking_lot::Mutex; - -// The number of kinds of resources that can be used for limiting. -const RESOURCE_COUNT: usize = 8; - -/// Fixed size table, mapping a resource to a (unitless) value indicating the amount of the resource that is available to RPC calls. -pub type ResourceTable = [u16; RESOURCE_COUNT]; -/// Variable size table, mapping a resource to a (unitless) value indicating the amount of the resource that is available to RPC calls. -pub type ResourceVec = ArrayVec; - -/// User defined resources available to be used by calls on the JSON-RPC server. -/// Each of the 8 possible resource kinds, for instance "cpu", "io", "nanobots", -/// store a maximum `capacity` and a default. A value of `0` means no limits for the given resource. -#[derive(Debug, Default, Clone)] -pub struct Resources { - /// Resources currently in use by executing calls. 0 for unused resource kinds. - totals: Arc>, - /// Max capacity for all resource kinds - pub capacities: ResourceTable, - /// Default value for all resource kinds; unless a method has a resource limit defined, this is the cost of a call (0 means no default limit) - pub defaults: ResourceTable, - /// Labels for every registered resource - pub labels: ResourceVec<&'static str>, -} - -impl Resources { - /// Register a new resource kind. Errors if `label` is already registered, or if the total number of - /// registered resources would exceed 8. - pub fn register(&mut self, label: &'static str, capacity: u16, default: u16) -> Result<(), Error> { - if self.labels.iter().any(|&l| l == label) { - return Err(Error::ResourceNameAlreadyTaken(label)); - } - - let idx = self.labels.len(); - - self.labels.try_push(label).map_err(|_| Error::MaxResourcesReached)?; - - self.capacities[idx] = capacity; - self.defaults[idx] = default; - - Ok(()) - } - - /// Attempt to claim `units` units for each resource, incrementing current totals. - /// If successful, returns a [`ResourceGuard`] which decrements the totals by the same - /// amounts once dropped. - pub fn claim(&self, units: ResourceTable) -> Result { - let mut totals = self.totals.lock(); - let mut sum = *totals; - - for (idx, sum) in sum.iter_mut().enumerate() { - match sum.checked_add(units[idx]) { - Some(s) if s <= self.capacities[idx] => *sum = s, - _ => { - let label = self.labels.get(idx).copied().unwrap_or(""); - - return Err(Error::ResourceAtCapacity(label)); - } - } - } - - *totals = sum; - - Ok(ResourceGuard { totals: self.totals.clone(), units }) - } -} - -/// RAII style "lock" for claimed resources, will automatically release them once dropped. -#[derive(Debug)] -pub struct ResourceGuard { - totals: Arc>, - units: ResourceTable, -} - -impl Drop for ResourceGuard { - fn drop(&mut self) { - for (sum, claimed) in self.totals.lock().iter_mut().zip(self.units) { - *sum -= claimed; - } - } -} diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 8a0fff850d..dbe7206e4d 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -29,47 +29,38 @@ use std::fmt::{self, Debug}; use std::future::Future; use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use std::time::Duration; -use crate::error::{Error, SubscriptionClosed}; +use crate::error::{Error, SubscriptionAcceptRejectError}; use crate::id_providers::RandomIntegerIdProvider; -use crate::server::helpers::{BoundedSubscriptions, MethodSink, SubscriptionPermit}; -use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec, Resources}; +use crate::server::helpers::{BoundedSubscriptions, MethodSink}; use crate::traits::{IdProvider, ToRpcParams}; -use futures_channel::{mpsc, oneshot}; +use crate::{SubscriptionCallbackError, SubscriptionResult}; use futures_util::future::Either; -use futures_util::pin_mut; -use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt, TryStream, TryStreamExt}; -use jsonrpsee_types::error::{ - CallError, ErrorCode, ErrorObject, ErrorObjectOwned, SubscriptionAcceptRejectError, INTERNAL_ERROR_CODE, - SUBSCRIPTION_CLOSED_WITH_ERROR, -}; +use futures_util::{future::BoxFuture, FutureExt}; +use jsonrpsee_types::error::{CallError, ErrorCode, ErrorObject, ErrorObjectOwned}; use jsonrpsee_types::response::{SubscriptionError, SubscriptionPayloadError}; use jsonrpsee_types::{ - ErrorResponse, Id, Params, Request, Response, SubscriptionId as RpcSubscriptionId, SubscriptionPayload, - SubscriptionResponse, SubscriptionResult, + ErrorResponse, Id, Params, Request, Response, SubscriptionId as RpcSubscriptionId, SubscriptionResponse, }; use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::{de::DeserializeOwned, Serialize}; -use tokio::sync::watch; +use tokio::sync::{mpsc, oneshot}; -use super::helpers::MethodResponse; +use super::helpers::{MethodResponse, SubscriptionPermit}; /// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request, /// implemented as a function pointer to a `Fn` function taking four arguments: /// the `id`, `params`, a channel the function uses to communicate the result (or error) /// back to `jsonrpsee`, and the connection ID (useful for the websocket transport). pub type SyncMethod = Arc MethodResponse>; -/// Similar to [`SyncMethod`], but represents an asynchronous handler and takes an additional argument containing a [`ResourceGuard`] if configured. -pub type AsyncMethod<'a> = Arc< - dyn Send - + Sync - + Fn(Id<'a>, Params<'a>, ConnectionId, MaxResponseSize, Option) -> BoxFuture<'a, MethodResponse>, ->; +/// Similar to [`SyncMethod`], but represents an asynchronous handler. +pub type AsyncMethod<'a> = + Arc, Params<'a>, ConnectionId, MaxResponseSize) -> BoxFuture<'a, MethodResponse>>; /// Method callback for subscriptions. -pub type SubscriptionMethod<'a> = Arc< - dyn Send + Sync + Fn(Id, Params, MethodSink, ConnState, Option) -> BoxFuture<'a, MethodResponse>, ->; +pub type SubscriptionMethod<'a> = + Arc BoxFuture<'a, SubscriptionAnswered>>; // Method callback to unsubscribe. type UnsubscriptionMethod = Arc MethodResponse>; @@ -85,16 +76,77 @@ pub type MaxResponseSize = usize; /// - Call result as a `String`, /// - a [`mpsc::UnboundedReceiver`] to receive future subscription results /// - a [`crate::server::helpers::SubscriptionPermit`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect. -pub type RawRpcResponse = (MethodResponse, mpsc::UnboundedReceiver, SubscriptionPermit); +pub type RawRpcResponse = (MethodResponse, mpsc::Receiver, SubscriptionPermit, mpsc::Sender); + +/// Error that may occur during `SubscriptionSink::try_send`. +#[derive(Debug)] +pub enum TrySendError { + /// The channel is closed. + Closed(SubscriptionMessage), + /// The channel is full. + Full(SubscriptionMessage), +} + +impl std::fmt::Display for TrySendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let msg = match self { + Self::Closed(_) => "closed", + Self::Full(_) => "full", + }; + f.write_str(msg) + } +} + +#[derive(Debug, Clone)] +/// Represents whether a subscription was answered or not. +pub enum SubscriptionAnswered { + /// The subscription was already answered and doesn't need to answered again. + /// The response is kept to be logged. + Yes(MethodResponse), + /// The subscription was never answered and needs to be answered. + /// + /// This may occur if a subscription dropped without calling `PendingSubscriptionSink::accept` or `PendingSubscriptionSink::reject`. + No(MethodResponse), +} + +/// Error that may occur during `MethodSink::send` or `SubscriptionSink::send`. +#[derive(Debug)] +pub struct DisconnectError(pub SubscriptionMessage); + +impl std::fmt::Display for DisconnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("closed") + } +} + +/// Error that may occur during `SubscriptionSink::send_timeout`. +#[derive(Debug)] +pub enum SendTimeoutError { + /// The data could not be sent because the timeout elapsed + /// which most likely is that the channel is full. + Timeout(SubscriptionMessage), + /// The channel is full. + Closed(SubscriptionMessage), +} + +impl std::fmt::Display for SendTimeoutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let msg = match self { + Self::Timeout(_) => "timed out waiting on send operation", + Self::Closed(_) => "closed", + }; + f.write_str(msg) + } +} /// Helper struct to manage subscriptions. pub struct ConnState<'a> { /// Connection ID pub conn_id: ConnectionId, - /// Get notified when the connection to subscribers is closed. - pub close_notify: SubscriptionPermit, /// ID provider. pub id_provider: &'a dyn IdProvider, + /// Subscription limit + pub subscription_permit: SubscriptionPermit, } /// Outcome of a successful terminated subscription. @@ -106,13 +158,103 @@ pub enum InnerSubscriptionResult { Aborted, } +impl From> for DisconnectError { + fn from(e: mpsc::error::SendError) -> Self { + DisconnectError(SubscriptionMessage::from_complete_message(e.0)) + } +} + +impl From> for TrySendError { + fn from(e: mpsc::error::TrySendError) -> Self { + match e { + mpsc::error::TrySendError::Closed(m) => Self::Closed(SubscriptionMessage::from_complete_message(m)), + mpsc::error::TrySendError::Full(m) => Self::Full(SubscriptionMessage::from_complete_message(m)), + } + } +} + +impl From> for SendTimeoutError { + fn from(e: mpsc::error::SendTimeoutError) -> Self { + match e { + mpsc::error::SendTimeoutError::Closed(m) => Self::Closed(SubscriptionMessage::from_complete_message(m)), + mpsc::error::SendTimeoutError::Timeout(m) => Self::Timeout(SubscriptionMessage::from_complete_message(m)), + } + } +} + impl<'a> std::fmt::Debug for ConnState<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close_notify).finish() + f.debug_struct("ConnState").field("conn_id", &self.conn_id).finish() + } +} + +type Subscribers = Arc)>>>; + +/// This represent a response to a RPC call +/// and `Subscribe` calls are handled differently +/// because we want to prevent subscriptions to start +/// before the actual subscription call has been answered. +#[derive(Debug, Clone)] +pub enum CallOrSubscription { + /// The subscription callback itself sends back the result + /// so it must not be sent back again. + Subscription(SubscriptionAnswered), + + /// Treat it as ordinary call. + Call(MethodResponse), +} + +impl CallOrSubscription { + /// Extract the JSON-RPC response. + pub fn as_response(&self) -> &MethodResponse { + match &self { + Self::Subscription(r) => match r { + SubscriptionAnswered::Yes(r) => r, + SubscriptionAnswered::No(r) => r, + }, + Self::Call(r) => r, + } + } + + /// Convert the `CallOrSubscription` to JSON-RPC response. + pub fn into_response(self) -> MethodResponse { + match self { + Self::Subscription(r) => match r { + SubscriptionAnswered::Yes(r) => r, + SubscriptionAnswered::No(r) => r, + }, + Self::Call(r) => r, + } } } -type Subscribers = Arc)>>>; +/// A complete subscription message or partial subscription message. +#[derive(Debug, Clone)] +pub enum SubscriptionMessageInner { + /// Complete JSON message. + Complete(String), + /// Need subscription ID and method name. + NeedsData(String), +} + +/// Subscription message. +#[derive(Debug, Clone)] +pub struct SubscriptionMessage(pub(crate) SubscriptionMessageInner); + +impl SubscriptionMessage { + /// Create a new subscription message from JSON. + pub fn from_json(t: &impl Serialize) -> Result { + serde_json::to_string(t).map(|json| SubscriptionMessage(SubscriptionMessageInner::NeedsData(json))) + } + + pub(crate) fn from_complete_message(msg: String) -> Self { + SubscriptionMessage(SubscriptionMessageInner::Complete(msg)) + } + + pub(crate) fn empty() -> Self { + Self::from_complete_message(String::new()) + } +} /// Represent a unique subscription entry based on [`RpcSubscriptionId`] and [`ConnectionId`]. #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -134,21 +276,10 @@ pub enum MethodKind { Unsubscription(UnsubscriptionMethod), } -/// Information about resources the method uses during its execution. Initialized when the the server starts. -#[derive(Clone, Debug)] -enum MethodResources { - /// Uninitialized resource table, mapping string label to units. - Uninitialized(Box<[(&'static str, u16)]>), - /// Initialized resource table containing units for each `ResourceId`. - Initialized(ResourceTable), -} - /// Method callback wrapper that contains a sync or async closure, -/// plus a table with resources it needs to claim to run #[derive(Clone, Debug)] pub struct MethodCallback { callback: MethodKind, - resources: MethodResources, } /// Result of a method, either direct value or a future of one. @@ -168,57 +299,21 @@ impl Debug for MethodResult { } } -/// Builder for configuring resources used by a method. -#[derive(Debug)] -pub struct MethodResourcesBuilder<'a> { - build: ResourceVec<(&'static str, u16)>, - callback: &'a mut MethodCallback, -} - -impl<'a> MethodResourcesBuilder<'a> { - /// Define how many units of a given named resource the method uses during its execution. - pub fn resource(mut self, label: &'static str, units: u16) -> Result { - self.build.try_push((label, units)).map_err(|_| Error::MaxResourcesReached)?; - Ok(self) - } -} - -impl<'a> Drop for MethodResourcesBuilder<'a> { - fn drop(&mut self) { - self.callback.resources = MethodResources::Uninitialized(self.build[..].into()); - } -} - impl MethodCallback { fn new_sync(callback: SyncMethod) -> Self { - MethodCallback { callback: MethodKind::Sync(callback), resources: MethodResources::Uninitialized([].into()) } + MethodCallback { callback: MethodKind::Sync(callback) } } fn new_async(callback: AsyncMethod<'static>) -> Self { - MethodCallback { callback: MethodKind::Async(callback), resources: MethodResources::Uninitialized([].into()) } + MethodCallback { callback: MethodKind::Async(callback) } } fn new_subscription(callback: SubscriptionMethod<'static>) -> Self { - MethodCallback { - callback: MethodKind::Subscription(callback), - resources: MethodResources::Uninitialized([].into()), - } + MethodCallback { callback: MethodKind::Subscription(callback) } } fn new_unsubscription(callback: UnsubscriptionMethod) -> Self { - MethodCallback { - callback: MethodKind::Unsubscription(callback), - resources: MethodResources::Uninitialized([].into()), - } - } - - /// Attempt to claim resources prior to executing a method. On success returns a guard that releases - /// claimed resources when dropped. - pub fn claim(&self, name: &str, resources: &Resources) -> Result { - match self.resources { - MethodResources::Uninitialized(_) => Err(Error::UninitializedMethod(name.into())), - MethodResources::Initialized(units) => resources.claim(units), - } + MethodCallback { callback: MethodKind::Unsubscription(callback) } } /// Get handle to the callback. @@ -271,36 +366,6 @@ impl Methods { } } - /// Initialize resources for all methods in this collection. This method has no effect if called more than once. - pub fn initialize_resources(mut self, resources: &Resources) -> Result { - let callbacks = self.mut_callbacks(); - - for (&method_name, callback) in callbacks.iter_mut() { - if let MethodResources::Uninitialized(uninit) = &callback.resources { - let mut map = resources.defaults; - - for &(label, units) in uninit.iter() { - let idx = match resources.labels.iter().position(|&l| l == label) { - Some(idx) => idx, - None => return Err(Error::ResourceNameNotFoundForMethod(label, method_name)), - }; - - // If resource capacity set to `0`, we ignore the unit value of the method - // and set it to `0` as well, effectively making the resource unlimited. - if resources.capacities[idx] == 0 { - map[idx] = 0; - } else { - map[idx] = units; - } - } - - callback.resources = MethodResources::Initialized(map); - } - } - - Ok(self) - } - /// Helper for obtaining a mut ref to the callbacks HashMap. fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> { Arc::make_mut(&mut self.callbacks) @@ -365,7 +430,7 @@ impl Methods { let params = params.to_rpc_params()?; let req = Request::new(method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0)); tracing::trace!("[Methods::call] Method: {:?}, params: {:?}", method, params); - let (resp, _, _) = self.inner_call(req).await; + let (resp, _, _, _) = self.inner_call(req, 1).await; if resp.success { serde_json::from_str::>(&resp.result).map(|r| r.result).map_err(Into::into) @@ -386,18 +451,20 @@ impl Methods { /// ``` /// #[tokio::main] /// async fn main() { - /// use jsonrpsee::RpcModule; + /// use jsonrpsee::{RpcModule, SubscriptionMessage}; /// use jsonrpsee::types::Response; /// use futures_util::StreamExt; /// /// let mut module = RpcModule::new(()); - /// module.register_subscription("hi", "hi", "goodbye", |_, mut sink, _| { - /// sink.send(&"one answer").unwrap(); + /// module.register_subscription("hi", "hi", "goodbye", |_, pending, _| async { + /// let sink = pending.accept().await?; + /// let msg = SubscriptionMessage::from_json(&"one answer")?; + /// sink.send(msg).await?; /// Ok(()) /// }).unwrap(); - /// let (resp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"hi","id":0}"#).await.unwrap(); + /// let (resp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"hi","id":0}"#, 1).await.unwrap(); /// let resp = serde_json::from_str::>(&resp.result).unwrap(); - /// let sub_resp = stream.next().await.unwrap(); + /// let sub_resp = stream.recv().await.unwrap(); /// assert_eq!( /// format!(r#"{{"jsonrpc":"2.0","method":"hi","params":{{"subscription":{},"result":"one answer"}}}}"#, resp.result), /// sub_resp @@ -407,92 +474,110 @@ impl Methods { pub async fn raw_json_request( &self, request: &str, - ) -> Result<(MethodResponse, mpsc::UnboundedReceiver), Error> { + buf_size: usize, + ) -> Result<(MethodResponse, mpsc::Receiver), Error> { tracing::trace!("[Methods::raw_json_request] Request: {:?}", request); let req: Request = serde_json::from_str(request)?; - let (resp, rx, _) = self.inner_call(req).await; + let (resp, rx, _, _) = self.inner_call(req, buf_size).await; + Ok((resp, rx)) } /// Execute a callback. - async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse { - let (tx_sink, mut rx_sink) = mpsc::unbounded(); - let sink = MethodSink::new(tx_sink); + async fn inner_call(&self, req: Request<'_>, buf_size: usize) -> RawRpcResponse { + let (tx, mut rx) = mpsc::channel(buf_size); let id = req.id.clone(); let params = Params::new(req.params.map(|params| params.get())); let bounded_subs = BoundedSubscriptions::new(u32::MAX); - let close_notify = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed"); - let notify = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed"); + let p1 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed"); + let p2 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed"); let response = match self.method(&req.method).map(|c| &c.callback) { None => MethodResponse::error(req.id, ErrorObject::from(ErrorCode::MethodNotFound)), Some(MethodKind::Sync(cb)) => (cb)(id, params, usize::MAX), - Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), 0, usize::MAX, None).await, + Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), 0, usize::MAX).await, Some(MethodKind::Subscription(cb)) => { - let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider }; - let res = (cb)(id, params, sink.clone(), conn_state, None).await; + let conn_state = + ConnState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit: p1 }; + let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state).await; // This message is not used because it's used for metrics so we discard in other to // not read once this is used for subscriptions. // // The same information is part of `res` above. - let _ = rx_sink.next().await.expect("Every call must at least produce one response; qed"); + let _ = rx.recv().await.expect("Every call must at least produce one response; qed"); - res + match res { + SubscriptionAnswered::Yes(r) => r, + SubscriptionAnswered::No(r) => r, + } } Some(MethodKind::Unsubscription(cb)) => (cb)(id, params, 0, usize::MAX), }; tracing::trace!("[Methods::inner_call] Method: {}, response: {:?}", req.method, response); - (response, rx_sink, notify) + (response, rx, p2, tx) } /// Helper to create a subscription on the `RPC module` without having to spin up a server. /// /// The params must be serializable as JSON array, see [`ToRpcParams`] for further documentation. /// - /// Returns [`Subscription`] on success which can used to get results from the subscriptions. + /// Returns [`Subscription`] on success which can used to get results from the subscription. /// /// # Examples /// /// ``` /// #[tokio::main] /// async fn main() { - /// use jsonrpsee::{RpcModule, types::EmptyServerParams}; + /// use jsonrpsee::{RpcModule, core::EmptyServerParams, SubscriptionMessage}; /// /// let mut module = RpcModule::new(()); - /// module.register_subscription("hi", "hi", "goodbye", |_, mut sink, _| { - /// sink.send(&"one answer").unwrap(); + /// module.register_subscription("hi", "hi", "goodbye", |_, pending, _| async move { + /// let sink = pending.accept().await?; + /// + /// let msg = SubscriptionMessage::from_json(&"one answer")?; + /// sink.send(msg).await?; /// Ok(()) + /// /// }).unwrap(); /// - /// let mut sub = module.subscribe("hi", EmptyServerParams::new()).await.unwrap(); + /// let mut sub = module.subscribe_unbounded("hi", EmptyServerParams::new()).await.unwrap(); /// // In this case we ignore the subscription ID, /// let (sub_resp, _sub_id) = sub.next::().await.unwrap().unwrap(); /// assert_eq!(&sub_resp, "one answer"); /// } /// ``` - pub async fn subscribe(&self, sub_method: &str, params: impl ToRpcParams) -> Result { + pub async fn subscribe_unbounded(&self, sub_method: &str, params: impl ToRpcParams) -> Result { + self.subscribe(sub_method, params, u32::MAX as usize).await + } + + /// Similar to [`Methods::subscribe_unbounded`] but it's using a bounded channel and the buffer capacity must be provided. + pub async fn subscribe( + &self, + sub_method: &str, + params: impl ToRpcParams, + buf_size: usize, + ) -> Result { let params = params.to_rpc_params()?; let req = Request::new(sub_method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0)); tracing::trace!("[Methods::subscribe] Method: {}, params: {:?}", sub_method, params); - let (response, rx, close_notify) = self.inner_call(req).await; + let (resp, rx, permit, tx) = self.inner_call(req, buf_size).await; - let subscription_response = match serde_json::from_str::>(&response.result) { + let subscription_response = match serde_json::from_str::>(&resp.result) { Ok(r) => r, - Err(_) => match serde_json::from_str::(&response.result) { + Err(_) => match serde_json::from_str::(&resp.result) { Ok(err) => return Err(Error::Call(CallError::Custom(err.error_object().clone().into_owned()))), Err(err) => return Err(err.into()), }, }; let sub_id = subscription_response.result.into_owned(); - let close_notify = Some(close_notify); - Ok(Subscription { sub_id, rx, close_notify }) + Ok(Subscription { sub_id, rx, tx: MethodSink::new(tx), _permit: permit }) } /// Returns an `Iterator` with all the method names registered on this server. @@ -550,22 +635,20 @@ impl RpcModule { &mut self, method_name: &'static str, callback: F, - ) -> Result + ) -> Result<&mut MethodCallback, Error> where Context: Send + Sync + 'static, R: Serialize, F: Fn(Params, &Context) -> Result + Send + Sync + 'static, { let ctx = self.ctx.clone(); - let callback = self.methods.verify_and_insert( + self.methods.verify_and_insert( method_name, MethodCallback::new_sync(Arc::new(move |id, params, max_response_size| match callback(params, &*ctx) { Ok(res) => MethodResponse::response(id, res, max_response_size), Err(err) => MethodResponse::error(id, err), })), - )?; - - Ok(MethodResourcesBuilder { build: ResourceVec::new(), callback }) + ) } /// Register a new asynchronous RPC method, which computes the response with the given callback. @@ -573,7 +656,7 @@ impl RpcModule { &mut self, method_name: &'static str, callback: Fun, - ) -> Result + ) -> Result<&mut MethodCallback, Error> where R: Serialize + Send + Sync + 'static, E: Into, @@ -581,28 +664,21 @@ impl RpcModule { Fun: (Fn(Params<'static>, Arc) -> Fut) + Clone + Send + Sync + 'static, { let ctx = self.ctx.clone(); - let callback = self.methods.verify_and_insert( + self.methods.verify_and_insert( method_name, - MethodCallback::new_async(Arc::new(move |id, params, _, max_response_size, claimed| { + MethodCallback::new_async(Arc::new(move |id, params, _, max_response_size| { let ctx = ctx.clone(); let callback = callback.clone(); let future = async move { - let result = match callback(params, ctx).await { + match callback(params, ctx).await { Ok(res) => MethodResponse::response(id, res, max_response_size), Err(err) => MethodResponse::error(id, err.into()), - }; - - // Release claimed resources - drop(claimed); - - result + } }; future.boxed() })), - )?; - - Ok(MethodResourcesBuilder { build: ResourceVec::new(), callback }) + ) } /// Register a new **blocking** synchronous RPC method, which computes the response with the given callback. @@ -611,7 +687,7 @@ impl RpcModule { &mut self, method_name: &'static str, callback: F, - ) -> Result + ) -> Result<&mut MethodCallback, Error> where Context: Send + Sync + 'static, R: Serialize, @@ -621,20 +697,13 @@ impl RpcModule { let ctx = self.ctx.clone(); let callback = self.methods.verify_and_insert( method_name, - MethodCallback::new_async(Arc::new(move |id, params, _, max_response_size, claimed| { + MethodCallback::new_async(Arc::new(move |id, params, _, max_response_size| { let ctx = ctx.clone(); let callback = callback.clone(); - tokio::task::spawn_blocking(move || { - let result = match callback(params, ctx) { - Ok(result) => MethodResponse::response(id, result, max_response_size), - Err(err) => MethodResponse::error(id, err.into()), - }; - - // Release claimed resources - drop(claimed); - - result + tokio::task::spawn_blocking(move || match callback(params, ctx) { + Ok(result) => MethodResponse::response(id, result, max_response_size), + Err(err) => MethodResponse::error(id, err.into()), }) .map(|result| match result { Ok(r) => r, @@ -647,7 +716,7 @@ impl RpcModule { })), )?; - Ok(MethodResourcesBuilder { build: ResourceVec::new(), callback }) + Ok(callback) } /// Register a new publish/subscribe interface using JSON-RPC notifications. @@ -657,7 +726,7 @@ impl RpcModule { /// /// Furthermore, it generates the `unsubscribe implementation` where a `bool` is used as /// the result to indicate whether the subscription was successfully unsubscribed to or not. - /// For instance an `unsubscribe call` may fail if a non-existent subscriptionID is used in the call. + /// For instance an `unsubscribe call` may fail if a non-existent subscription ID is used in the call. /// /// This method ensures that the `subscription_method_name` and `unsubscription_method_name` are unique. /// The `notif_method_name` argument sets the content of the `method` field in the JSON document that @@ -671,45 +740,54 @@ impl RpcModule { /// * `unsubscription_method` - name of the method to call to terminate a subscription /// * `callback` - A callback to invoke on each subscription; it takes three parameters: /// - [`Params`]: JSON-RPC parameters in the subscription call. - /// - [`SubscriptionSink`]: A sink to send messages to the subscriber. + /// - [`PendingSubscriptionSink`]: A pending subscription waiting to be accepted, in order to send out messages on the subscription /// - Context: Any type that can be embedded into the [`RpcModule`]. + /// + /// # Returns + /// + /// An async block which returns `Result<(), SubscriptionCallbackError>` the error is simply + /// for a more ergonomic API and is not used (except logged for user-related caused errors). + /// By default jsonrpsee doesn't send any special close notification, + /// it can be a footgun if one wants to send out a "special notification" to indicate that an error occurred. + /// + /// If you want to a special error notification use `SubscriptionSink::close` or + /// `SubscriptionSink::send` before returning from the async block. /// /// # Examples /// /// ```no_run /// - /// use jsonrpsee_core::server::rpc_module::{RpcModule, SubscriptionSink}; + /// use jsonrpsee_core::server::rpc_module::{RpcModule, SubscriptionSink, SubscriptionMessage}; /// use jsonrpsee_core::Error; /// /// let mut ctx = RpcModule::new(99_usize); - /// ctx.register_subscription("sub", "notif_name", "unsub", |params, mut sink, ctx| { - /// let x = match params.one::() { - /// Ok(x) => x, - /// Err(e) => { - /// let err: Error = e.into(); - /// sink.reject(err); - /// return Ok(()); - /// } - /// }; - /// // Sink is accepted on the first `send` call. - /// std::thread::spawn(move || { - /// let sum = x + (*ctx); - /// let _ = sink.send(&sum); - /// }); + /// ctx.register_subscription("sub", "notif_name", "unsub", |params, pending, ctx| async move { + /// let x = params.one::()?; + /// + /// // mark the subscription is accepted after the params has been parsed successful. + /// let sink = pending.accept().await?; + /// + /// let sum = x + (*ctx); + /// + /// // NOTE: the error handling here is for easy of use + /// // and are thrown away + /// let msg = SubscriptionMessage::from_json(&sum)?; + /// sink.send(msg).await?; /// /// Ok(()) /// }); /// ``` - pub fn register_subscription( + pub fn register_subscription( &mut self, subscribe_method_name: &'static str, notif_method_name: &'static str, unsubscribe_method_name: &'static str, callback: F, - ) -> Result + ) -> Result<&mut MethodCallback, Error> where Context: Send + Sync + 'static, - F: Fn(Params, SubscriptionSink, Arc) -> SubscriptionResult + Send + Sync + 'static, + F: (Fn(Params<'static>, PendingSubscriptionSink, Arc) -> Fut) + Send + Sync + Clone + 'static, + Fut: Future + Send + 'static, { if subscribe_method_name == unsubscribe_method_name { return Err(Error::SubscriptionNameConflict(subscribe_method_name.into())); @@ -745,14 +823,13 @@ impl RpcModule { let result = subscribers.lock().remove(&key).is_some(); if !result { - tracing::warn!( + tracing::debug!( "Unsubscribe call `{}` subscription key={:?} not an active subscription", unsubscribe_method_name, key, ); } - // TODO: register as failed in !result. MethodResponse::response(id, result, max_response_size) })), ); @@ -762,34 +839,43 @@ impl RpcModule { let callback = { self.methods.verify_and_insert( subscribe_method_name, - MethodCallback::new_subscription(Arc::new(move |id, params, method_sink, conn, claimed| { + MethodCallback::new_subscription(Arc::new(move |id, params, method_sink, conn| { let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() }; // response to the subscription call. let (tx, rx) = oneshot::channel(); - let sink = SubscriptionSink { + let sink = PendingSubscriptionSink { inner: method_sink, - close_notify: Some(conn.close_notify), method: notif_method_name, subscribers: subscribers.clone(), uniq_sub, - id: Some((id.clone().into_owned(), tx)), - unsubscribe: None, - _claimed: claimed, + id: id.clone().into_owned(), + subscribe: tx, + permit: conn.subscription_permit, }; - // The callback returns a `SubscriptionResult` for better ergonomics and is not propagated further. - if callback(params, sink, ctx.clone()).is_err() { - tracing::warn!("Subscribe call `{}` failed", subscribe_method_name); - } + // The subscription callback is a future from the subscription + // definition and not the as same when the subscription call has been completed. + // + // This runs until the subscription callback has completed. + let sub_fut = callback(params.into_owned(), sink, ctx.clone()); + + tokio::spawn(async move { + if let Err(SubscriptionCallbackError::Some(msg)) = sub_fut.await { + tracing::warn!("Subscribe call `{subscribe_method_name}` failed: {msg}"); + } + }); let id = id.clone().into_owned(); let result = async move { match rx.await { - Ok(result) => result, - Err(_) => MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)), + Ok(r) => SubscriptionAnswered::Yes(r), + Err(_) => { + let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)); + SubscriptionAnswered::No(response) + } } }; @@ -798,7 +884,7 @@ impl RpcModule { )? }; - Ok(MethodResourcesBuilder { build: ResourceVec::new(), callback }) + Ok(callback) } /// Register an alias for an existing_method. Alias uniqueness is enforced. @@ -816,270 +902,196 @@ impl RpcModule { } } -/// Returns once the unsubscribe method has been called. -type UnsubscribeCall = Option>; +/// Represents a subscription until it is unsubscribed. +/// +// NOTE: The reason why we use `mpsc` here is because it allows `IsUnsubscribed::unsubscribed` +// to be &self instead of &mut self. +#[derive(Debug, Clone)] +struct IsUnsubscribed(mpsc::Sender<()>); -/// Represents a single subscription. +impl IsUnsubscribed { + /// Returns true if the unsubscribe method has been invoked or the subscription has been canceled. + /// + /// This can be called multiple times as the element in the channel is never + /// removed. + fn is_unsubscribed(&self) -> bool { + self.0.is_closed() + } + + /// Wrapper over [`tokio::sync::mpsc::Sender::closed`] + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, + /// it stays closed forever and all future calls to closed will return immediately. + async fn unsubscribed(&self) { + self.0.closed().await; + } +} + +/// Represents a single subscription that is waiting to be accepted or rejected. +/// +/// If this is dropped without calling `PendingSubscription::reject` or `PendingSubscriptionSink::accept` +/// a default error is sent out as response to the subscription call. +/// +/// Thus, if you want a customized error message then `PendingSubscription::reject` must be called. #[derive(Debug)] -pub struct SubscriptionSink { +#[must_use = "PendningSubscriptionSink does nothing unless `accept` or `reject` is called"] +pub struct PendingSubscriptionSink { /// Sink. inner: MethodSink, - /// Get notified when subscribers leave so we can exit - close_notify: Option, /// MethodCallback. method: &'static str, /// Shared Mutex of subscriptions for this method. subscribers: Subscribers, /// Unique subscription. uniq_sub: SubscriptionKey, - /// Id of the `subscription call` (i.e. not the same as subscription id) which is used + /// ID of the `subscription call` (i.e. not the same as subscription id) which is used /// to reply to subscription method call and must only be used once. - /// - /// *Note*: Having some value means the subscription was not accepted or rejected yet. - id: Option<(Id<'static>, oneshot::Sender)>, - /// Having some value means the subscription was accepted. - unsubscribe: UnsubscribeCall, - /// Claimed resources. - _claimed: Option, + id: Id<'static>, + /// Sender to answer the subscribe call. + subscribe: oneshot::Sender, + /// Subscription permit. + permit: SubscriptionPermit, } -impl SubscriptionSink { - /// Reject the subscription call from [`ErrorObject`]. - pub fn reject(&mut self, err: impl Into) -> Result<(), SubscriptionAcceptRejectError> { - let (id, subscribe_call) = self.id.take().ok_or(SubscriptionAcceptRejectError::AlreadyCalled)?; +impl PendingSubscriptionSink { + /// Reject the subscription call with the error from [`ErrorObject`]. + pub async fn reject(self, err: impl Into) -> Result<(), SubscriptionAcceptRejectError> { + let err = MethodResponse::error(self.id, err.into()); + self.inner.send(err.result.clone()).await.map_err(|_| SubscriptionAcceptRejectError::RemotePeerAborted)?; + self.subscribe.send(err).map_err(|_| SubscriptionAcceptRejectError::RemotePeerAborted)?; - let err = MethodResponse::error(id, err.into()); - - if self.answer_subscription(err, subscribe_call) { - Ok(()) - } else { - Err(SubscriptionAcceptRejectError::RemotePeerAborted) - } + Ok(()) } /// Attempt to accept the subscription and respond the subscription method call. /// - /// Fails if the connection was closed, or if called multiple times. - pub fn accept(&mut self) -> Result<(), SubscriptionAcceptRejectError> { - let (id, subscribe_call) = self.id.take().ok_or(SubscriptionAcceptRejectError::AlreadyCalled)?; - - let response = MethodResponse::response(id, &self.uniq_sub.sub_id, self.inner.max_response_size() as usize); + /// Fails if the connection was closed or the message was too large. + pub async fn accept(self) -> Result { + let response = + MethodResponse::response(self.id, &self.uniq_sub.sub_id, self.inner.max_response_size() as usize); let success = response.success; - - let sent = self.answer_subscription(response, subscribe_call); - - if sent && success { - let (tx, rx) = watch::channel(()); - self.subscribers.lock().insert(self.uniq_sub.clone(), (self.inner.clone(), tx)); - self.unsubscribe = Some(rx); - Ok(()) + self.inner.send(response.result.clone()).await.map_err(|_| SubscriptionAcceptRejectError::RemotePeerAborted)?; + self.subscribe.send(response).map_err(|_| SubscriptionAcceptRejectError::RemotePeerAborted)?; + + if success { + let (tx, rx) = mpsc::channel(1); + self.subscribers.lock().insert(self.uniq_sub.clone(), (self.inner.clone(), rx)); + Ok(SubscriptionSink { + inner: self.inner, + method: self.method, + subscribers: self.subscribers, + uniq_sub: self.uniq_sub, + unsubscribe: IsUnsubscribed(tx), + _permit: Arc::new(self.permit), + }) } else { - Err(SubscriptionAcceptRejectError::RemotePeerAborted) + Err(SubscriptionAcceptRejectError::MessageTooLarge) } } +} - /// Return the subscription ID if the the subscription was accepted. - /// - /// [`SubscriptionSink::accept`] should be called prior to this method. - pub fn subscription_id(&self) -> Option> { - if self.id.is_some() { - // Subscription was not accepted. - None - } else { - Some(self.uniq_sub.sub_id.clone()) - } - } +/// Represents a single subscription that hasn't been processed yet. +#[derive(Debug, Clone)] +pub struct SubscriptionSink { + /// Sink. + inner: MethodSink, + /// MethodCallback. + method: &'static str, + /// Shared Mutex of subscriptions for this method. + subscribers: Subscribers, + /// Unique subscription. + uniq_sub: SubscriptionKey, + /// A future to that fires once the unsubscribe method has been called. + unsubscribe: IsUnsubscribed, + /// Subscription permit + _permit: Arc, +} - /// Send a message back to subscribers. - /// - /// Returns - /// - `Ok(true)` if the message could be send. - /// - `Ok(false)` if the sink was closed (either because the subscription was closed or the connection was terminated), - /// or the subscription could not be accepted. - /// - `Err(err)` if the message could not be serialized. - pub fn send(&mut self, result: &T) -> Result { - // Cannot accept the subscription. - if let Err(SubscriptionAcceptRejectError::RemotePeerAborted) = self.accept() { - return Ok(false); - } +impl SubscriptionSink { + /// Get the subscription ID. + pub fn subscription_id(&self) -> RpcSubscriptionId<'static> { + self.uniq_sub.sub_id.clone() + } - self.send_without_accept(result) + /// Get the method name. + pub fn method_name(&self) -> &str { + self.method } - /// Reads data from the `stream` and sends back data on the subscription - /// when items gets produced by the stream. - /// The underlying stream must produce `Result values, see [`futures_util::TryStream`] for further information. - /// - /// Returns `Ok(())` if the stream or connection was terminated. - /// Returns `Err(_)` immediately if the underlying stream returns an error or if an item from the stream could not be serialized. - /// - /// # Examples + /// Send out a response on the subscription and wait until there is capacity. /// - /// ```no_run - /// - /// use jsonrpsee_core::server::rpc_module::RpcModule; - /// use jsonrpsee_core::error::{Error, SubscriptionClosed}; - /// use jsonrpsee_types::ErrorObjectOwned; - /// use anyhow::anyhow; /// - /// let mut m = RpcModule::new(()); - /// m.register_subscription("sub", "_", "unsub", |params, mut sink, _| { - /// let stream = futures_util::stream::iter(vec![Ok(1_u32), Ok(2), Err("error on the stream")]); - /// // This will return send `[Ok(1_u32), Ok(2_u32), Err(Error::SubscriptionClosed))]` to the subscriber - /// // because after the `Err(_)` the stream is terminated. - /// let stream = futures_util::stream::iter(vec![Ok(1_u32), Ok(2), Err("error on the stream")]); + /// Returns + /// - `Ok(())` if the message could be sent. + /// - `Err(err)` if the connection or subscription was closed. /// - /// tokio::spawn(async move { + /// # Cancel safety /// - /// // jsonrpsee doesn't send an error notification unless `close` is explicitly called. - /// // If we pipe messages to the sink, we can inspect why it ended: - /// match sink.pipe_from_try_stream(stream).await { - /// SubscriptionClosed::Success => { - /// let err_obj: ErrorObjectOwned = SubscriptionClosed::Success.into(); - /// sink.close(err_obj); - /// } - /// // we don't want to send close reason when the client is unsubscribed or disconnected. - /// SubscriptionClosed::RemotePeerAborted => (), - /// SubscriptionClosed::Failed(e) => { - /// sink.close(e); - /// } - /// } - /// }); - /// Ok(()) - /// }); - /// ``` - pub async fn pipe_from_try_stream(&mut self, mut stream: S) -> SubscriptionClosed - where - S: TryStream + Unpin, - T: Serialize, - E: std::fmt::Display, - { - if let Err(SubscriptionAcceptRejectError::RemotePeerAborted) = self.accept() { - return SubscriptionClosed::RemotePeerAborted; + /// This method is cancel-safe and dropping a future loses its spot in the waiting queue. + pub async fn send(&self, msg: SubscriptionMessage) -> Result<(), DisconnectError> { + // Only possible to trigger when the connection is dropped. + if self.is_closed() { + return Err(DisconnectError(msg)); } - let conn_closed = match self.close_notify.as_ref().map(|cn| cn.handle()) { - Some(cn) => cn, - None => return SubscriptionClosed::RemotePeerAborted, - }; - - let mut sub_closed = match self.unsubscribe.as_ref() { - Some(rx) => rx.clone(), - _ => { - return SubscriptionClosed::Failed(ErrorObject::owned( - INTERNAL_ERROR_CODE, - "Unsubscribe watcher not set after accepting the subscription".to_string(), - None::<()>, - )) - } - }; - - let sub_closed_fut = sub_closed.changed(); - - let conn_closed_fut = conn_closed.notified(); - pin_mut!(conn_closed_fut); - pin_mut!(sub_closed_fut); - - let mut stream_item = stream.try_next(); - let mut closed_fut = futures_util::future::select(conn_closed_fut, sub_closed_fut); - - loop { - match futures_util::future::select(stream_item, closed_fut).await { - // The app sent us a value to send back to the subscribers - Either::Left((Ok(Some(result)), next_closed_fut)) => { - match self.send_without_accept(&result) { - Ok(true) => (), - Ok(false) => { - break SubscriptionClosed::RemotePeerAborted; - } - Err(err) => { - let err = ErrorObject::owned(SUBSCRIPTION_CLOSED_WITH_ERROR, err.to_string(), None::<()>); - break SubscriptionClosed::Failed(err); - } - }; - stream_item = stream.try_next(); - closed_fut = next_closed_fut; - } - // Stream canceled because of error. - Either::Left((Err(err), _)) => { - let err = ErrorObject::owned(SUBSCRIPTION_CLOSED_WITH_ERROR, err.to_string(), None::<()>); - break SubscriptionClosed::Failed(err); - } - Either::Left((Ok(None), _)) => break SubscriptionClosed::Success, - Either::Right((_, _)) => { - break SubscriptionClosed::RemotePeerAborted; - } - } - } + let json = self.sub_message_to_json(msg); + self.inner.send(json).await.map_err(Into::into) } - /// Similar to [`SubscriptionSink::pipe_from_try_stream`] but it doesn't require the stream return `Result`. - /// - /// Warning: it's possible to pass in a stream that returns `Result` if `Result: Serialize` is satisfied - /// but it won't cancel the stream when an error occurs. If you want the stream to be canceled when an - /// error occurs use [`SubscriptionSink::pipe_from_try_stream`] instead. - /// - /// # Examples - /// - /// ```no_run - /// - /// use jsonrpsee_core::server::rpc_module::RpcModule; - /// - /// let mut m = RpcModule::new(()); - /// m.register_subscription("sub", "_", "unsub", |params, mut sink, _| { - /// let stream = futures_util::stream::iter(vec![1_usize, 2, 3]); - /// tokio::spawn(async move { sink.pipe_from_stream(stream).await; }); - /// Ok(()) - /// }); - /// ``` - pub async fn pipe_from_stream(&mut self, stream: S) -> SubscriptionClosed - where - S: Stream + Unpin, - T: Serialize, - { - self.pipe_from_try_stream::<_, _, Error>(stream.map(|item| Ok(item))).await - } + /// Similar to to `SubscriptionSink::send` but only waits for a limited time. + pub async fn send_timeout(&self, msg: SubscriptionMessage, timeout: Duration) -> Result<(), SendTimeoutError> { + // Only possible to trigger when the connection is dropped. + if self.is_closed() { + return Err(SendTimeoutError::Closed(msg)); + } - /// Returns whether the subscription is closed. - pub fn is_closed(&self) -> bool { - self.inner.is_closed() || self.close_notify.is_none() || !self.is_active_subscription() + let json = self.sub_message_to_json(msg); + self.inner.send_timeout(json, timeout).await.map_err(Into::into) } - /// Send a message back to subscribers. + /// Attempts to immediately send out the message as JSON string to the subscribers but fails if the + /// channel is full or the connection/subscription is closed /// - /// This is similar to the [`SubscriptionSink::send`], but it does not try to accept - /// the subscription prior to sending. - #[inline] - fn send_without_accept(&mut self, result: &T) -> Result { + /// + /// This differs from [`SubscriptionSink::send`] where it will until there is capacity + /// in the channel. + pub fn try_send(&mut self, msg: SubscriptionMessage) -> Result<(), TrySendError> { // Only possible to trigger when the connection is dropped. if self.is_closed() { - return Ok(false); + return Err(TrySendError::Closed(msg)); } - let msg = self.build_message(result)?; - Ok(self.inner.send_raw(msg).is_ok()) + let json = self.sub_message_to_json(msg); + self.inner.try_send(json).map_err(Into::into) } - fn is_active_subscription(&self) -> bool { - match self.unsubscribe.as_ref() { - Some(unsubscribe) => unsubscribe.has_changed().is_ok(), - _ => false, - } + /// Returns whether the subscription is closed. + pub fn is_closed(&self) -> bool { + self.inner.is_closed() || !self.is_active_subscription() } - fn answer_subscription(&self, response: MethodResponse, subscribe_call: oneshot::Sender) -> bool { - let ws_send = self.inner.send_raw(response.result.clone()).is_ok(); - let logger_call = subscribe_call.send(response).is_ok(); - - ws_send && logger_call + /// Completes when the subscription has been closed. + pub async fn closed(&self) { + // Both are cancel-safe thus ok to use select here. + tokio::select! { + _ = self.inner.closed() => (), + _ = self.unsubscribe.unsubscribed() => (), + } } - fn build_message(&self, result: &T) -> Result { - serde_json::to_string(&SubscriptionResponse::new( - self.method.into(), - SubscriptionPayload { subscription: self.uniq_sub.sub_id.clone(), result }, - )) - .map_err(Into::into) + fn sub_message_to_json(&self, msg: SubscriptionMessage) -> String { + match msg.0 { + SubscriptionMessageInner::Complete(msg) => msg, + SubscriptionMessageInner::NeedsData(result) => { + let sub_id = serde_json::to_string(&self.uniq_sub.sub_id).expect("valid JSON; qed"); + let method = self.method; + format!( + r#"{{"jsonrpc":"2.0","method":"{method}","params":{{"subscription":{sub_id},"result":{result}}}}}"#, + ) + } + } } fn build_error_message(&self, error: &T) -> Result { @@ -1110,29 +1122,29 @@ impl SubscriptionSink { /// } /// ``` /// - pub fn close(self, err: impl Into) -> bool { + pub fn close(self, err: impl Into) -> impl Future { if self.is_active_subscription() { if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) { tracing::debug!("Closing subscription: {:?}", self.uniq_sub.sub_id); let msg = self.build_error_message(&err.into()).expect("valid json infallible; qed"); - return sink.send_raw(msg).is_ok(); + + return Either::Right(async move { + let _ = sink.send(msg).await; + }); } } - false + Either::Left(futures_util::future::ready(())) + } + + fn is_active_subscription(&self) -> bool { + !self.unsubscribe.is_unsubscribed() } } impl Drop for SubscriptionSink { fn drop(&mut self) { - if let Some((id, subscribe_call)) = self.id.take() { - // Subscription was never accepted / rejected. As such, - // we default to assuming that the params were invalid, - // because that's how the previous PendingSubscription logic - // worked. - let err = MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidParams)); - self.answer_subscription(err, subscribe_call); - } else if self.is_active_subscription() { + if self.is_active_subscription() { self.subscribers.lock().remove(&self.uniq_sub); } } @@ -1141,18 +1153,17 @@ impl Drop for SubscriptionSink { /// Wrapper struct that maintains a subscription "mainly" for testing. #[derive(Debug)] pub struct Subscription { - close_notify: Option, - rx: mpsc::UnboundedReceiver, + tx: MethodSink, + rx: mpsc::Receiver, sub_id: RpcSubscriptionId<'static>, + _permit: SubscriptionPermit, } impl Subscription { /// Close the subscription channel. pub fn close(&mut self) { tracing::trace!("[Subscription::close] Notifying"); - if let Some(n) = self.close_notify.take() { - n.handle().notify_one() - } + self.rx.close(); } /// Get the subscription ID @@ -1162,7 +1173,7 @@ impl Subscription { /// Check whether the subscription is closed. pub fn is_closed(&self) -> bool { - self.close_notify.is_none() + self.tx.is_closed() } /// Returns `Some((val, sub_id))` for the next element of type T from the underlying stream, @@ -1172,11 +1183,7 @@ impl Subscription { /// /// If the decoding the value as `T` fails. pub async fn next(&mut self) -> Option), Error>> { - if self.close_notify.is_none() { - tracing::debug!("[Subscription::next] Closed."); - return None; - } - let raw = self.rx.next().await?; + let raw = self.rx.recv().await?; tracing::debug!("[Subscription::next]: rx {}", raw); let res = match serde_json::from_str::>(&raw) { diff --git a/examples/Cargo.toml b/examples/Cargo.toml index d2da324b8a..a69d893d07 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -18,4 +18,4 @@ serde_json = { version = "1" } tower-http = { version = "0.3.4", features = ["full"] } tower = { version = "0.4.13", features = ["full"] } hyper = "0.14.20" -console-subscriber = "0.1.8" +console-subscriber = "0.1.8" \ No newline at end of file diff --git a/examples/examples/proc_macro.rs b/examples/examples/proc_macro.rs index eebedc3cbb..165f612139 100644 --- a/examples/examples/proc_macro.rs +++ b/examples/examples/proc_macro.rs @@ -26,10 +26,10 @@ use std::net::SocketAddr; -use jsonrpsee::core::{async_trait, client::Subscription, Error}; +use jsonrpsee::core::server::rpc_module::SubscriptionMessage; +use jsonrpsee::core::{async_trait, client::Subscription, Error, SubscriptionResult}; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::{ServerBuilder, SubscriptionSink}; -use jsonrpsee::types::SubscriptionResult; +use jsonrpsee::server::{PendingSubscriptionSink, ServerBuilder}; use jsonrpsee::ws_client::WsClientBuilder; type ExampleHash = [u8; 32]; @@ -46,7 +46,7 @@ where /// Subscription that takes a `StorageKey` as input and produces a `Vec`. #[subscription(name = "subscribeStorage" => "override", item = Vec)] - fn subscribe_storage(&self, keys: Option>); + async fn subscribe_storage(&self, keys: Option>); } pub struct RpcServerImpl; @@ -62,12 +62,14 @@ impl RpcServer for RpcServerImpl { } // Note that the server's subscription method must return `SubscriptionResult`. - fn subscribe_storage( + async fn subscribe_storage( &self, - mut sink: SubscriptionSink, + pending: PendingSubscriptionSink, _keys: Option>, ) -> SubscriptionResult { - let _ = sink.send(&vec![[0; 32]]); + let sink = pending.accept().await?; + let msg = SubscriptionMessage::from_json(&vec![[0; 32]])?; + sink.send(msg).await?; Ok(()) } } diff --git a/examples/examples/ws_pubsub_broadcast.rs b/examples/examples/ws_pubsub_broadcast.rs index 9aaca568c3..3ff2de6162 100644 --- a/examples/examples/ws_pubsub_broadcast.rs +++ b/examples/examples/ws_pubsub_broadcast.rs @@ -28,13 +28,16 @@ use std::net::SocketAddr; -use futures::future; +use futures::future::{self, Either}; use futures::StreamExt; use jsonrpsee::core::client::{Subscription, SubscriptionClientT}; -use jsonrpsee::core::error::SubscriptionClosed; +use jsonrpsee::core::server::rpc_module::SubscriptionMessage; + +use jsonrpsee::core::SubscriptionResult; use jsonrpsee::rpc_params; use jsonrpsee::server::{RpcModule, ServerBuilder}; use jsonrpsee::ws_client::WsClientBuilder; +use jsonrpsee::PendingSubscriptionSink; use tokio::sync::broadcast; use tokio_stream::wrappers::BroadcastStream; @@ -64,29 +67,23 @@ async fn main() -> anyhow::Result<()> { } async fn run_server() -> anyhow::Result { - let server = ServerBuilder::default().build("127.0.0.1:0").await?; - let mut module = RpcModule::new(()); - let (tx, _rx) = broadcast::channel(16); - let tx2 = tx.clone(); + // let's configure the server only hold 5 messages in memory. + let server = ServerBuilder::default().set_message_buffer_capacity(5).build("127.0.0.1:0").await?; + let (tx, _rx) = broadcast::channel::(16); - std::thread::spawn(move || produce_items(tx2)); + let mut module = RpcModule::new(tx.clone()); - module.register_subscription("subscribe_hello", "s_hello", "unsubscribe_hello", move |_, mut sink, _| { - let rx = BroadcastStream::new(tx.clone().subscribe()); + std::thread::spawn(move || produce_items(tx)); - tokio::spawn(async move { - match sink.pipe_from_try_stream(rx).await { - SubscriptionClosed::Success => { - sink.close(SubscriptionClosed::Success); - } - SubscriptionClosed::RemotePeerAborted => (), - SubscriptionClosed::Failed(err) => { - sink.close(err); - } - }; - }); - Ok(()) - })?; + module + .register_subscription("subscribe_hello", "s_hello", "unsubscribe_hello", |_, pending, tx| async move { + let rx = tx.subscribe(); + let stream = BroadcastStream::new(rx); + pipe_from_stream_with_bounded_buffer(pending, stream).await?; + + Ok(()) + }) + .unwrap(); let addr = server.local_addr()?; let handle = server.start(module)?; @@ -97,10 +94,48 @@ async fn run_server() -> anyhow::Result { Ok(addr) } +async fn pipe_from_stream_with_bounded_buffer( + pending: PendingSubscriptionSink, + stream: BroadcastStream, +) -> SubscriptionResult { + let sink = pending.accept().await?; + let closed = sink.closed(); + + futures::pin_mut!(closed, stream); + + loop { + match future::select(closed, stream.next()).await { + // subscription closed. + Either::Left((_, _)) => break, + + // received new item from the stream. + Either::Right((Some(Ok(item)), c)) => { + let notif = SubscriptionMessage::from_json(&item)?; + + // NOTE: this will block until there a spot in the queue + // and you might want to do something smarter if it's + // critical that "the most recent item" must be sent when it is produced. + if sink.send(notif).await.is_err() { + break; + } + + closed = c; + } + + // stream is closed or some error, just quit. + Either::Right((_, _)) => { + break; + } + } + } + + Ok(()) +} + // Naive example that broadcasts the produced values to all active subscribers. fn produce_items(tx: broadcast::Sender) { for c in 1..=100 { - std::thread::sleep(std::time::Duration::from_secs(1)); + std::thread::sleep(std::time::Duration::from_millis(1)); // This might fail if no receivers are alive, could occur if no subscriptions are active... // Also be aware that this will succeed when at least one receiver is alive diff --git a/examples/examples/ws_pubsub_with_params.rs b/examples/examples/ws_pubsub_with_params.rs index 1e56352ba4..0b4d910489 100644 --- a/examples/examples/ws_pubsub_with_params.rs +++ b/examples/examples/ws_pubsub_with_params.rs @@ -27,12 +27,14 @@ use std::net::SocketAddr; use std::time::Duration; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use jsonrpsee::core::client::{Subscription, SubscriptionClientT}; -use jsonrpsee::core::error::SubscriptionClosed; -use jsonrpsee::rpc_params; +use jsonrpsee::core::server::rpc_module::{SubscriptionMessage, TrySendError}; +use jsonrpsee::core::{Serialize, SubscriptionResult}; use jsonrpsee::server::{RpcModule, ServerBuilder}; +use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::ws_client::WsClientBuilder; +use jsonrpsee::{rpc_params, PendingSubscriptionSink}; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; @@ -63,38 +65,37 @@ async fn main() -> anyhow::Result<()> { async fn run_server() -> anyhow::Result { const LETTERS: &str = "abcdefghijklmnopqrstuvxyz"; - let server = ServerBuilder::default().build("127.0.0.1:0").await?; + let server = ServerBuilder::default().set_message_buffer_capacity(10).build("127.0.0.1:0").await?; let mut module = RpcModule::new(()); module - .register_subscription("sub_one_param", "sub_one_param", "unsub_one_param", |params, mut sink, _| { - let idx = params.one()?; + .register_subscription("sub_one_param", "sub_one_param", "unsub_one_param", |params, pending, _| async move { + // we are doing this verbose way to get a customized reject error on the subscription. + let idx = match params.one::() { + Ok(p) => p, + Err(e) => { + let _ = pending.reject(ErrorObjectOwned::from(e)).await; + return Ok(()); + } + }; + let item = LETTERS.chars().nth(idx); let interval = interval(Duration::from_millis(200)); let stream = IntervalStream::new(interval).map(move |_| item); - tokio::spawn(async move { - if let SubscriptionClosed::Failed(err) = sink.pipe_from_stream(stream).await { - sink.close(err); - } - }); + pipe_from_stream_and_drop(pending, stream).await?; + Ok(()) }) .unwrap(); module - .register_subscription("sub_params_two", "params_two", "unsub_params_two", |params, mut sink, _| { + .register_subscription("sub_params_two", "params_two", "unsub_params_two", |params, pending, _| async move { let (one, two) = params.parse::<(usize, usize)>()?; let item = &LETTERS[one..two]; - let interval = interval(Duration::from_millis(200)); let stream = IntervalStream::new(interval).map(move |_| item); - - tokio::spawn(async move { - if let SubscriptionClosed::Failed(err) = sink.pipe_from_stream(stream).await { - sink.close(err); - } - }); + pipe_from_stream_and_drop(pending, stream).await?; Ok(()) }) @@ -109,3 +110,41 @@ async fn run_server() -> anyhow::Result { Ok(addr) } + +async fn pipe_from_stream_and_drop(pending: PendingSubscriptionSink, mut stream: S) -> SubscriptionResult +where + S: Stream + Unpin, + T: Serialize, +{ + let mut sink = pending.accept().await?; + + loop { + tokio::select! { + _ = sink.closed() => break, + maybe_item = stream.next() => { + let item = match maybe_item { + Some(item) => item, + None => break, + }; + let msg = match SubscriptionMessage::from_json(&item) { + Ok(msg) => msg, + Err(e) => { + sink.close(ErrorObject::owned(1, e.to_string(), None::<()>)).await; + return Err(e.into()); + } + }; + + match sink.try_send(msg) { + Ok(_) => (), + Err(TrySendError::Closed(_)) => break, + // channel is full, let's be naive an just drop the message. + Err(TrySendError::Full(_)) => (), + } + } + } + } + + sink.close(ErrorObject::owned(1, "Ok", None::<()>)).await; + + Ok(()) +} diff --git a/jsonrpsee/Cargo.toml b/jsonrpsee/Cargo.toml index 8700ac8cc6..1dd7f166c5 100644 --- a/jsonrpsee/Cargo.toml +++ b/jsonrpsee/Cargo.toml @@ -21,6 +21,7 @@ jsonrpsee-proc-macros = { path = "../proc-macros", version = "0.16.2", optional jsonrpsee-core = { path = "../core", version = "0.16.2", optional = true } jsonrpsee-types = { path = "../types", version = "0.16.2", optional = true } tracing = { version = "0.1.34", optional = true } +tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros", "time"], optional = true } [features] client-ws-transport-native-tls = ["jsonrpsee-client-transport/ws", "jsonrpsee-client-transport/native-tls"] @@ -36,7 +37,7 @@ macros = ["jsonrpsee-proc-macros", "jsonrpsee-types", "tracing"] client = ["http-client", "ws-client", "wasm-client", "client-ws-transport-native-tls", "client-ws-transport-webpki-tls", "client-web-transport", "async-client", "async-wasm-client", "client-core"] client-core = ["jsonrpsee-core/client"] -server = ["jsonrpsee-server", "server-core", "jsonrpsee-types"] +server = ["jsonrpsee-server", "server-core", "jsonrpsee-types", "tokio"] server-core = ["jsonrpsee-core/server"] full = ["client", "server", "macros"] diff --git a/jsonrpsee/src/lib.rs b/jsonrpsee/src/lib.rs index 2a0efbc21d..5f44308047 100644 --- a/jsonrpsee/src/lib.rs +++ b/jsonrpsee/src/lib.rs @@ -90,7 +90,8 @@ cfg_types! { } cfg_server! { - pub use jsonrpsee_core::server::rpc_module::{RpcModule, SubscriptionSink}; + pub use jsonrpsee_core::server::rpc_module::{RpcModule, SubscriptionSink, PendingSubscriptionSink, SubscriptionMessage, DisconnectError, TrySendError, SendTimeoutError}; + pub use tokio; } cfg_client_or_server! { diff --git a/proc-macros/src/attributes.rs b/proc-macros/src/attributes.rs index 139f25e7d4..30c0797bd2 100644 --- a/proc-macros/src/attributes.rs +++ b/proc-macros/src/attributes.rs @@ -29,7 +29,7 @@ use std::{fmt, iter}; use proc_macro2::{Span, TokenStream as TokenStream2, TokenTree}; use syn::parse::{Parse, ParseStream, Parser}; use syn::punctuated::Punctuated; -use syn::{spanned::Spanned, Attribute, Error, LitInt, LitStr, Token}; +use syn::{spanned::Spanned, Attribute, Error, LitStr, Token}; pub(crate) struct AttributeMeta { pub path: syn::Path, @@ -47,13 +47,6 @@ pub enum ParamKind { Map, } -#[derive(Debug, Clone)] -pub struct Resource { - pub name: LitStr, - pub assign: Token![=], - pub value: LitInt, -} - pub struct NameMapping { pub name: String, pub mapped: Option, @@ -93,12 +86,6 @@ impl Parse for Argument { } } -impl Parse for Resource { - fn parse(input: ParseStream) -> syn::Result { - Ok(Resource { name: input.parse()?, assign: input.parse()?, value: input.parse()? }) - } -} - impl Parse for NameMapping { fn parse(input: ParseStream) -> syn::Result { let name = input.parse::()?.value(); diff --git a/proc-macros/src/helpers.rs b/proc-macros/src/helpers.rs index 5e5492c1bb..b1b3764916 100644 --- a/proc-macros/src/helpers.rs +++ b/proc-macros/src/helpers.rs @@ -71,7 +71,8 @@ fn find_jsonrpsee_crate(http_name: &str, ws_name: &str) -> Result { @@ -79,7 +80,7 @@ fn find_jsonrpsee_crate(http_name: &str, ws_name: &str) -> Result RpcResult; /// /// #[subscription(name = "subscribe", item = Vec)] -/// fn sub(&self); +/// async fn sub(&self); /// } /// ``` /// diff --git a/proc-macros/src/lib.rs b/proc-macros/src/lib.rs index fa3b3fab7b..e9bdcb4f6b 100644 --- a/proc-macros/src/lib.rs +++ b/proc-macros/src/lib.rs @@ -86,7 +86,7 @@ pub(crate) mod visitor; /// fn sync_method(&self) -> String; /// /// #[subscription(name = "subscribe", item = "String")] -/// fn sub(&self); +/// async fn sub(&self); /// } /// ``` /// @@ -99,8 +99,8 @@ pub(crate) mod visitor; /// async fn async_method(&self, param_a: u8, param_b: String) -> u16; /// fn sync_method(&self) -> String; /// -/// // Note that `subscription_sink` and `SubscriptionResult` were added automatically. -/// fn sub(&self, subscription_sink: SubscriptionResult) -> SubscriptionResult; +/// // Note that `pending_subscription_sink` and `SubscriptionResult` were added automatically. +/// async fn sub(&self, pending: PendingSubscriptionSink) -> SubscriptionResult; /// /// fn into_rpc(self) -> Result { /// // Actual implementation stripped, but inside we will create @@ -219,8 +219,8 @@ pub(crate) mod visitor; /// /// // RPC is put into a separate module to clearly show names of generated entities. /// mod rpc_impl { -/// use jsonrpsee::{proc_macros::rpc, core::async_trait, core::RpcResult, server::SubscriptionSink}; -/// use jsonrpsee::types::SubscriptionResult; +/// use jsonrpsee::{proc_macros::rpc, server::PendingSubscriptionSink, server::SubscriptionMessage}; +/// use jsonrpsee::core::{async_trait, SubscriptionResult, RpcResult}; /// /// // Generate both server and client implementations, prepend all the methods with `foo_` prefix. /// #[rpc(client, server, namespace = "foo")] @@ -248,7 +248,7 @@ pub(crate) mod visitor; /// /// } /// /// ``` /// #[subscription(name = "sub" => "subNotif", unsubscribe = "unsub", item = String)] -/// fn sub_override_notif_method(&self); +/// async fn sub_override_notif_method(&self); /// /// /// Use the same method name for both the `subscribe call` and `notifications` /// /// @@ -265,7 +265,7 @@ pub(crate) mod visitor; /// /// } /// /// ``` /// #[subscription(name = "subscribe", item = String)] -/// fn sub(&self); +/// async fn sub(&self); /// } /// /// // Structure that will implement the `MyRpcServer` trait. @@ -292,20 +292,25 @@ pub(crate) mod visitor; /// /// // The stream API can be used to pipe items from the underlying stream /// // as subscription responses. -/// fn sub_override_notif_method(&self, mut sink: SubscriptionSink) -> SubscriptionResult { -/// tokio::spawn(async move { -/// let stream = futures_util::stream::iter(["one", "two", "three"]); -/// sink.pipe_from_stream(stream).await; -/// }); +/// async fn sub_override_notif_method(&self, pending: PendingSubscriptionSink) -> SubscriptionResult { +/// let mut sink = pending.accept().await?; +/// +/// let msg = SubscriptionMessage::from_json(&"Response_A")?; +/// sink.send(msg).await?; +/// /// Ok(()) /// } /// -/// // We could've spawned a `tokio` future that yields values while our program works, -/// // but for simplicity of the example we will only send two values and then close -/// // the subscription. -/// fn sub(&self, mut sink: SubscriptionSink) -> SubscriptionResult { -/// let _ = sink.send(&"Response_A"); -/// let _ = sink.send(&"Response_B"); +/// // Send out two values on the subscription. +/// async fn sub(&self, pending: PendingSubscriptionSink) -> SubscriptionResult { +/// let sink = pending.accept().await?; +/// +/// let msg1 = SubscriptionMessage::from_json(&"Response_A")?; +/// let msg2 = SubscriptionMessage::from_json(&"Response_B")?; +/// +/// sink.send(msg1).await?; +/// sink.send(msg2).await?; +/// /// Ok(()) /// } /// } diff --git a/proc-macros/src/render_server.rs b/proc-macros/src/render_server.rs index 054e1fade4..49f13326c9 100644 --- a/proc-macros/src/render_server.rs +++ b/proc-macros/src/render_server.rs @@ -28,12 +28,11 @@ use std::collections::HashSet; use std::str::FromStr; use super::RpcDescription; -use crate::attributes::Resource; use crate::helpers::{generate_where_clause, is_option}; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, quote_spanned}; use syn::punctuated::Punctuated; -use syn::{parse_quote, token, AttrStyle, Attribute, Path, PathSegment, ReturnType, Token}; +use syn::{parse_quote, token, AttrStyle, Attribute, Path, PathSegment, ReturnType}; impl RpcDescription { pub(super) fn render_server(&self) -> Result { @@ -72,13 +71,13 @@ impl RpcDescription { let subscriptions = self.subscriptions.iter().map(|sub| { let docs = &sub.docs; - let subscription_sink_ty = self.jrps_server_item(quote! { SubscriptionSink }); + let subscription_sink_ty = self.jrps_server_item(quote! { PendingSubscriptionSink }); // Add `SubscriptionSink` as the second input parameter to the signature. let subscription_sink: syn::FnArg = syn::parse_quote!(subscription_sink: #subscription_sink_ty); let mut sub_sig = sub.signature.clone(); // For ergonomic reasons, the server's subscription method should return `SubscriptionResult`. - let return_ty = self.jrps_server_item(quote! { types::SubscriptionResult }); + let return_ty = self.jrps_server_item(quote! { core::SubscriptionResult }); let output: ReturnType = parse_quote! { -> #return_ty }; sub_sig.sig.output = output; @@ -121,28 +120,6 @@ impl RpcDescription { }} } - /// Helper that will parse the resources passed to the macro and call the appropriate resource - /// builder to register the resource limits. - fn handle_resource_limits(resources: &Punctuated) -> TokenStream2 { - // Nothing to be done if no resources were set. - if resources.is_empty() { - return quote! {}; - } - - // Transform each resource into a call to `.resource(name, value)`. - let resources = resources.iter().map(|resource| { - let Resource { name, value, .. } = resource; - quote! { .resource(#name, #value)? } - }); - - quote! { - .and_then(|resource_builder| { - resource_builder #(#resources)*; - Ok(()) - }) - } - } - let methods = self .methods .iter() @@ -159,15 +136,12 @@ impl RpcDescription { check_name(&rpc_method_name, rust_method_name.span()); - let resources = handle_resource_limits(&method.resources); - if method.signature.sig.asyncness.is_some() { handle_register_result(quote! { rpc.register_async_method(#rpc_method_name, |params, context| async move { #parsing context.as_ref().#rust_method_name(#params_seq).await }) - #resources }) } else { let register_kind = @@ -178,7 +152,6 @@ impl RpcDescription { #parsing context.#rust_method_name(#params_seq) }) - #resources }) } }) @@ -213,14 +186,11 @@ impl RpcDescription { None => rpc_sub_name.clone(), }; - let resources = handle_resource_limits(&sub.resources); - handle_register_result(quote! { - rpc.register_subscription(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut subscription_sink, context| { + rpc.register_subscription(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut subscription_sink, context| async move { #parsing - context.as_ref().#rust_method_name(subscription_sink, #params_seq) + context.as_ref().#rust_method_name(subscription_sink, #params_seq).await }) - #resources }) }) .collect::>(); @@ -320,7 +290,8 @@ impl RpcDescription { let params_fields = quote! { #(#params_fields_seq),* }; let tracing = self.jrps_server_item(quote! { tracing }); let err = self.jrps_server_item(quote! { core::Error }); - let sub_err = self.jrps_server_item(quote! { types::SubscriptionEmptyError }); + let sub_err = self.jrps_server_item(quote! { core::SubscriptionCallbackError::None }); + let tokio = self.jrps_server_item(quote! { tokio }); // Code to decode sequence of parameters from a JSON array. let decode_array = { @@ -330,9 +301,11 @@ impl RpcDescription { let #name: #ty = match seq.optional_next() { Ok(v) => v, Err(e) => { - #tracing::error!(concat!("Error parsing optional \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); + #tracing::warn!(concat!("Error parsing optional \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); let _e: #err = e.into(); - #pending.reject(_e)?; + #tokio::spawn(async move { + let _ = #pending.reject(_e).await; + }); return Err(#sub_err); } }; @@ -343,7 +316,7 @@ impl RpcDescription { let #name: #ty = match seq.optional_next() { Ok(v) => v, Err(e) => { - #tracing::error!(concat!("Error parsing optional \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); + #tracing::warn!(concat!("Error parsing optional \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); return Err(e.into()) } }; @@ -354,9 +327,11 @@ impl RpcDescription { let #name: #ty = match seq.next() { Ok(v) => v, Err(e) => { - #tracing::error!(concat!("Error parsing \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); + #tracing::warn!(concat!("Error parsing \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); let _e: #err = e.into(); - #pending.reject(_e)?; + #tokio::spawn(async move { + let _ = #pending.reject(_e).await; + }); return Err(#sub_err); } }; @@ -367,7 +342,7 @@ impl RpcDescription { let #name: #ty = match seq.next() { Ok(v) => v, Err(e) => { - #tracing::error!(concat!("Error parsing \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); + #tracing::warn!(concat!("Error parsing \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); return Err(e.into()) } }; @@ -433,9 +408,11 @@ impl RpcDescription { let parsed: ParamsObject<#(#types,)*> = match params.parse() { Ok(p) => p, Err(e) => { - #tracing::error!("Failed to parse JSON-RPC params as object: {}", e); + #tracing::warn!("Failed to parse JSON-RPC params as object: {}", e); let _e: #err = e.into(); - #pending.reject(_e)?; + #tokio::spawn(async move { + let _ = #pending.reject(_e).await; + }); return Err(#sub_err); } }; @@ -451,7 +428,7 @@ impl RpcDescription { } let parsed: ParamsObject<#(#types,)*> = params.parse().map_err(|e| { - #tracing::error!("Failed to parse JSON-RPC params as object: {}", e); + #tracing::warn!("Failed to parse JSON-RPC params as object: {}", e); e })?; (#(#destruct),*) diff --git a/proc-macros/src/rpc_macro.rs b/proc-macros/src/rpc_macro.rs index 841ddece15..7003727268 100644 --- a/proc-macros/src/rpc_macro.rs +++ b/proc-macros/src/rpc_macro.rs @@ -29,7 +29,7 @@ use std::borrow::Cow; use crate::attributes::{ - optional, parse_param_kind, Aliases, Argument, AttributeMeta, MissingArgument, NameMapping, ParamKind, Resource, + optional, parse_param_kind, Aliases, Argument, AttributeMeta, MissingArgument, NameMapping, ParamKind, }; use crate::helpers::extract_doc_comments; use proc_macro2::TokenStream as TokenStream2; @@ -48,19 +48,17 @@ pub struct RpcMethod { pub returns: Option, pub signature: syn::TraitItemMethod, pub aliases: Vec, - pub resources: Punctuated, } impl RpcMethod { pub fn from_item(attr: Attribute, mut method: syn::TraitItemMethod) -> syn::Result { - let [aliases, blocking, name, param_kind, resources] = - AttributeMeta::parse(attr)?.retain(["aliases", "blocking", "name", "param_kind", "resources"])?; + let [aliases, blocking, name, param_kind] = + AttributeMeta::parse(attr)?.retain(["aliases", "blocking", "name", "param_kind"])?; let aliases = parse_aliases(aliases)?; let blocking = optional(blocking, Argument::flag)?.is_some(); let name = name?.string()?; let param_kind = parse_param_kind(param_kind)?; - let resources = optional(resources, Argument::group)?.unwrap_or_default(); let sig = method.sig.clone(); let docs = extract_doc_comments(&method.attrs); @@ -100,18 +98,7 @@ impl RpcMethod { // We've analyzed attributes and don't need them anymore. method.attrs.clear(); - Ok(Self { - aliases, - blocking, - name, - params, - param_kind, - returns, - signature: method, - docs, - resources, - deprecated, - }) + Ok(Self { aliases, blocking, name, params, param_kind, returns, signature: method, docs, deprecated }) } } @@ -133,21 +120,12 @@ pub struct RpcSubscription { pub signature: syn::TraitItemMethod, pub aliases: Vec, pub unsubscribe_aliases: Vec, - pub resources: Punctuated, } impl RpcSubscription { pub fn from_item(attr: syn::Attribute, mut sub: syn::TraitItemMethod) -> syn::Result { - let [aliases, item, name, param_kind, unsubscribe, unsubscribe_aliases, resources] = - AttributeMeta::parse(attr)?.retain([ - "aliases", - "item", - "name", - "param_kind", - "unsubscribe", - "unsubscribe_aliases", - "resources", - ])?; + let [aliases, item, name, param_kind, unsubscribe, unsubscribe_aliases] = AttributeMeta::parse(attr)? + .retain(["aliases", "item", "name", "param_kind", "unsubscribe", "unsubscribe_aliases"])?; let aliases = parse_aliases(aliases)?; let map = name?.value::()?; @@ -156,7 +134,6 @@ impl RpcSubscription { let item = item?.value()?; let param_kind = parse_param_kind(param_kind)?; let unsubscribe_aliases = parse_aliases(unsubscribe_aliases)?; - let resources = optional(resources, Argument::group)?.unwrap_or_default(); let sig = sub.sig.clone(); let docs = extract_doc_comments(&sub.attrs); @@ -193,7 +170,6 @@ impl RpcSubscription { signature: sub, aliases, docs, - resources, }) } } @@ -296,14 +272,11 @@ impl RpcDescription { } if !matches!(method.sig.output, syn::ReturnType::Default) { - return Err(syn::Error::new_spanned( - method, - "Subscription methods must not return anything; the error must send via subscription via either `SubscriptionSink::reject` or `SubscriptionSink::close`", - )); + return Err(syn::Error::new_spanned(method, "Subscription methods must not return anything")); } - if method.sig.asyncness.is_some() { - return Err(syn::Error::new_spanned(method, "Subscription methods must not be `async`")); + if method.sig.asyncness.is_none() { + return Err(syn::Error::new_spanned(method, "Subscription methods must be `async`")); } let sub_data = RpcSubscription::from_item(attr.clone(), method.clone())?; diff --git a/proc-macros/tests/ui/correct/alias_doesnt_use_namespace.rs b/proc-macros/tests/ui/correct/alias_doesnt_use_namespace.rs index f68be69dcc..102487f850 100644 --- a/proc-macros/tests/ui/correct/alias_doesnt_use_namespace.rs +++ b/proc-macros/tests/ui/correct/alias_doesnt_use_namespace.rs @@ -8,7 +8,7 @@ pub trait Rpc { async fn async_method(&self, param_a: u8, param_b: String) -> RpcResult; #[subscription(name = "subscribeGetFood", item = String, aliases = ["getFood"], unsubscribe_aliases = ["unsubscribegetFood"])] - fn sub(&self); + async fn sub(&self); } fn main() {} diff --git a/proc-macros/tests/ui/correct/basic.rs b/proc-macros/tests/ui/correct/basic.rs index e8a58e97c9..0365e30de9 100644 --- a/proc-macros/tests/ui/correct/basic.rs +++ b/proc-macros/tests/ui/correct/basic.rs @@ -3,12 +3,12 @@ use std::net::SocketAddr; use jsonrpsee::core::params::ArrayParams; +use jsonrpsee::core::SubscriptionResult; use jsonrpsee::core::{async_trait, client::ClientT, RpcResult}; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::ServerBuilder; -use jsonrpsee::types::SubscriptionResult; +use jsonrpsee::server::{ServerBuilder, SubscriptionMessage}; use jsonrpsee::ws_client::*; -use jsonrpsee::{rpc_params, SubscriptionSink}; +use jsonrpsee::{rpc_params, PendingSubscriptionSink}; #[rpc(client, server, namespace = "foo")] pub trait Rpc { @@ -28,15 +28,15 @@ pub trait Rpc { fn sync_method(&self) -> RpcResult; #[subscription(name = "subscribe", item = String)] - fn sub(&self); + async fn sub(&self); #[subscription(name = "echo", unsubscribe = "unsubscribeEcho", aliases = ["ECHO"], item = u32, unsubscribe_aliases = ["NotInterested", "listenNoMore"])] - fn sub_with_params(&self, val: u32); + async fn sub_with_params(&self, val: u32); // This will send data to subscribers with the `method` field in the JSON payload set to `foo_subscribe_override` // because it's in the `foo` namespace. #[subscription(name = "subscribe_method" => "subscribe_override", item = u32)] - fn sub_with_override_notif_method(&self); + async fn sub_with_override_notif_method(&self); } pub struct RpcServerImpl; @@ -65,20 +65,34 @@ impl RpcServer for RpcServerImpl { Ok(10u16) } - fn sub(&self, mut sink: SubscriptionSink) -> SubscriptionResult { - let _ = sink.send(&"Response_A"); - let _ = sink.send(&"Response_B"); + async fn sub(&self, pending: PendingSubscriptionSink) -> SubscriptionResult { + let sink = pending.accept().await?; + + let msg1 = SubscriptionMessage::from_json(&"Response_A").unwrap(); + let msg2 = SubscriptionMessage::from_json(&"Response_B").unwrap(); + + sink.send(msg1).await.unwrap(); + sink.send(msg2).await.unwrap(); + Ok(()) } - fn sub_with_params(&self, mut sink: SubscriptionSink, val: u32) -> SubscriptionResult { - let _ = sink.send(&val); - let _ = sink.send(&val); + async fn sub_with_params(&self, pending: PendingSubscriptionSink, val: u32) -> SubscriptionResult { + let sink = pending.accept().await?; + + let msg = SubscriptionMessage::from_json(&val).unwrap(); + + sink.send(msg.clone()).await.unwrap(); + sink.send(msg).await.unwrap(); + Ok(()) } - fn sub_with_override_notif_method(&self, mut sink: SubscriptionSink) -> SubscriptionResult { - let _ = sink.send(&1); + async fn sub_with_override_notif_method(&self, pending: PendingSubscriptionSink) -> SubscriptionResult { + let sink = pending.accept().await?; + let msg = SubscriptionMessage::from_json(&1).unwrap(); + sink.send(msg).await.unwrap(); + Ok(()) } } diff --git a/proc-macros/tests/ui/correct/only_client.rs b/proc-macros/tests/ui/correct/only_client.rs index 05bedba2b5..bc9a4872f4 100644 --- a/proc-macros/tests/ui/correct/only_client.rs +++ b/proc-macros/tests/ui/correct/only_client.rs @@ -1,4 +1,4 @@ -//! Example of using proc macro to generate working client and server. +//! Example of using proc macro to generate working client. use jsonrpsee::{core::RpcResult, proc_macros::rpc}; @@ -11,7 +11,7 @@ pub trait Rpc { fn sync_method(&self) -> RpcResult; #[subscription(name = "subscribe", item = String)] - fn sub(&self); + async fn sub(&self); } fn main() {} diff --git a/proc-macros/tests/ui/correct/only_server.rs b/proc-macros/tests/ui/correct/only_server.rs index 45f251688e..89fa961d52 100644 --- a/proc-macros/tests/ui/correct/only_server.rs +++ b/proc-macros/tests/ui/correct/only_server.rs @@ -1,9 +1,8 @@ use std::net::SocketAddr; -use jsonrpsee::core::{async_trait, RpcResult}; +use jsonrpsee::core::{async_trait, RpcResult, SubscriptionResult}; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::{ServerBuilder, SubscriptionSink}; -use jsonrpsee::types::SubscriptionResult; +use jsonrpsee::server::{PendingSubscriptionSink, ServerBuilder, SubscriptionMessage}; #[rpc(server)] pub trait Rpc { @@ -14,7 +13,7 @@ pub trait Rpc { fn sync_method(&self) -> RpcResult; #[subscription(name = "subscribe", item = String)] - fn sub(&self); + async fn sub(&self); } pub struct RpcServerImpl; @@ -29,11 +28,15 @@ impl RpcServer for RpcServerImpl { Ok(10u16) } - fn sub(&self, mut sink: SubscriptionSink) -> SubscriptionResult { - sink.accept()?; + async fn sub(&self, pending: PendingSubscriptionSink) -> SubscriptionResult { + let sink = pending.accept().await?; + + let msg1 = SubscriptionMessage::from_json(&"Response_A").unwrap(); + let msg2 = SubscriptionMessage::from_json(&"Response_B").unwrap(); + + sink.send(msg1).await.unwrap(); + sink.send(msg2).await.unwrap(); - let _ = sink.send(&"Response_A"); - let _ = sink.send(&"Response_B"); Ok(()) } } diff --git a/proc-macros/tests/ui/correct/parse_angle_brackets.rs b/proc-macros/tests/ui/correct/parse_angle_brackets.rs index 24061b2095..c597b7afc6 100644 --- a/proc-macros/tests/ui/correct/parse_angle_brackets.rs +++ b/proc-macros/tests/ui/correct/parse_angle_brackets.rs @@ -12,6 +12,6 @@ fn main() { // angle braces need to be accounted for manually. item = TransactionStatus, )] - fn dummy_subscription(&self); + async fn dummy_subscription(&self); } } diff --git a/proc-macros/tests/ui/correct/rpc_deny_missing_docs.rs b/proc-macros/tests/ui/correct/rpc_deny_missing_docs.rs index ee61fa6bfe..b8b35aba98 100644 --- a/proc-macros/tests/ui/correct/rpc_deny_missing_docs.rs +++ b/proc-macros/tests/ui/correct/rpc_deny_missing_docs.rs @@ -13,7 +13,7 @@ pub trait ApiWithDocumentation { /// Subscription docs. #[subscription(name = "sub", unsubscribe = "unsub", item = String)] - fn sub(&self); + async fn sub(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr index 57c82ce5eb..8fecc8437a 100644 --- a/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr +++ b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr @@ -1,4 +1,4 @@ -error: Unknown argument `magic`, expected one of: `aliases`, `blocking`, `name`, `param_kind`, `resources` +error: Unknown argument `magic`, expected one of: `aliases`, `blocking`, `name`, `param_kind` --> $DIR/method_unexpected_field.rs:6:25 | 6 | #[method(name = "foo", magic = false)] diff --git a/proc-macros/tests/ui/incorrect/sub/sub_async.rs b/proc-macros/tests/ui/incorrect/sub/sub_async.rs deleted file mode 100644 index b77d9ed945..0000000000 --- a/proc-macros/tests/ui/incorrect/sub/sub_async.rs +++ /dev/null @@ -1,10 +0,0 @@ -use jsonrpsee::proc_macros::rpc; - -// Subscription method must not be async. -#[rpc(client, server)] -pub trait AsyncSub { - #[subscription(name = "sub", item = u8)] - async fn sub(&self); -} - -fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_async.stderr b/proc-macros/tests/ui/incorrect/sub/sub_async.stderr deleted file mode 100644 index 352c2410dd..0000000000 --- a/proc-macros/tests/ui/incorrect/sub/sub_async.stderr +++ /dev/null @@ -1,6 +0,0 @@ -error: Subscription methods must not be `async` - --> $DIR/sub_async.rs:6:2 - | -6 | / #[subscription(name = "sub", item = u8)] -7 | | async fn sub(&self); - | |________________________^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.rs b/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.rs index 119c1a1a45..600ca43df3 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.rs +++ b/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.rs @@ -3,7 +3,7 @@ use jsonrpsee::proc_macros::rpc; #[rpc(client, server)] pub trait DuplicatedSubAlias { #[subscription(name = "subscribeAlias", item = String, aliases = ["hello_is_goodbye"], unsubscribe_aliases = ["hello_is_goodbye"])] - fn async_method(&self); + async fn async_method(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.stderr b/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.stderr index 7c08822b38..e30ce56f7c 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.stderr +++ b/proc-macros/tests/ui/incorrect/sub/sub_conflicting_alias.stderr @@ -1,5 +1,5 @@ error: "hello_is_goodbye" is already defined - --> $DIR/sub_conflicting_alias.rs:6:5 + --> $DIR/sub_conflicting_alias.rs:6:11 | -6 | fn async_method(&self); - | ^^^^^^^^^^^^ +6 | async fn async_method(&self); + | ^^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.rs b/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.rs index cce07c544e..e81a9fed7d 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.rs +++ b/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.rs @@ -4,9 +4,9 @@ use jsonrpsee::proc_macros::rpc; #[rpc(client, server)] pub trait DupOverride { #[subscription(name = "subscribeOne" => "override", item = u8)] - fn one(&self); + async fn one(&self); #[subscription(name = "subscribeTwo" => "override", item = u8)] - fn two(&self); + async fn two(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.stderr b/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.stderr index b17dba0119..50efc5f4d5 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.stderr +++ b/proc-macros/tests/ui/incorrect/sub/sub_dup_name_override.stderr @@ -1,5 +1,5 @@ error: "override" is already defined - --> $DIR/sub_dup_name_override.rs:9:5 + --> $DIR/sub_dup_name_override.rs:9:11 | -9 | fn two(&self); - | ^^^ +9 | async fn two(&self); + | ^^^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs b/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs index 55124fe914..e11dbe25b8 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs +++ b/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs @@ -4,7 +4,7 @@ use jsonrpsee::proc_macros::rpc; #[rpc(client, server)] pub trait SubEmptyAttr { #[subscription()] - fn sub(&self); + async fn sub(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_name_override.rs b/proc-macros/tests/ui/incorrect/sub/sub_name_override.rs index f46c313c9f..50c786ba35 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_name_override.rs +++ b/proc-macros/tests/ui/incorrect/sub/sub_name_override.rs @@ -4,7 +4,7 @@ use jsonrpsee::proc_macros::rpc; #[rpc(client, server)] pub trait DupName { #[subscription(name = "one" => "one", unsubscribe = "unsubscribeOne", item = u8)] - fn one(&self); + async fn one(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_name_override.stderr b/proc-macros/tests/ui/incorrect/sub/sub_name_override.stderr index cfdd2afbbe..e4af34bcc3 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_name_override.stderr +++ b/proc-macros/tests/ui/incorrect/sub/sub_name_override.stderr @@ -1,5 +1,5 @@ error: "one" is already defined - --> $DIR/sub_name_override.rs:7:5 + --> $DIR/sub_name_override.rs:7:11 | -7 | fn one(&self); - | ^^^ +7 | async fn one(&self); + | ^^^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_item.rs b/proc-macros/tests/ui/incorrect/sub/sub_no_item.rs index fbb65eff3f..d2c750a194 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_no_item.rs +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_item.rs @@ -4,7 +4,7 @@ use jsonrpsee::proc_macros::rpc; #[rpc(client, server)] pub trait NoSubItem { #[subscription(name = "sub")] - fn sub(&self); + async fn sub(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_name.rs b/proc-macros/tests/ui/incorrect/sub/sub_no_name.rs index 3bab992b0b..9a03ad21a5 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_no_name.rs +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_name.rs @@ -4,7 +4,7 @@ use jsonrpsee::proc_macros::rpc; #[rpc(client, server)] pub trait NoSubName { #[subscription(item = String)] - fn async_method(&self); + async fn async_method(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs index be5c9472be..fbb22e367d 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs +++ b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs @@ -4,7 +4,7 @@ use jsonrpsee::proc_macros::rpc; #[rpc(client, server)] pub trait UnsupportedField { #[subscription(name = "sub", unsubscribe = "unsub", item = u8, magic = true)] - fn sub(&self); + async fn sub(&self); } fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr index 3eea012b51..70ad7bf7d0 100644 --- a/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr +++ b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr @@ -1,4 +1,4 @@ -error: Unknown argument `magic`, expected one of: `aliases`, `item`, `name`, `param_kind`, `unsubscribe`, `unsubscribe_aliases`, `resources` +error: Unknown argument `magic`, expected one of: `aliases`, `item`, `name`, `param_kind`, `unsubscribe`, `unsubscribe_aliases` --> tests/ui/incorrect/sub/sub_unsupported_field.rs:6:65 | 6 | #[subscription(name = "sub", unsubscribe = "unsub", item = u8, magic = true)] diff --git a/server/Cargo.toml b/server/Cargo.toml index 6b9a162dc5..bcfc6ff916 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -10,7 +10,6 @@ homepage = "https://github.com/paritytech/jsonrpsee" documentation = "https://docs.rs/jsonrpsee-server" [dependencies] -futures-channel = "0.3.14" futures-util = { version = "0.3.14", default-features = false, features = ["io", "async-await-macro"] } jsonrpsee-types = { path = "../types", version = "0.16.2" } jsonrpsee-core = { path = "../core", version = "0.16.2", features = ["server", "soketto", "http-helpers"] } diff --git a/server/src/lib.rs b/server/src/lib.rs index d661e4b83b..1100f89ff7 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -43,7 +43,10 @@ mod tests; pub use future::ServerHandle; pub use jsonrpsee_core::server::host_filtering::AllowHosts; -pub use jsonrpsee_core::server::rpc_module::{RpcModule, SubscriptionSink}; +pub use jsonrpsee_core::server::rpc_module::{ + DisconnectError, PendingSubscriptionSink, RpcModule, SendTimeoutError, SubscriptionMessage, SubscriptionSink, + TrySendError, +}; pub use jsonrpsee_core::{id_providers::*, traits::IdProvider}; pub use jsonrpsee_types as types; pub use server::{Builder as ServerBuilder, Server}; diff --git a/server/src/server.rs b/server/src/server.rs index ee80c08194..4511a25cb5 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -42,9 +42,7 @@ use futures_util::io::{BufReader, BufWriter}; use hyper::body::HttpBody; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; -use jsonrpsee_core::server::helpers::MethodResponse; use jsonrpsee_core::server::host_filtering::AllowHosts; -use jsonrpsee_core::server::resource_limiting::Resources; use jsonrpsee_core::server::rpc_module::Methods; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES}; @@ -64,7 +62,6 @@ const MAX_CONNECTIONS: u32 = 100; pub struct Server { listener: TcpListener, cfg: Settings, - resources: Resources, logger: L, id_provider: Arc, service_builder: tower::ServiceBuilder, @@ -76,7 +73,6 @@ impl std::fmt::Debug for Server { .field("listener", &self.listener) .field("cfg", &self.cfg) .field("id_provider", &self.id_provider) - .field("resources", &self.resources) .finish() } } @@ -107,7 +103,7 @@ where /// /// This will run on the tokio runtime until the server is stopped or the `ServerHandle` is dropped. pub fn start(mut self, methods: impl Into) -> Result { - let methods = methods.into().initialize_resources(&self.resources)?; + let methods = methods.into(); let (stop_tx, stop_rx) = watch::channel(()); let stop_handle = StopHandle::new(stop_rx); @@ -124,12 +120,11 @@ where let max_request_body_size = self.cfg.max_request_body_size; let max_response_body_size = self.cfg.max_response_body_size; let max_log_length = self.cfg.max_log_length; + let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection; let allow_hosts = self.cfg.allow_hosts; - let resources = self.resources; let logger = self.logger; let batch_requests_supported = self.cfg.batch_requests_supported; let id_provider = self.id_provider; - let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection; let mut id: u32 = 0; let connection_guard = ConnectionGuard::new(self.cfg.max_connections as usize); @@ -143,20 +138,20 @@ where remote_addr, methods: methods.clone(), allow_hosts: allow_hosts.clone(), - resources: resources.clone(), max_request_body_size, max_response_body_size, max_log_length, + max_subscriptions_per_connection, batch_requests_supported, id_provider: id_provider.clone(), ping_interval: self.cfg.ping_interval, stop_handle: stop_handle.clone(), - max_subscriptions_per_connection, conn_id: id, logger: logger.clone(), max_connections: self.cfg.max_connections, enable_http: self.cfg.enable_http, enable_ws: self.cfg.enable_ws, + message_buffer_capacity: self.cfg.message_buffer_capacity, }; process_connection(&self.service_builder, &connection_guard, data, socket, &mut connections); id = id.wrapping_add(1); @@ -181,12 +176,12 @@ struct Settings { max_response_body_size: u32, /// Maximum number of incoming connections allowed. max_connections: u32, - /// Maximum number of subscriptions per connection. - max_subscriptions_per_connection: u32, /// Max length for logging for requests and responses /// /// Logs bigger than this limit will be truncated. max_log_length: u32, + /// Maximum number of subscriptions per connection. + max_subscriptions_per_connection: u32, /// Host filtering. allow_hosts: AllowHosts, /// Whether batch requests are supported by this server or not. @@ -199,6 +194,8 @@ struct Settings { enable_http: bool, /// Enable WS. enable_ws: bool, + /// Number of messages that server is allowed to `buffer` until backpressure kicks in. + message_buffer_capacity: u32, } impl Default for Settings { @@ -207,14 +204,15 @@ impl Default for Settings { max_request_body_size: TEN_MB_SIZE_BYTES, max_response_body_size: TEN_MB_SIZE_BYTES, max_log_length: 4096, - max_subscriptions_per_connection: 1024, max_connections: MAX_CONNECTIONS, + max_subscriptions_per_connection: 1024, batch_requests_supported: true, allow_hosts: AllowHosts::Any, tokio_runtime: None, ping_interval: Duration::from_secs(60), enable_http: true, enable_ws: true, + message_buffer_capacity: 1024, } } } @@ -223,7 +221,6 @@ impl Default for Settings { #[derive(Debug)] pub struct Builder { settings: Settings, - resources: Resources, logger: L, id_provider: Arc, service_builder: tower::ServiceBuilder, @@ -233,7 +230,6 @@ impl Default for Builder { fn default() -> Self { Builder { settings: Settings::default(), - resources: Resources::default(), logger: (), id_provider: Arc::new(RandomIntegerIdProvider), service_builder: tower::ServiceBuilder::new(), @@ -280,16 +276,6 @@ impl Builder { self } - /// Register a new resource kind. Errors if `label` is already registered, or if the number of - /// registered resources on this server instance would exceed 8. - /// - /// See the module documentation for [`resurce_limiting`](../jsonrpsee_utils/server/resource_limiting/index.html#resource-limiting) - /// for details. - pub fn register_resource(mut self, label: &'static str, capacity: u16, default: u16) -> Result { - self.resources.register(label, capacity, default)?; - Ok(self) - } - /// Add a logger to the builder [`Logger`](../jsonrpsee_core/logger/trait.Logger.html). /// /// ``` @@ -334,7 +320,6 @@ impl Builder { pub fn set_logger(self, logger: T) -> Builder { Builder { settings: self.settings, - resources: self.resources, logger, id_provider: self.id_provider, service_builder: self.service_builder, @@ -424,13 +409,7 @@ impl Builder { /// } /// ``` pub fn set_middleware(self, service_builder: tower::ServiceBuilder) -> Builder { - Builder { - settings: self.settings, - resources: self.resources, - logger: self.logger, - id_provider: self.id_provider, - service_builder, - } + Builder { settings: self.settings, logger: self.logger, id_provider: self.id_provider, service_builder } } /// Configure the server to only serve JSON-RPC HTTP requests. @@ -453,6 +432,29 @@ impl Builder { self } + /// The server enforces backpressure which means that + /// `n` messages can be buffered and if the client + /// can't keep with up the server. + /// + /// This `capacity` is applied per connection and + /// applies globally on the connection which implies + /// all JSON-RPC messages. + /// + /// For example if a subscription produces plenty of new items + /// and the client can't keep up then no new messages are handled. + /// + /// If this limit is exceeded then the server will "back-off" + /// and only accept new messages once the client reads pending messages. + /// + /// # Panics + /// + /// Panics if the buffer capacity is 0. + /// + pub fn set_message_buffer_capacity(mut self, c: u32) -> Self { + self.settings.message_buffer_capacity = c; + self + } + /// Set maximum length for logging calls and responses. /// /// Logs bigger than this limit will be truncated. @@ -483,7 +485,6 @@ impl Builder { Ok(Server { listener, cfg: self.settings, - resources: self.resources, logger: self.logger, id_provider: self.id_provider, service_builder: self.service_builder, @@ -519,7 +520,6 @@ impl Builder { Ok(Server { listener, cfg: self.settings, - resources: self.resources, logger: self.logger, id_provider: self.id_provider, service_builder: self.service_builder, @@ -527,27 +527,6 @@ impl Builder { } } -pub(crate) enum MethodResult { - JustLogger(MethodResponse), - SendAndLogger(MethodResponse), -} - -impl MethodResult { - pub(crate) fn as_inner(&self) -> &MethodResponse { - match &self { - Self::JustLogger(r) => r, - Self::SendAndLogger(r) => r, - } - } - - pub(crate) fn into_inner(self) -> MethodResponse { - match self { - Self::JustLogger(r) => r, - Self::SendAndLogger(r) => r, - } - } -} - /// Data required by the server to handle requests. #[derive(Debug, Clone)] pub(crate) struct ServiceData { @@ -557,8 +536,6 @@ pub(crate) struct ServiceData { pub(crate) methods: Methods, /// Access control. pub(crate) allow_hosts: AllowHosts, - /// Tracker for currently used resources on the server. - pub(crate) resources: Resources, /// Max request body size. pub(crate) max_request_body_size: u32, /// Max response body size. @@ -567,6 +544,8 @@ pub(crate) struct ServiceData { /// /// Logs bigger than this limit will be truncated. pub(crate) max_log_length: u32, + /// Maximum number of subscriptions per connection. + pub(crate) max_subscriptions_per_connection: u32, /// Whether batch requests are supported by this server or not. pub(crate) batch_requests_supported: bool, /// Subscription ID provider. @@ -575,8 +554,6 @@ pub(crate) struct ServiceData { pub(crate) ping_interval: Duration, /// Stop handle. pub(crate) stop_handle: StopHandle, - /// Max subscriptions per connection. - pub(crate) max_subscriptions_per_connection: u32, /// Connection ID pub(crate) conn_id: u32, /// Logger. @@ -587,6 +564,8 @@ pub(crate) struct ServiceData { pub(crate) enable_http: bool, /// Enable WS. pub(crate) enable_ws: bool, + /// Number of messages that server is allowed `buffer` until backpressure kicks in. + pub(crate) message_buffer_capacity: u32, } /// JsonRPSee service compatible with `tower`. @@ -672,7 +651,6 @@ impl hyper::service::Service> for TowerSe // The request wasn't an upgrade request; let's treat it as a standard HTTP request: let data = http::HandleRequest { methods: self.inner.methods.clone(), - resources: self.inner.resources.clone(), max_request_body_size: self.inner.max_request_body_size, max_response_body_size: self.inner.max_response_body_size, max_log_length: self.inner.max_log_length, @@ -748,8 +726,6 @@ struct ProcessConnection { methods: Methods, /// Access control. allow_hosts: AllowHosts, - /// Tracker for currently used resources on the server. - resources: Resources, /// Max request body size. max_request_body_size: u32, /// Max response body size. @@ -758,6 +734,8 @@ struct ProcessConnection { /// /// Logs bigger than this limit will be truncated. max_log_length: u32, + /// Maximum number of subscriptions per connection. + max_subscriptions_per_connection: u32, /// Whether batch requests are supported by this server or not. batch_requests_supported: bool, /// Subscription ID provider. @@ -766,8 +744,6 @@ struct ProcessConnection { ping_interval: Duration, /// Stop handle. stop_handle: StopHandle, - /// Max subscriptions per connection. - max_subscriptions_per_connection: u32, /// Max connections, max_connections: u32, /// Connection ID @@ -778,6 +754,8 @@ struct ProcessConnection { enable_http: bool, /// Allow JSON-RPC WS request and WS upgrade requests. enable_ws: bool, + /// Number of messages that server is allowed `buffer` until backpressure kicks in. + message_buffer_capacity: u32, } #[instrument(name = "connection", skip_all, fields(remote_addr = %cfg.remote_addr, conn_id = %cfg.conn_id), level = "INFO")] @@ -823,20 +801,20 @@ fn process_connection<'a, L: Logger, B, U>( remote_addr: cfg.remote_addr, methods: cfg.methods, allow_hosts: cfg.allow_hosts, - resources: cfg.resources, max_request_body_size: cfg.max_request_body_size, max_response_body_size: cfg.max_response_body_size, max_log_length: cfg.max_log_length, + max_subscriptions_per_connection: cfg.max_subscriptions_per_connection, batch_requests_supported: cfg.batch_requests_supported, id_provider: cfg.id_provider, ping_interval: cfg.ping_interval, stop_handle: cfg.stop_handle.clone(), - max_subscriptions_per_connection: cfg.max_subscriptions_per_connection, conn_id: cfg.conn_id, logger: cfg.logger, conn: Arc::new(conn), enable_http: cfg.enable_http, enable_ws: cfg.enable_ws, + message_buffer_capacity: cfg.message_buffer_capacity, }, }; diff --git a/server/src/tests/helpers.rs b/server/src/tests/helpers.rs index 6259d6697b..ccad473cd5 100644 --- a/server/src/tests/helpers.rs +++ b/server/src/tests/helpers.rs @@ -82,18 +82,16 @@ pub(crate) async fn server_with_handles() -> (SocketAddr, ServerHandle) { }) .unwrap(); module - .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, mut sink, _| { - sink.accept()?; - - tokio::spawn(async move { - loop { - let _ = &sink; - tokio::time::sleep(std::time::Duration::from_secs(30)).await; - } - }); - Ok(()) + .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, pending, _| async move { + let sink = pending.accept().await?; + + loop { + let _ = &sink; + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + } }) .unwrap(); + module.register_method("notif", |_, _| Ok("")).unwrap(); module .register_method("should_err", |_, ctx| { diff --git a/server/src/tests/ws.rs b/server/src/tests/ws.rs index 8795490bf9..94f4f612bb 100644 --- a/server/src/tests/ws.rs +++ b/server/src/tests/ws.rs @@ -27,10 +27,12 @@ use crate::tests::helpers::{deser_call, init_logger, server_with_context}; use crate::types::SubscriptionId; use crate::{RpcModule, ServerBuilder}; +use jsonrpsee_core::server::rpc_module::{SendTimeoutError, SubscriptionMessage}; use jsonrpsee_core::{traits::IdProvider, Error}; use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::mocks::{Id, WebSocketTestClient, WebSocketTestError}; use jsonrpsee_test_utils::TimeoutFutureExt; +use jsonrpsee_types::SubscriptionResponse; use serde_json::Value as JsonValue; use super::helpers::server; @@ -402,10 +404,10 @@ async fn register_methods_works() { assert!(module.register_method("say_hello", |_, _| Ok("lo")).is_ok()); assert!(module.register_method("say_hello", |_, _| Ok("lo")).is_err()); assert!(module - .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, _, _| { Ok(()) }) + .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, _, _| async { Ok(()) }) .is_ok()); assert!(module - .register_subscription("subscribe_hello_again", "subscribe_hello_again", "unsubscribe_hello", |_, _, _| { + .register_subscription("subscribe_hello_again", "subscribe_hello_again", "unsubscribe_hello", |_, _, _| async { Ok(()) }) .is_err()); @@ -419,7 +421,8 @@ async fn register_methods_works() { async fn register_same_subscribe_unsubscribe_is_err() { let mut module = RpcModule::new(()); assert!(matches!( - module.register_subscription("subscribe_hello", "subscribe_hello", "subscribe_hello", |_, _, _| { Ok(()) }), + module + .register_subscription("subscribe_hello", "subscribe_hello", "subscribe_hello", |_, _, _| async { Ok(()) }), Err(Error::SubscriptionNameConflict(_)) )); } @@ -546,23 +549,15 @@ async fn custom_subscription_id_works() { let addr = server.local_addr().unwrap(); let mut module = RpcModule::new(()); module - .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, mut sink, _| { - // There is no subscription ID prior to calling accept. - let sub_id = sink.subscription_id(); - assert!(sub_id.is_none()); + .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, sink, _| async { + let sink = sink.accept().await.unwrap(); - sink.accept()?; + assert!(matches!(sink.subscription_id(), SubscriptionId::Str(id) if id == "0xdeadbeef")); - let sub_id = sink.subscription_id(); - assert!(matches!(sub_id, Some(SubscriptionId::Str(id)) if id == "0xdeadbeef")); - - tokio::spawn(async move { - loop { - let _ = &sink; - tokio::time::sleep(std::time::Duration::from_secs(30)).await; - } - }); - Ok(()) + loop { + let _ = &sink; + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + } }) .unwrap(); let _handle = server.start(module).unwrap(); @@ -653,3 +648,100 @@ async fn batch_with_mixed_calls() { let response = client.send_request_text(req.to_string()).with_default_timeout().await.unwrap().unwrap(); assert_eq!(response, res); } + +#[tokio::test] +async fn ws_server_backpressure_works() { + init_logger(); + + let (backpressure_tx, mut backpressure_rx) = tokio::sync::mpsc::channel::<()>(1); + + let server = ServerBuilder::default() + .set_message_buffer_capacity(5) + .build("127.0.0.1:0") + .with_default_timeout() + .await + .unwrap() + .unwrap(); + + let mut module = RpcModule::new(backpressure_tx); + + module + .register_subscription( + "subscribe_with_backpressure_aggregation", + "n", + "unsubscribe_with_backpressure_aggregation", + move |_, pending, mut backpressure_tx| async move { + let sink = pending.accept().await?; + let n = SubscriptionMessage::from_json(&1).unwrap(); + let bp = SubscriptionMessage::from_json(&2).unwrap(); + + let mut msg = n.clone(); + + loop { + tokio::select! { + biased; + _ = sink.closed() => { + // User closed connection. + break; + }, + res = sink.send_timeout(msg.clone(), std::time::Duration::from_millis(100)) => { + match res { + // msg == 1 + Ok(_) => { + msg = n.clone(); + } + Err(SendTimeoutError::Closed(_)) => break, + // msg == 2 + Err(SendTimeoutError::Timeout(_)) => { + let b_tx = std::sync::Arc::make_mut(&mut backpressure_tx); + let _ = b_tx.send(()).await; + msg = bp.clone(); + } + }; + }, + } + } + Ok(()) + }, + ) + .unwrap(); + let addr = server.local_addr().unwrap(); + + let _server_handle = server.start(module).unwrap(); + + // Send a valid batch. + let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap(); + let req = r#" + {"jsonrpc":"2.0","method":"subscribe_with_backpressure_aggregation", "params":[],"id":1}"#; + client.send(req).with_default_timeout().await.unwrap().unwrap(); + + backpressure_rx.recv().await.unwrap(); + + let now = std::time::Instant::now(); + let mut msg; + + // Assert that first `item == 2` was sent and then + // the client start reading the socket again the buffered items should be sent. + // Thus, eventually `item == 1` should be sent again. + let mut seen_backpressure_item = false; + let mut seen_item_after_backpressure = false; + + while now.elapsed() < std::time::Duration::from_secs(10) { + msg = client.receive().with_default_timeout().await.unwrap().unwrap(); + if let Ok(sub_notif) = serde_json::from_str::>(&msg) { + match sub_notif.params.result { + 1 if seen_backpressure_item => { + seen_item_after_backpressure = true; + break; + } + 2 => { + seen_backpressure_item = true; + } + _ => (), + } + } + } + + assert!(seen_backpressure_item); + assert!(seen_item_after_backpressure); +} diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index ad693bd708..9337b026ac 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -11,7 +11,7 @@ use jsonrpsee_core::error::GenericTransportError; use jsonrpsee_core::http_helpers::read_body; use jsonrpsee_core::server::helpers::{prepare_error, BatchResponse, BatchResponseBuilder, MethodResponse}; use jsonrpsee_core::server::rpc_module::MethodKind; -use jsonrpsee_core::server::{resource_limiting::Resources, rpc_module::Methods}; +use jsonrpsee_core::server::rpc_module::Methods; use jsonrpsee_core::tracing::{rx_log_from_json, tx_log_from_str}; use jsonrpsee_core::JsonRawValue; use jsonrpsee_types::error::{ErrorCode, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG}; @@ -51,7 +51,6 @@ pub(crate) struct ProcessValidatedRequest<'a, L: Logger> { pub(crate) request: hyper::Request, pub(crate) logger: &'a L, pub(crate) methods: Methods, - pub(crate) resources: Resources, pub(crate) max_request_body_size: u32, pub(crate) max_response_body_size: u32, pub(crate) max_log_length: u32, @@ -67,7 +66,6 @@ pub(crate) async fn process_validated_request( request, logger, methods, - resources, max_request_body_size, max_response_body_size, max_log_length, @@ -89,15 +87,8 @@ pub(crate) async fn process_validated_request( // Single request or notification if is_single { - let call = CallData { - conn_id: 0, - logger, - methods: &methods, - max_response_body_size, - max_log_length, - resources: &resources, - request_start, - }; + let call = + CallData { conn_id: 0, logger, methods: &methods, max_response_body_size, max_log_length, request_start }; let response = process_single_request(body, call).await; logger.on_response(&response.result, request_start, TransportProtocol::Http); response::ok_response(response.result) @@ -121,7 +112,6 @@ pub(crate) async fn process_validated_request( methods: &methods, max_response_body_size, max_log_length, - resources: &resources, request_start, }, }) @@ -144,7 +134,6 @@ pub(crate) struct CallData<'a, L: Logger> { methods: &'a Methods, max_response_body_size: u32, max_log_length: u32, - resources: &'a Resources, request_start: L::Instant, } @@ -221,7 +210,7 @@ pub(crate) async fn execute_call_with_tracing<'a, L: Logger>( } pub(crate) async fn execute_call(req: Request<'_>, call: CallData<'_, L>) -> MethodResponse { - let CallData { resources, methods, logger, max_response_body_size, max_log_length, conn_id, request_start } = call; + let CallData { methods, logger, max_response_body_size, max_log_length, conn_id, request_start } = call; rx_log_from_json(&req, call.max_log_length); @@ -237,33 +226,15 @@ pub(crate) async fn execute_call(req: Request<'_>, call: CallData<'_, Some((name, method)) => match &method.inner() { MethodKind::Sync(callback) => { logger.on_call(name, params.clone(), logger::MethodKind::MethodCall, TransportProtocol::Http); - - match method.claim(name, resources) { - Ok(guard) => { - let r = (callback)(id, params, max_response_body_size as usize); - drop(guard); - r - } - Err(err) => { - tracing::error!("[Methods::execute_with_resources] failed to lock resources: {}", err); - MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)) - } - } + (callback)(id, params, max_response_body_size as usize) } MethodKind::Async(callback) => { logger.on_call(name, params.clone(), logger::MethodKind::MethodCall, TransportProtocol::Http); - match method.claim(name, resources) { - Ok(guard) => { - let id = id.into_owned(); - let params = params.into_owned(); - - (callback)(id, params, conn_id, max_response_body_size as usize, Some(guard)).await - } - Err(err) => { - tracing::error!("[Methods::execute_with_resources] failed to lock resources: {}", err); - MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)) - } - } + + let id = id.into_owned(); + let params = params.into_owned(); + + (callback)(id, params, conn_id, max_response_body_size as usize).await } MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => { logger.on_call(name, params.clone(), logger::MethodKind::Unknown, TransportProtocol::Http); @@ -288,7 +259,6 @@ fn execute_notification(notif: Notif, max_log_length: u32) -> MethodResponse { pub(crate) struct HandleRequest { pub(crate) methods: Methods, - pub(crate) resources: Resources, pub(crate) max_request_body_size: u32, pub(crate) max_response_body_size: u32, pub(crate) max_log_length: u32, @@ -304,7 +274,6 @@ pub(crate) async fn handle_request( ) -> hyper::Response { let HandleRequest { methods, - resources, max_request_body_size, max_response_body_size, max_log_length, @@ -322,7 +291,6 @@ pub(crate) async fn handle_request( process_validated_request(ProcessValidatedRequest { request, methods, - resources, max_request_body_size, max_response_body_size, max_log_length, diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index f2ed5b375d..479f23d639 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -4,9 +4,8 @@ use std::time::Duration; use crate::future::{FutureDriver, StopHandle}; use crate::logger::{self, Logger, TransportProtocol}; -use crate::server::{MethodResult, ServiceData}; +use crate::server::ServiceData; -use futures_channel::mpsc; use futures_util::future::{self, Either}; use futures_util::io::{BufReader, BufWriter}; use futures_util::stream::FuturesOrdered; @@ -15,8 +14,7 @@ use hyper::upgrade::Upgraded; use jsonrpsee_core::server::helpers::{ prepare_error, BatchResponse, BatchResponseBuilder, BoundedSubscriptions, MethodResponse, MethodSink, }; -use jsonrpsee_core::server::resource_limiting::Resources; -use jsonrpsee_core::server::rpc_module::{ConnState, MethodKind, Methods}; +use jsonrpsee_core::server::rpc_module::{CallOrSubscription, ConnState, MethodKind, Methods, SubscriptionAnswered}; use jsonrpsee_core::tracing::{rx_log_from_json, tx_log_from_str}; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{Error, JsonRawValue}; @@ -27,7 +25,8 @@ use jsonrpsee_types::error::{ use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Params, Request}; use soketto::connection::Error as SokettoError; use soketto::data::ByteSlice125; -use tokio_stream::wrappers::IntervalStream; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::{IntervalStream, ReceiverStream}; use tokio_util::compat::Compat; use tracing::instrument; @@ -63,7 +62,6 @@ pub(crate) struct CallData<'a, L: Logger> { pub(crate) methods: &'a Methods, pub(crate) max_response_body_size: u32, pub(crate) max_log_length: u32, - pub(crate) resources: &'a Resources, pub(crate) sink: &'a MethodSink, pub(crate) logger: &'a L, pub(crate) request_start: L::Instant, @@ -118,7 +116,7 @@ pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> Option< .into_iter() .filter_map(|v| { if let Ok(req) = serde_json::from_str::(v.get()) { - Some(Either::Right(async { execute_call(req, call.clone()).await.into_inner() })) + Some(Either::Right(async { execute_call(req, call.clone()).await.into_response() })) } else if let Ok(_notif) = serde_json::from_str::>(v.get()) { // notifications should not be answered. got_notif = true; @@ -153,17 +151,20 @@ pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> Option< } } -pub(crate) async fn process_single_request(data: Vec, call: CallData<'_, L>) -> MethodResult { +pub(crate) async fn process_single_request(data: Vec, call: CallData<'_, L>) -> CallOrSubscription { if let Ok(req) = serde_json::from_slice::(&data) { execute_call_with_tracing(req, call).await } else { let (id, code) = prepare_error(&data); - MethodResult::SendAndLogger(MethodResponse::error(id, ErrorObject::from(code))) + CallOrSubscription::Call(MethodResponse::error(id, ErrorObject::from(code))) } } #[instrument(name = "method_call", fields(method = req.method.as_ref()), skip(call, req), level = "TRACE")] -pub(crate) async fn execute_call_with_tracing<'a, L: Logger>(req: Request<'a>, call: CallData<'_, L>) -> MethodResult { +pub(crate) async fn execute_call_with_tracing<'a, L: Logger>( + req: Request<'a>, + call: CallData<'_, L>, +) -> CallOrSubscription { execute_call(req, call).await } @@ -172,18 +173,17 @@ pub(crate) async fn execute_call_with_tracing<'a, L: Logger>(req: Request<'a>, c /// /// Returns `(MethodResponse, None)` on every call that isn't a subscription /// Otherwise `(MethodResponse, Some(PendingSubscriptionCallTx)`. -pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData<'_, L>) -> MethodResult { +pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData<'_, L>) -> CallOrSubscription { let CallData { - resources, methods, max_response_body_size, max_log_length, conn_id, - bounded_subscriptions, id_provider, sink, logger, request_start, + bounded_subscriptions, } = call; rx_log_from_json(&req, call.max_log_length); @@ -196,61 +196,33 @@ pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData None => { logger.on_call(name, params.clone(), logger::MethodKind::Unknown, TransportProtocol::WebSocket); let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)); - MethodResult::SendAndLogger(response) + CallOrSubscription::Call(response) } Some((name, method)) => match &method.inner() { MethodKind::Sync(callback) => { logger.on_call(name, params.clone(), logger::MethodKind::MethodCall, TransportProtocol::WebSocket); - match method.claim(name, resources) { - Ok(guard) => { - let r = (callback)(id, params, max_response_body_size as usize); - drop(guard); - MethodResult::SendAndLogger(r) - } - Err(err) => { - tracing::error!("[Methods::execute_with_resources] failed to lock resources: {}", err); - let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); - MethodResult::SendAndLogger(response) - } - } + CallOrSubscription::Call((callback)(id, params, max_response_body_size as usize)) } MethodKind::Async(callback) => { logger.on_call(name, params.clone(), logger::MethodKind::MethodCall, TransportProtocol::WebSocket); - match method.claim(name, resources) { - Ok(guard) => { - let id = id.into_owned(); - let params = params.into_owned(); - - let response = - (callback)(id, params, conn_id, max_response_body_size as usize, Some(guard)).await; - MethodResult::SendAndLogger(response) - } - Err(err) => { - tracing::error!("[Methods::execute_with_resources] failed to lock resources: {}", err); - let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); - MethodResult::SendAndLogger(response) - } - } + + let id = id.into_owned(); + let params = params.into_owned(); + + let response = (callback)(id, params, conn_id, max_response_body_size as usize).await; + CallOrSubscription::Call(response) } MethodKind::Subscription(callback) => { logger.on_call(name, params.clone(), logger::MethodKind::Subscription, TransportProtocol::WebSocket); - match method.claim(name, resources) { - Ok(guard) => { - if let Some(cn) = bounded_subscriptions.acquire() { - let conn_state = ConnState { conn_id, close_notify: cn, id_provider }; - let response = callback(id.clone(), params, sink.clone(), conn_state, Some(guard)).await; - MethodResult::JustLogger(response) - } else { - let response = - MethodResponse::error(id, reject_too_many_subscriptions(bounded_subscriptions.max())); - MethodResult::SendAndLogger(response) - } - } - Err(err) => { - tracing::error!("[Methods::execute_with_resources] failed to lock resources: {}", err); - let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); - MethodResult::SendAndLogger(response) - } + + if let Some(p) = bounded_subscriptions.acquire() { + let conn_state = ConnState { conn_id, id_provider, subscription_permit: p }; + let response = callback(id.clone(), params, sink.clone(), conn_state).await; + CallOrSubscription::Subscription(response) + } else { + let response = + MethodResponse::error(id, reject_too_many_subscriptions(bounded_subscriptions.max())); + CallOrSubscription::Call(response) } } MethodKind::Unsubscription(callback) => { @@ -258,12 +230,12 @@ pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData // Don't adhere to any resource or subscription limits; always let unsubscribing happen! let result = callback(id, params, conn_id, max_response_body_size as usize); - MethodResult::SendAndLogger(result) + CallOrSubscription::Call(result) } }, }; - let r = response.as_inner(); + let r = response.as_response(); tx_log_from_str(&r.result, max_log_length); logger.on_result(name, r.success, request_start, TransportProtocol::WebSocket); @@ -277,28 +249,29 @@ pub(crate) async fn background_task( ) -> Result<(), Error> { let ServiceData { methods, - resources, max_request_body_size, max_response_body_size, max_log_length, + max_subscriptions_per_connection, batch_requests_supported, stop_handle, id_provider, ping_interval, - max_subscriptions_per_connection, conn_id, logger, remote_addr, + message_buffer_capacity, conn, .. } = svc; - let (tx, rx) = mpsc::unbounded::(); - let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection); + let (tx, rx) = mpsc::channel::(message_buffer_capacity as usize); + let (conn_tx, conn_rx) = oneshot::channel(); let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length); + let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection); // Spawn another task that sends out the responses on the Websocket. - tokio::spawn(send_task(rx, sender, stop_handle.clone(), ping_interval)); + tokio::spawn(send_task(rx, sender, stop_handle.clone(), ping_interval, conn_rx)); // Buffer for incoming data. let mut data = Vec::with_capacity(100); @@ -308,6 +281,20 @@ pub(crate) async fn background_task( let result = loop { data.clear(); + let sink_permit_fut = sink.reserve(); + + tokio::pin!(sink_permit_fut); + + // Wait until there is a slot in the bounded channel which means that + // the underlying TCP socket won't be read. + // + // This will force the client to read socket on the other side + // otherwise the socket will not be read again. + let sink_permit = match method_executors.select_with(Monitored::new(sink_permit_fut, &stop_handle)).await { + Ok(permit) => permit, + Err(_) => break Ok(()), + }; + { // Need the extra scope to drop this pinned future and reclaim access to `data` let receive = async { @@ -334,18 +321,19 @@ pub(crate) async fn background_task( break Ok(()); } MonitoredError::Selector(SokettoError::MessageTooLarge { current, maximum }) => { - tracing::warn!( + tracing::debug!( "WS transport error: request length: {} exceeded max limit: {} bytes", current, maximum ); - sink.send_error(Id::Null, reject_too_big_request(max_request_body_size)); + sink_permit.send_error(Id::Null, reject_too_big_request(max_request_body_size)); + continue; } // These errors can not be gracefully handled, so just log them and terminate the connection. MonitoredError::Selector(err) => { - tracing::error!("WS transport error: {}; terminate connection: {}", err, conn_id); + tracing::debug!("WS transport error: {}; terminate connection: {}", err, conn_id); break Err(err.into()); } MonitoredError::Shutdown => { @@ -362,19 +350,17 @@ pub(crate) async fn background_task( Some(b'{') => { let data = std::mem::take(&mut data); let sink = sink.clone(); - let resources = &resources; let methods = &methods; - let bounded_subscriptions = bounded_subscriptions.clone(); let id_provider = &*id_provider; + let bounded_subscriptions = bounded_subscriptions.clone(); let fut = async move { let call = CallData { conn_id: conn_id as usize, - resources, + bounded_subscriptions, max_response_body_size, max_log_length, methods, - bounded_subscriptions, sink: &sink, id_provider, logger, @@ -382,12 +368,16 @@ pub(crate) async fn background_task( }; match process_single_request(data, call).await { - MethodResult::JustLogger(r) => { + CallOrSubscription::Subscription(SubscriptionAnswered::Yes(r)) => { logger.on_response(&r.result, request_start, TransportProtocol::WebSocket); } - MethodResult::SendAndLogger(r) => { + CallOrSubscription::Subscription(SubscriptionAnswered::No(r)) => { logger.on_response(&r.result, request_start, TransportProtocol::WebSocket); - let _ = sink.send_raw(r.result); + sink_permit.send_raw(r.result); + } + CallOrSubscription::Call(r) => { + logger.on_response(&r.result, request_start, TransportProtocol::WebSocket); + sink_permit.send_raw(r.result); } }; } @@ -401,27 +391,25 @@ pub(crate) async fn background_task( ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None), ); logger.on_response(&response.result, request_start, TransportProtocol::WebSocket); - let _ = sink.send_raw(response.result); + sink_permit.send_raw(response.result); } Some(b'[') => { // Make sure the following variables are not moved into async closure below. - let resources = &resources; let methods = &methods; - let bounded_subscriptions = bounded_subscriptions.clone(); let sink = sink.clone(); let id_provider = id_provider.clone(); let data = std::mem::take(&mut data); + let bounded_subscriptions = bounded_subscriptions.clone(); let fut = async move { let response = process_batch_request(Batch { data, call: CallData { conn_id: conn_id as usize, - resources, + bounded_subscriptions, max_response_body_size, max_log_length, methods, - bounded_subscriptions, sink: &sink, id_provider: &*id_provider, logger, @@ -433,14 +421,14 @@ pub(crate) async fn background_task( if let Some(response) = response { tx_log_from_str(&response.result, max_log_length); logger.on_response(&response.result, request_start, TransportProtocol::WebSocket); - let _ = sink.send_raw(response.result); + sink_permit.send_raw(response.result); } }; method_executors.add(Box::pin(fut)); } _ => { - sink.send_error(Id::Null, ErrorCode::ParseError.into()); + sink_permit.send_error(Id::Null, ErrorCode::ParseError.into()); } } }; @@ -452,10 +440,7 @@ pub(crate) async fn background_task( // proper drop behaviour. method_executors.await; - // Notify all listeners and close down associated tasks. - sink.close(); - bounded_subscriptions.close(); - + let _ = conn_tx.send(()); drop(conn); result @@ -463,22 +448,23 @@ pub(crate) async fn background_task( /// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`. async fn send_task( - mut rx: mpsc::UnboundedReceiver, + rx: mpsc::Receiver, mut ws_sender: Sender, mut stop_handle: StopHandle, ping_interval: Duration, + conn_closed: oneshot::Receiver<()>, ) { - // Received messages from the WebSocket. - let mut rx_item = rx.next(); - // Interval to send out continuously `pings`. let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval)); let stopped = stop_handle.shutdown(); + let rx = ReceiverStream::new(rx); - tokio::pin!(ping_interval, stopped); + tokio::pin!(ping_interval, stopped, rx, conn_closed); + // Received messages from the WebSocket. + let mut rx_item = rx.next(); let next_ping = ping_interval.next(); - let mut futs = future::select(next_ping, stopped); + let mut futs = future::select(next_ping, future::select(stopped, conn_closed)); loop { // Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish yet. @@ -488,9 +474,10 @@ async fn send_task( Either::Left((Some(response), not_ready)) => { // If websocket message send fail then terminate the connection. if let Err(err) = send_message(&mut ws_sender, response).await { - tracing::error!("WS transport error: send failed: {}", err); + tracing::debug!("WS transport error: send failed: {}", err); break; } + rx_item = rx.next(); futs = not_ready; } @@ -503,15 +490,15 @@ async fn send_task( // Handle timer intervals. Either::Right((Either::Left((_, stop)), next_rx)) => { if let Err(err) = send_ping(&mut ws_sender).await { - tracing::error!("WS transport error: send ping failed: {}", err); + tracing::debug!("WS transport error: send ping failed: {}", err); break; } rx_item = next_rx; futs = future::select(ping_interval.next(), stop); } - // Server is closed - Either::Right((Either::Right((_, _)), _)) => { + // Server is stopped or closed + Either::Right((Either::Right(_), _)) => { break; } } @@ -519,4 +506,5 @@ async fn send_task( // Terminate connection and send close message. let _ = ws_sender.close().await; + rx.close(); } diff --git a/test-utils/src/mocks.rs b/test-utils/src/mocks.rs index 38819ddbf3..e5bdb4a806 100644 --- a/test-utils/src/mocks.rs +++ b/test-utils/src/mocks.rs @@ -118,16 +118,24 @@ impl WebSocketTestClient { } pub async fn send_request_text(&mut self, msg: impl AsRef) -> Result { - self.tx.send_text(msg).await?; - self.tx.flush().await?; + self.send(msg).await?; let mut data = Vec::new(); self.rx.receive_data(&mut data).await?; String::from_utf8(data).map_err(Into::into) } + pub async fn send(&mut self, msg: impl AsRef) -> Result<(), Error> { + self.tx.send_text(msg).await?; + self.tx.flush().await.map_err(Into::into) + } + pub async fn send_request_binary(&mut self, msg: &[u8]) -> Result { self.tx.send_binary(msg).await?; self.tx.flush().await?; + self.receive().await + } + + pub async fn receive(&mut self) -> Result { let mut data = Vec::new(); self.rx.receive_data(&mut data).await?; String::from_utf8(data).map_err(Into::into) diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 2009167a1f..380c92d14e 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -29,13 +29,16 @@ use std::net::SocketAddr; use std::time::Duration; -use futures::{SinkExt, StreamExt}; -use jsonrpsee::core::error::{Error, SubscriptionClosed}; +use futures::{SinkExt, Stream, StreamExt}; use jsonrpsee::core::server::host_filtering::AllowHosts; +use jsonrpsee::core::server::rpc_module::{SubscriptionMessage, TrySendError}; +use jsonrpsee::core::{Error, SubscriptionResult}; use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; use jsonrpsee::server::{ServerBuilder, ServerHandle}; -use jsonrpsee::types::error::{ErrorObject, SUBSCRIPTION_CLOSED_WITH_ERROR}; -use jsonrpsee::RpcModule; +use jsonrpsee::types::error::ErrorObject; +use jsonrpsee::types::ErrorObjectOwned; +use jsonrpsee::{PendingSubscriptionSink, RpcModule}; +use serde::Serialize; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; use tower_http::cors::CorsLayer; @@ -48,25 +51,21 @@ pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle) module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); module - .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, mut sink, _| { + .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, pending, _| async move { let interval = interval(Duration::from_millis(50)); let stream = IntervalStream::new(interval).map(move |_| &"hello from subscription"); + pipe_from_stream_and_drop(pending, stream).await?; - tokio::spawn(async move { - sink.pipe_from_stream(stream).await; - }); Ok(()) }) .unwrap(); module - .register_subscription("subscribe_foo", "subscribe_foo", "unsubscribe_foo", |_, mut sink, _| { + .register_subscription("subscribe_foo", "subscribe_foo", "unsubscribe_foo", |_, pending, _| async { let interval = interval(Duration::from_millis(100)); let stream = IntervalStream::new(interval).map(move |_| 1337_usize); + pipe_from_stream_and_drop(pending, stream).await?; - tokio::spawn(async move { - sink.pipe_from_stream(stream).await; - }); Ok(()) }) .unwrap(); @@ -76,102 +75,47 @@ pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle) "subscribe_add_one", "subscribe_add_one", "unsubscribe_add_one", - |params, mut sink, _| { - let count = params.one::().map(|c| c.wrapping_add(1))?; + |params, pending, _| async move { + let count = match params.one::().map(|c| c.wrapping_add(1)) { + Ok(count) => count, + Err(e) => { + let _ = pending.reject(ErrorObjectOwned::from(e)).await; + return Ok(()); + } + }; let wrapping_counter = futures::stream::iter((count..).cycle()); let interval = interval(Duration::from_millis(100)); let stream = IntervalStream::new(interval).zip(wrapping_counter).map(move |(_, c)| c); + pipe_from_stream_and_drop(pending, stream).await?; - tokio::spawn(async move { - sink.pipe_from_stream(stream).await; - }); Ok(()) }, ) .unwrap(); module - .register_subscription("subscribe_noop", "subscribe_noop", "unsubscribe_noop", |_, mut sink, _| { - sink.accept().unwrap(); - - tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(1)).await; - let err = ErrorObject::owned( - SUBSCRIPTION_CLOSED_WITH_ERROR, - "Server closed the stream because it was lazy", - None::<()>, - ); - sink.close(err); - }); - Ok(()) - }) - .unwrap(); + .register_subscription("subscribe_noop", "subscribe_noop", "unsubscribe_noop", |_, pending, _| async { + let sink = pending.accept().await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + let err = ErrorObject::owned(1, "Server closed the stream because it was lazy", None::<()>); + sink.close(err).await; - module - .register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, mut sink, _| { - tokio::spawn(async move { - let interval = interval(Duration::from_millis(50)); - let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); - - match sink.pipe_from_stream(stream).await { - SubscriptionClosed::Success => { - sink.close(SubscriptionClosed::Success); - } - _ => unreachable!(), - } - }); Ok(()) }) .unwrap(); module - .register_subscription("can_reuse_subscription", "n", "u_can_reuse_subscription", |_, mut sink, _| { - tokio::spawn(async move { - let stream1 = IntervalStream::new(interval(Duration::from_millis(50))) - .zip(futures::stream::iter(1..=5)) - .map(|(_, c)| c); - let stream2 = IntervalStream::new(interval(Duration::from_millis(50))) - .zip(futures::stream::iter(6..=10)) - .map(|(_, c)| c); - - let result = sink.pipe_from_stream(stream1).await; - assert!(matches!(result, SubscriptionClosed::Success)); - - match sink.pipe_from_stream(stream2).await { - SubscriptionClosed::Success => { - sink.close(SubscriptionClosed::Success); - } - _ => unreachable!(), - } - }); + .register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| async move { + let interval = interval(Duration::from_millis(50)); + let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); + tracing::info!("pipe_from_stream"); + pipe_from_stream_and_drop(pending, stream).await?; + Ok(()) }) .unwrap(); - module - .register_subscription( - "subscribe_with_err_on_stream", - "n", - "unsubscribe_with_err_on_stream", - move |_, mut sink, _| { - let err: &'static str = "error on the stream"; - - // Create stream that produce an error which will cancel the subscription. - let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]); - tokio::spawn(async move { - match sink.pipe_from_try_stream(stream).await { - SubscriptionClosed::Failed(e) => { - sink.close(e); - } - _ => unreachable!(), - } - }); - Ok(()) - }, - ) - .unwrap(); - let addr = server.local_addr().unwrap(); let server_handle = server.start(module).unwrap(); @@ -220,15 +164,15 @@ pub async fn server_with_sleeping_subscription(tx: futures::channel::mpsc::Sende let mut module = RpcModule::new(tx); module - .register_subscription("subscribe_sleep", "n", "unsubscribe_sleep", |_, mut sink, mut tx| { - tokio::spawn(async move { - let interval = interval(Duration::from_secs(60 * 60)); - let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); - - sink.pipe_from_stream(stream).await; - let send_back = std::sync::Arc::make_mut(&mut tx); - send_back.send(()).await.unwrap(); - }); + .register_subscription("subscribe_sleep", "n", "unsubscribe_sleep", |_, pending, mut tx| async move { + let interval = interval(Duration::from_secs(60 * 60)); + let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); + + pipe_from_stream_and_drop(pending, stream).await?; + + let send_back = std::sync::Arc::make_mut(&mut tx); + send_back.send(()).await.unwrap(); + Ok(()) }) .unwrap(); @@ -273,3 +217,40 @@ pub fn init_logger() { .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .try_init(); } + +pub async fn pipe_from_stream_and_drop( + pending: PendingSubscriptionSink, + mut stream: impl Stream + Unpin, +) -> SubscriptionResult { + let mut sink = pending.accept().await?; + + loop { + tokio::select! { + _ = sink.closed() => break, + maybe_item = stream.next() => { + let item = match maybe_item { + Some(item) => item, + None => break, + }; + let msg = match SubscriptionMessage::from_json(&item) { + Ok(msg) => msg, + Err(e) => { + sink.close(ErrorObject::owned(1, e.to_string(), None::<()>)).await; + return Err(e.into()); + } + }; + + match sink.try_send(msg) { + Ok(_) => (), + Err(TrySendError::Closed(_)) => break, + // channel is full, let's be naive an just drop the message. + Err(TrySendError::Full(_)) => (), + } + } + } + } + + sink.close(ErrorObject::owned(1, "Subscription executed successful", None::<()>)).await; + + Ok(()) +} diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 8069c1887d..4d4381d4ec 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -34,13 +34,13 @@ use std::time::Duration; use futures::{channel::mpsc, StreamExt, TryStreamExt}; use helpers::{ - init_logger, server, server_with_access_control, server_with_health_api, server_with_subscription, - server_with_subscription_and_handle, + init_logger, pipe_from_stream_and_drop, server, server_with_access_control, server_with_health_api, + server_with_subscription, server_with_subscription_and_handle, }; use hyper::http::HeaderValue; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; -use jsonrpsee::core::error::SubscriptionClosed; use jsonrpsee::core::params::{ArrayParams, BatchRequestBuilder}; +use jsonrpsee::core::server::rpc_module::SubscriptionMessage; use jsonrpsee::core::{Error, JsonValue}; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::rpc_params; @@ -434,20 +434,21 @@ async fn ws_server_should_stop_subscription_after_client_drop() { let mut module = RpcModule::new(tx); module - .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, mut sink, mut tx| { - sink.accept().unwrap(); - tokio::spawn(async move { - let close_err = loop { - if !sink.send(&1_usize).expect("usize can be serialized; qed") { - break ErrorObject::borrowed(0, &"Subscription terminated successfully", None); - } - tokio::time::sleep(Duration::from_millis(100)).await; - }; + .register_subscription( + "subscribe_hello", + "subscribe_hello", + "unsubscribe_hello", + |_, pending, mut tx| async move { + let sink = pending.accept().await.unwrap(); + let msg = SubscriptionMessage::from_json(&1).unwrap(); + sink.send(msg).await.unwrap(); + sink.closed().await; let send_back = Arc::make_mut(&mut tx); - send_back.feed(close_err).await.unwrap(); - }); - Ok(()) - }) + send_back.feed("Subscription terminated by remote peer").await.unwrap(); + + Ok(()) + }, + ) .unwrap(); let _handle = server.start(module).unwrap(); @@ -464,7 +465,28 @@ async fn ws_server_should_stop_subscription_after_client_drop() { let close_err = rx.next().await.unwrap(); // assert that the server received `SubscriptionClosed` after the client was dropped. - assert_eq!(close_err, ErrorObject::borrowed(0, &"Subscription terminated successfully", None)); + assert_eq!(close_err, "Subscription terminated by remote peer"); +} + +#[tokio::test] +async fn ws_server_stop_subscription_when_dropped() { + use jsonrpsee::{server::ServerBuilder, RpcModule}; + + init_logger(); + + let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap(); + let server_url = format!("ws://{}", server.local_addr().unwrap()); + + let mut module = RpcModule::new(()); + + module + .register_subscription("subscribe_nop", "h", "unsubscribe_nop", |_params, _pending, _ctx| async { Ok(()) }) + .unwrap(); + + let _handle = server.start(module).unwrap(); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + assert!(client.subscribe::("subscribe_nop", rpc_params![], "unsubscribe_nop").await.is_err()); } #[tokio::test] @@ -568,24 +590,6 @@ async fn ws_server_cancels_subscriptions_on_reset_conn() { assert_eq!(rx_len, 10); } -#[tokio::test] -async fn ws_server_cancels_sub_stream_after_err() { - init_logger(); - - let addr = server_with_subscription().await; - let server_url = format!("ws://{}", addr); - - let client = WsClientBuilder::default().build(&server_url).await.unwrap(); - let mut sub: Subscription = client - .subscribe("subscribe_with_err_on_stream", rpc_params![], "unsubscribe_with_err_on_stream") - .await - .unwrap(); - - assert_eq!(sub.next().await.unwrap().unwrap(), 1); - // The server closed down the subscription with the underlying error from the stream. - assert!(sub.next().await.is_none()); -} - #[tokio::test] async fn ws_server_subscribe_with_stream() { init_logger(); @@ -642,22 +646,6 @@ async fn ws_server_pipe_from_stream_should_cancel_tasks_immediately() { assert_eq!(rx_len, 10); } -#[tokio::test] -async fn ws_server_pipe_from_stream_can_be_reused() { - init_logger(); - - let addr = server_with_subscription().await; - let client = WsClientBuilder::default().build(&format!("ws://{}", addr)).await.unwrap(); - let sub = client - .subscribe::("can_reuse_subscription", rpc_params![], "u_can_reuse_subscription") - .await - .unwrap(); - - let items = sub.fold(0, |acc, _| async move { acc + 1 }).await; - - assert_eq!(items, 10); -} - #[tokio::test] async fn ws_batch_works() { init_logger(); @@ -752,19 +740,11 @@ async fn ws_server_limit_subs_per_conn_works() { let mut module = RpcModule::new(()); module - .register_subscription("subscribe_forever", "n", "unsubscribe_forever", |_, mut sink, _| { - tokio::spawn(async move { - let interval = interval(Duration::from_millis(50)); - let stream = IntervalStream::new(interval).map(move |_| 0_usize); - - match sink.pipe_from_stream(stream).await { - SubscriptionClosed::Success => { - sink.close(SubscriptionClosed::Success); - } - _ => unreachable!(), - }; - }); - Ok(()) + .register_subscription("subscribe_forever", "n", "unsubscribe_forever", |_, pending, _| async move { + let interval = interval(Duration::from_millis(50)); + let stream = IntervalStream::new(interval).map(move |_| 0_usize); + + pipe_from_stream_and_drop(pending, stream).await }) .unwrap(); let _handle = server.start(module).unwrap(); @@ -815,19 +795,11 @@ async fn ws_server_unsub_methods_should_ignore_sub_limit() { let mut module = RpcModule::new(()); module - .register_subscription("subscribe_forever", "n", "unsubscribe_forever", |_, mut sink, _| { - tokio::spawn(async move { - let interval = interval(Duration::from_millis(50)); - let stream = IntervalStream::new(interval).map(move |_| 0_usize); - - match sink.pipe_from_stream(stream).await { - SubscriptionClosed::RemotePeerAborted => { - sink.close(SubscriptionClosed::RemotePeerAborted); - } - _ => unreachable!(), - }; - }); - Ok(()) + .register_subscription("subscribe_forever", "n", "unsubscribe_forever", |_, pending, _| async { + let interval = interval(Duration::from_millis(50)); + let stream = IntervalStream::new(interval).map(move |_| 0_usize); + + pipe_from_stream_and_drop(pending, stream).await }) .unwrap(); let _handle = server.start(module).unwrap(); diff --git a/tests/tests/metrics.rs b/tests/tests/metrics.rs index 8baf5d4f90..c5d7e32afa 100644 --- a/tests/tests/metrics.rs +++ b/tests/tests/metrics.rs @@ -113,12 +113,7 @@ fn test_module() -> RpcModule<()> { } async fn websocket_server(module: RpcModule<()>, counter: Counter) -> Result<(SocketAddr, ServerHandle), Error> { - let server = ServerBuilder::default() - .register_resource("CPU", 6, 2)? - .register_resource("MEM", 10, 1)? - .set_logger(counter) - .build("127.0.0.1:0") - .await?; + let server = ServerBuilder::default().set_logger(counter).build("127.0.0.1:0").await?; let addr = server.local_addr()?; let handle = server.start(module)?; @@ -127,12 +122,7 @@ async fn websocket_server(module: RpcModule<()>, counter: Counter) -> Result<(So } async fn http_server(module: RpcModule<()>, counter: Counter) -> Result<(SocketAddr, ServerHandle), Error> { - let server = ServerBuilder::default() - .register_resource("CPU", 6, 2)? - .register_resource("MEM", 10, 1)? - .set_logger(counter) - .build("127.0.0.1:0") - .await?; + let server = ServerBuilder::default().set_logger(counter).build("127.0.0.1:0").await?; let addr = server.local_addr()?; let handle = server.start(module)?; diff --git a/tests/tests/proc_macros.rs b/tests/tests/proc_macros.rs index 17d5b66ee7..d9997bee94 100644 --- a/tests/tests/proc_macros.rs +++ b/tests/tests/proc_macros.rs @@ -42,10 +42,10 @@ use jsonrpsee::ws_client::*; use serde_json::json; mod rpc_impl { - use jsonrpsee::core::{async_trait, RpcResult}; + use jsonrpsee::core::server::rpc_module::SubscriptionMessage; + use jsonrpsee::core::{async_trait, RpcResult, SubscriptionResult}; use jsonrpsee::proc_macros::rpc; - use jsonrpsee::types::SubscriptionResult; - use jsonrpsee::SubscriptionSink; + use jsonrpsee::PendingSubscriptionSink; #[rpc(client, server, namespace = "foo")] pub trait Rpc { @@ -56,10 +56,10 @@ mod rpc_impl { fn sync_method(&self) -> RpcResult; #[subscription(name = "sub", unsubscribe = "unsub", item = String)] - fn sub(&self); + async fn sub(&self); #[subscription(name = "echo", unsubscribe = "unsubscribe_echo", aliases = ["alias_echo"], item = u32)] - fn sub_with_params(&self, val: u32); + async fn sub_with_params(&self, val: u32); #[method(name = "params")] fn params(&self, a: u8, b: &str) -> RpcResult { @@ -116,7 +116,7 @@ mod rpc_impl { /// All head subscription #[subscription(name = "subscribeAllHeads", item = Header)] - fn subscribe_all_heads(&self, hash: Hash); + async fn subscribe_all_heads(&self, hash: Hash); } /// Trait to ensure that the trait bounds are correct. @@ -131,7 +131,7 @@ mod rpc_impl { pub trait OnlyGenericSubscription { /// Get header of a relay chain block. #[subscription(name = "sub", unsubscribe = "unsub", item = Vec)] - fn sub(&self, hash: Input); + async fn sub(&self, hash: Input); } /// Trait to ensure that the trait bounds are correct. @@ -168,15 +168,22 @@ mod rpc_impl { Ok(10u16) } - fn sub(&self, mut sink: SubscriptionSink) -> SubscriptionResult { - let _ = sink.send(&"Response_A"); - let _ = sink.send(&"Response_B"); + async fn sub(&self, pending: PendingSubscriptionSink) -> SubscriptionResult { + let sink = pending.accept().await.unwrap(); + + let _ = sink.send(SubscriptionMessage::from_json(&"Response_A").unwrap()).await; + let _ = sink.send(SubscriptionMessage::from_json(&"Response_B").unwrap()).await; + Ok(()) } - fn sub_with_params(&self, mut sink: SubscriptionSink, val: u32) -> SubscriptionResult { - let _ = sink.send(&val); - let _ = sink.send(&val); + async fn sub_with_params(&self, pending: PendingSubscriptionSink, val: u32) -> SubscriptionResult { + let sink = pending.accept().await.unwrap(); + let msg = SubscriptionMessage::from_json(&val).unwrap(); + + let _ = sink.send(msg.clone()).await; + let _ = sink.send(msg).await; + Ok(()) } } @@ -190,8 +197,11 @@ mod rpc_impl { #[async_trait] impl OnlyGenericSubscriptionServer for RpcServerImpl { - fn sub(&self, mut sink: SubscriptionSink, _: String) -> SubscriptionResult { - let _ = sink.send(&"hello"); + async fn sub(&self, pending: PendingSubscriptionSink, _: String) -> SubscriptionResult { + let sink = pending.accept().await.unwrap(); + let msg = SubscriptionMessage::from_json(&"hello").unwrap(); + let _ = sink.send(msg).await.unwrap(); + Ok(()) } } @@ -261,7 +271,7 @@ async fn macro_optional_param_parsing() { // Named params using a map let (resp, _) = module - .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_optional_params","params":{"a":22,"c":50},"id":0}"#) + .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_optional_params","params":{"a":22,"c":50},"id":0}"#, 1) .await .unwrap(); assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":"Called with: 22, None, Some(50)","id":0}"#); @@ -278,10 +288,12 @@ async fn macro_lifetimes_parsing() { #[tokio::test] async fn macro_zero_copy_cow() { + init_logger(); + let module = RpcServerImpl.into_rpc(); let (resp, _) = module - .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_zero_copy_cow","params":["foo", "bar"],"id":0}"#) + .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_zero_copy_cow","params":["foo", "bar"],"id":0}"#, 1) .await .unwrap(); @@ -290,7 +302,7 @@ async fn macro_zero_copy_cow() { // serde_json will have to allocate a new string to replace `\t` with byte 0x09 (tab) let (resp, _) = module - .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_zero_copy_cow","params":["\tfoo", "\tbar"],"id":0}"#) + .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_zero_copy_cow","params":["\tfoo", "\tbar"],"id":0}"#, 1) .await .unwrap(); assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":"Zero copy params: false, false","id":0}"#); @@ -300,7 +312,7 @@ async fn macro_zero_copy_cow() { #[cfg(not(target_os = "macos"))] #[tokio::test] async fn multiple_blocking_calls_overlap() { - use jsonrpsee::types::EmptyServerParams; + use jsonrpsee::core::EmptyServerParams; use std::time::{Duration, Instant}; let module = RpcServerImpl.into_rpc(); diff --git a/tests/tests/resource_limiting.rs b/tests/tests/resource_limiting.rs deleted file mode 100644 index 14bfd11a2b..0000000000 --- a/tests/tests/resource_limiting.rs +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use std::net::SocketAddr; -use std::time::Duration; - -use futures::StreamExt; -use jsonrpsee::core::client::{ClientT, SubscriptionClientT}; -use jsonrpsee::core::params::ArrayParams; -use jsonrpsee::core::Error; -use jsonrpsee::http_client::HttpClientBuilder; -use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::{ServerBuilder, ServerHandle}; -use jsonrpsee::types::error::CallError; -use jsonrpsee::types::SubscriptionResult; -use jsonrpsee::ws_client::WsClientBuilder; -use jsonrpsee::{rpc_params, RpcModule, SubscriptionSink}; -use tokio::time::{interval, sleep}; -use tokio_stream::wrappers::IntervalStream; - -fn module_manual() -> Result, Error> { - let mut module = RpcModule::new(()); - - module.register_async_method("say_hello", |_, _| async move { - sleep(Duration::from_millis(50)).await; - Result::<_, Error>::Ok("hello") - })?; - - module - .register_async_method("expensive_call", |_, _| async move { - sleep(Duration::from_millis(50)).await; - Result::<_, Error>::Ok("hello expensive call") - })? - .resource("CPU", 3)?; - - module - .register_async_method("memory_hog", |_, _| async move { - sleep(Duration::from_millis(50)).await; - Result::<_, Error>::Ok("hello memory hog") - })? - .resource("CPU", 0)? - .resource("MEM", 8)?; - - // Drop the `SubscriptionSink` to cause the internal `ResourceGuard` allocated per subscription call - // to get dropped. This is the equivalent of not having any resource limits (ie, sink is never used). - module - .register_subscription("subscribe_hello", "s_hello", "unsubscribe_hello", move |_, mut sink, _| { - sink.accept()?; - Ok(()) - })? - .resource("SUB", 3)?; - - // Keep the `SubscriptionSink` alive for a bit to validate that `ResourceGuard` is alive - // and the subscription method gets limited. - module - .register_subscription("subscribe_hello_limit", "s_hello", "unsubscribe_hello_limit", move |_, mut sink, _| { - tokio::spawn(async move { - for val in 0..10 { - // Sink is accepted on the first `send` call. - sink.send(&val).unwrap(); - sleep(Duration::from_secs(1)).await; - } - }); - - Ok(()) - })? - .resource("SUB", 3)?; - - Ok(module) -} - -fn module_macro() -> RpcModule<()> { - #[rpc(server)] - pub trait Rpc { - #[method(name = "say_hello")] - async fn hello(&self) -> Result<&'static str, Error> { - sleep(Duration::from_millis(50)).await; - Ok("hello") - } - - #[method(name = "expensive_call", resources("CPU" = 3))] - async fn expensive(&self) -> Result<&'static str, Error> { - sleep(Duration::from_millis(50)).await; - Ok("hello expensive call") - } - - #[method(name = "memory_hog", resources("CPU" = 0, "MEM" = 8))] - async fn memory(&self) -> Result<&'static str, Error> { - sleep(Duration::from_millis(50)).await; - Ok("hello memory hog") - } - - #[subscription(name = "subscribe_hello", item = String, resources("SUB" = 3))] - fn sub_hello(&self); - - #[subscription(name = "subscribe_hello_limit", item = String, resources("SUB" = 3))] - fn sub_hello_limit(&self); - } - - impl RpcServer for () { - fn sub_hello(&self, mut sink: SubscriptionSink) -> SubscriptionResult { - sink.accept()?; - Ok(()) - } - - fn sub_hello_limit(&self, mut sink: SubscriptionSink) -> SubscriptionResult { - tokio::spawn(async move { - let interval = interval(Duration::from_secs(1)); - let stream = IntervalStream::new(interval).map(move |_| 1); - - sink.pipe_from_stream(stream).await; - }); - - Ok(()) - } - } - - ().into_rpc() -} - -async fn websocket_server(module: RpcModule<()>) -> Result<(SocketAddr, ServerHandle), Error> { - let server = ServerBuilder::default() - .register_resource("CPU", 6, 2)? - .register_resource("MEM", 10, 1)? - .register_resource("SUB", 6, 1)? - .build("127.0.0.1:0") - .await?; - - let addr = server.local_addr()?; - let handle = server.start(module)?; - - Ok((addr, handle)) -} - -async fn http_server(module: RpcModule<()>) -> Result<(SocketAddr, ServerHandle), Error> { - let server = ServerBuilder::default() - .register_resource("CPU", 6, 2)? - .register_resource("MEM", 10, 1)? - .register_resource("SUB", 6, 1)? - .build("127.0.0.1:0") - .await?; - - let addr = server.local_addr()?; - let handle = server.start(module)?; - - Ok((addr, handle)) -} - -fn assert_server_busy(fail: Result) { - match fail { - Err(Error::Call(CallError::Custom(err))) => { - assert_eq!(err.code(), -32604); - assert_eq!(err.message(), "Server is busy, try again later"); - } - fail => panic!("Expected error, got: {:?}", fail), - } -} - -async fn run_tests_on_ws_server(server_addr: SocketAddr, server_handle: ServerHandle) { - let server_url = format!("ws://{}", server_addr); - let client = WsClientBuilder::default().build(&server_url).await.unwrap(); - - // 2 CPU units (default) per call, so 4th call exceeds cap - let (pass1, pass2, pass3, fail) = tokio::join!( - client.request::("say_hello", rpc_params!()), - client.request::("say_hello", rpc_params![]), - client.request::("say_hello", rpc_params![]), - client.request::("say_hello", rpc_params![]), - ); - - assert!(pass1.is_ok()); - assert!(pass2.is_ok()); - assert!(pass3.is_ok()); - assert_server_busy(fail); - - // 3 CPU units per call, so 3rd call exceeds CPU cap, but we can still get on MEM - let (pass_cpu1, pass_cpu2, fail_cpu, pass_mem, fail_mem) = tokio::join!( - client.request::("expensive_call", rpc_params![]), - client.request::("expensive_call", rpc_params![]), - client.request::("expensive_call", rpc_params![]), - client.request::("memory_hog", rpc_params![]), - client.request::("memory_hog", rpc_params![]), - ); - - assert!(pass_cpu1.is_ok()); - assert!(pass_cpu2.is_ok()); - assert_server_busy(fail_cpu); - assert!(pass_mem.is_ok()); - assert_server_busy(fail_mem); - - // If we issue multiple subscription requests at the same time from the same client, - // but the subscriptions drop their sinks when the subscription has been accepted or rejected. - // - // Thus, we can't assume that all subscriptions drop their resources instantly anymore. - let (pass1, pass2) = tokio::join!( - client.subscribe::("subscribe_hello", rpc_params![], "unsubscribe_hello"), - client.subscribe::("subscribe_hello", rpc_params![], "unsubscribe_hello"), - ); - - assert!(pass1.is_ok()); - assert!(pass2.is_ok()); - - // 3 CPU units (manually set for subscriptions) per call, so 3th call exceeds cap - let (pass1, pass2, fail) = tokio::join!( - client.subscribe::("subscribe_hello_limit", rpc_params![], "unsubscribe_hello_limit"), - client.subscribe::("subscribe_hello_limit", rpc_params![], "unsubscribe_hello_limit"), - client.subscribe::("subscribe_hello_limit", rpc_params![], "unsubscribe_hello_limit"), - ); - - assert!(pass1.is_ok()); - assert!(pass2.is_ok()); - assert_server_busy(fail); - - server_handle.stop().unwrap(); - server_handle.stopped().await; -} - -async fn run_tests_on_http_server(server_addr: SocketAddr, server_handle: ServerHandle) { - let server_url = format!("http://{}", server_addr); - let client = HttpClientBuilder::default().build(&server_url).unwrap(); - - // 2 CPU units (default) per call, so 4th call exceeds cap - let (a, b, c, d) = tokio::join!( - client.request::("say_hello", rpc_params![]), - client.request::("say_hello", rpc_params![]), - client.request::("say_hello", rpc_params![]), - client.request::("say_hello", rpc_params![]), - ); - - // HTTP does not guarantee ordering - let mut passes = 0; - - for result in [a, b, c, d] { - if result.is_ok() { - passes += 1; - } else { - assert_server_busy(result); - } - } - - assert_eq!(passes, 3); - - server_handle.stop().unwrap(); - server_handle.stopped().await; -} - -#[tokio::test] -async fn ws_server_with_manual_module() { - let (server_addr, server_handle) = websocket_server(module_manual().unwrap()).await.unwrap(); - - run_tests_on_ws_server(server_addr, server_handle).await; -} - -#[tokio::test] -async fn ws_server_with_macro_module() { - let (server_addr, server_handle) = websocket_server(module_macro()).await.unwrap(); - - run_tests_on_ws_server(server_addr, server_handle).await; -} - -#[tokio::test] -async fn http_server_with_manual_module() { - let (server_addr, server_handle) = http_server(module_manual().unwrap()).await.unwrap(); - - run_tests_on_http_server(server_addr, server_handle).await; -} - -#[tokio::test] -async fn http_server_with_macro_module() { - let (server_addr, server_handle) = http_server(module_macro()).await.unwrap(); - - run_tests_on_http_server(server_addr, server_handle).await; -} diff --git a/tests/tests/rpc_module.rs b/tests/tests/rpc_module.rs index 28b6b38653..01f46447a7 100644 --- a/tests/tests/rpc_module.rs +++ b/tests/tests/rpc_module.rs @@ -26,16 +26,18 @@ mod helpers; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::time::Duration; use futures::StreamExt; -use helpers::init_logger; -use jsonrpsee::core::error::{Error, SubscriptionClosed}; +use helpers::{init_logger, pipe_from_stream_and_drop}; +use jsonrpsee::core::error::{Error, SubscriptionCallbackError}; use jsonrpsee::core::server::rpc_module::*; +use jsonrpsee::core::EmptyServerParams; use jsonrpsee::types::error::{CallError, ErrorCode, ErrorObject, PARSE_ERROR_CODE}; -use jsonrpsee::types::{EmptyServerParams, Params}; +use jsonrpsee::types::{ErrorObjectOwned, Params}; use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; @@ -72,7 +74,7 @@ fn flatten_rpc_modules() { #[test] fn rpc_context_modules_can_register_subscriptions() { let mut cxmodule = RpcModule::new(()); - cxmodule.register_subscription("hi", "hi", "goodbye", |_, _, _| Ok(())).unwrap(); + cxmodule.register_subscription("hi", "hi", "goodbye", |_, _, _| async { Ok(()) }).unwrap(); assert!(cxmodule.method("hi").is_some()); assert!(cxmodule.method("goodbye").is_some()); @@ -233,24 +235,26 @@ async fn subscribing_without_server() { let mut module = RpcModule::new(()); module - .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| { + .register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, _| async move { let mut stream_data = vec!['0', '1', '2']; - sink.accept()?; - tokio::spawn(async move { - while let Some(letter) = stream_data.pop() { - tracing::debug!("This is your friendly subscription sending data."); - let _ = sink.send(&letter); - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - } - let close = ErrorObject::borrowed(0, &"closed successfully", None); - sink.close(close.into_owned()); - }); + let sink = pending.accept().await.unwrap(); + + while let Some(letter) = stream_data.pop() { + tracing::debug!("This is your friendly subscription sending data."); + let msg = SubscriptionMessage::from_json(&letter).unwrap(); + let _ = sink.send(msg).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } + let close = ErrorObject::borrowed(0, &"closed successfully", None); + let _ = sink.close(close.into_owned()).await; + Ok(()) }) .unwrap(); - let mut my_sub = module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap(); + let mut my_sub = module.subscribe_unbounded("my_sub", EmptyServerParams::new()).await.unwrap(); + for i in (0..=2).rev() { let (val, id) = my_sub.next::().await.unwrap().unwrap(); assert_eq!(val, std::char::from_digit(i, 10).unwrap()); @@ -266,30 +270,28 @@ async fn close_test_subscribing_without_server() { let mut module = RpcModule::new(()); module - .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| { - sink.accept()?; - - tokio::spawn(async move { - // make sure to only send one item - sink.send(&"lo").unwrap(); - while !sink.is_closed() { - tracing::debug!("[test] Sink is open, sleeping"); - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - } - // Get the close reason. - if !sink.send(&"lo").expect("str serializable; qed") { - sink.close(SubscriptionClosed::RemotePeerAborted); - } - }); + .register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, _| async move { + let sink = pending.accept().await.unwrap(); + let msg = SubscriptionMessage::from_json(&"lo").unwrap(); + + // make sure to only send one item + sink.send(msg.clone()).await.unwrap(); + sink.closed().await; + + match sink.send(msg).await { + Ok(_) => panic!("The sink should be closed"), + Err(DisconnectError(_)) => {} + } Ok(()) }) .unwrap(); - let mut my_sub = module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap(); + let mut my_sub = module.subscribe_unbounded("my_sub", EmptyServerParams::new()).await.unwrap(); let (val, id) = my_sub.next::().await.unwrap().unwrap(); assert_eq!(&val, "lo"); assert_eq!(&id, my_sub.subscription_id()); - let mut my_sub2 = std::mem::ManuallyDrop::new(module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap()); + let mut my_sub2 = + std::mem::ManuallyDrop::new(module.subscribe_unbounded("my_sub", EmptyServerParams::new()).await.unwrap()); // Close the subscription to ensure it doesn't return any items. my_sub.close(); @@ -313,23 +315,25 @@ async fn close_test_subscribing_without_server() { async fn subscribing_without_server_bad_params() { let mut module = RpcModule::new(()); module - .register_subscription("my_sub", "my_sub", "my_unsub", |params, mut sink, _| { + .register_subscription("my_sub", "my_sub", "my_unsub", |params, pending, _| async move { let p = match params.one::() { Ok(p) => p, Err(e) => { - let err: Error = e.into(); - let _ = sink.reject(err); - return Ok(()); + let err: ErrorObjectOwned = e.into(); + let _ = pending.reject(err).await; + return Err(SubscriptionCallbackError::None); } }; - sink.accept()?; - sink.send(&p).unwrap(); + let sink = pending.accept().await.unwrap(); + let msg = SubscriptionMessage::from_json(&p).unwrap(); + sink.send(msg).await.unwrap(); + Ok(()) }) .unwrap(); - let sub = module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap_err(); + let sub = module.subscribe_unbounded("my_sub", EmptyServerParams::new()).await.unwrap_err(); assert!( matches!(sub, Error::Call(CallError::Custom(e)) if e.message().contains("invalid length 0, expected an array of length 1 at line 1 column 2") && e.code() == ErrorCode::InvalidParams.code()) @@ -340,33 +344,30 @@ async fn subscribing_without_server_bad_params() { async fn subscribe_unsubscribe_without_server() { let mut module = RpcModule::new(()); module - .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| { + .register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, _| async move { let interval = interval(Duration::from_millis(200)); let stream = IntervalStream::new(interval).map(move |_| 1); + pipe_from_stream_and_drop(pending, stream).await?; - tokio::spawn(async move { - sink.pipe_from_stream(stream).await; - }); Ok(()) }) .unwrap(); async fn subscribe_and_assert(module: &RpcModule<()>) { - let sub = module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap(); - + let sub = module.subscribe_unbounded("my_sub", EmptyServerParams::new()).await.unwrap(); let ser_id = serde_json::to_string(sub.subscription_id()).unwrap(); assert!(!sub.is_closed()); // Unsubscribe should be valid. let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id); - let (resp, _) = module.raw_json_request(&unsub_req).await.unwrap(); + let (resp, _) = module.raw_json_request(&unsub_req, 1).await.unwrap(); assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":true,"id":1}"#); // Unsubscribe already performed; should be error. let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id); - let (resp, _) = module.raw_json_request(&unsub_req).await.unwrap(); + let (resp, _) = module.raw_json_request(&unsub_req, 2).await.unwrap(); assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":false,"id":1}"#); } @@ -378,82 +379,104 @@ async fn subscribe_unsubscribe_without_server() { } #[tokio::test] -async fn empty_subscription_without_server() { +async fn rejected_subscription_without_server() { let mut module = RpcModule::new(()); module - .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut _sink, _| { - // Sink was never accepted or rejected. Expected to return `InvalidParams`. + .register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, _| async move { + let err = ErrorObject::borrowed(PARSE_ERROR_CODE, &"rejected", None); + let _ = pending.reject(err.into_owned()).await; + Ok(()) }) .unwrap(); - let sub_err = module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap_err(); + let sub_err = module.subscribe_unbounded("my_sub", EmptyServerParams::new()).await.unwrap_err(); assert!( - matches!(sub_err, Error::Call(CallError::Custom(e)) if e.message().contains("Invalid params") && e.code() == ErrorCode::InvalidParams.code()) + matches!(sub_err, Error::Call(CallError::Custom(e)) if e.message().contains("rejected") && e.code() == PARSE_ERROR_CODE) ); } #[tokio::test] -async fn rejected_subscription_without_server() { +async fn reject_works() { let mut module = RpcModule::new(()); module - .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| { + .register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, _| async move { let err = ErrorObject::borrowed(PARSE_ERROR_CODE, &"rejected", None); - sink.reject(err.into_owned())?; + let res = pending.reject(err.into_owned()).await; + assert!(matches!(res, Ok(()))); + Ok(()) }) .unwrap(); - let sub_err = module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap_err(); + let sub_err = module.subscribe_unbounded("my_sub", EmptyServerParams::new()).await.unwrap_err(); assert!( matches!(sub_err, Error::Call(CallError::Custom(e)) if e.message().contains("rejected") && e.code() == PARSE_ERROR_CODE) ); } #[tokio::test] -async fn accepted_twice_subscription_without_server() { - let mut module = RpcModule::new(()); - module - .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| { - let res = sink.accept(); - assert!(matches!(res, Ok(_))); - - let res = sink.accept(); - assert!(matches!(res, Err(_))); - - let err = ErrorObject::borrowed(PARSE_ERROR_CODE, &"rejected", None); - let res = sink.reject(err.into_owned()); - assert!(matches!(res, Err(_))); - - Ok(()) - }) - .unwrap(); +async fn bounded_subscription_works() { + init_logger(); - let _ = module.subscribe("my_sub", EmptyServerParams::new()).await.expect("Subscription should not fail"); -} + let (tx, mut rx) = mpsc::unbounded_channel::(); + let mut module = RpcModule::new(tx); -#[tokio::test] -async fn reject_twice_subscription_without_server() { - let mut module = RpcModule::new(()); module - .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| { - let err = ErrorObject::borrowed(PARSE_ERROR_CODE, &"rejected", None); - let res = sink.reject(err.into_owned()); - assert!(matches!(res, Ok(()))); + .register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, mut ctx| async move { + println!("accept"); + let mut sink = pending.accept().await?; + + let mut stream = IntervalStream::new(interval(std::time::Duration::from_millis(100))) + .enumerate() + .map(|(n, _)| n) + .take(6); + let fail = std::sync::Arc::make_mut(&mut ctx); + let mut buf = VecDeque::new(); + + while let Some(n) = stream.next().await { + let msg = SubscriptionMessage::from_json(&n).expect("usize infallible; qed"); + + match sink.try_send(msg) { + Err(TrySendError::Closed(_)) => panic!("This is a bug"), + Err(TrySendError::Full(m)) => { + buf.push_back(m); + } + Ok(_) => (), + } + } - let err = ErrorObject::borrowed(PARSE_ERROR_CODE, &"rejected", None); - let res = sink.reject(err.into_owned()); - assert!(matches!(res, Err(_))); + if !buf.is_empty() { + fail.send("Full".to_string()).unwrap(); + } + + while let Some(m) = buf.pop_front() { + match sink.try_send(m) { + Err(TrySendError::Closed(_)) => panic!("This is a bug"), + Err(TrySendError::Full(m)) => { + buf.push_front(m); + } + Ok(_) => (), + } - let res = sink.accept(); - assert!(matches!(res, Err(_))); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } Ok(()) }) .unwrap(); - let sub_err = module.subscribe("my_sub", EmptyServerParams::new()).await.unwrap_err(); - assert!( - matches!(sub_err, Error::Call(CallError::Custom(e)) if e.message().contains("rejected") && e.code() == PARSE_ERROR_CODE) - ); + // create a bounded subscription and don't poll it + // after 3 items has been produced messages will be dropped. + let mut sub = module.subscribe("my_sub", EmptyServerParams::new(), 3).await.unwrap(); + + // assert that some items couldn't be sent. + assert_eq!(rx.recv().await, Some("Full".to_string())); + + // the subscription should continue produce items are consumed + // and the failed messages should be able to go succeed. + for exp in 0..6 { + let (item, _) = sub.next::().await.unwrap().unwrap(); + assert_eq!(item, exp); + } } diff --git a/types/Cargo.toml b/types/Cargo.toml index ae61b067fe..f1d7106909 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -15,4 +15,4 @@ beef = { version = "0.5.1", features = ["impl_serde"] } tracing = { version = "0.1.34", default-features = false } serde = { version = "1", default-features = false, features = ["derive"] } serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"] } -thiserror = "1.0" +thiserror = "1.0" \ No newline at end of file diff --git a/types/src/error.rs b/types/src/error.rs index 30b4e78ec1..8a57debe80 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -79,50 +79,6 @@ impl<'a> fmt::Display for ErrorResponse<'a> { write!(f, "{}", serde_json::to_string(&self).expect("infallible; qed")) } } -/// The return type of the subscription's method for the rpc server implementation. -/// -/// **Note**: The error does not contain any data and is discarded on drop. -pub type SubscriptionResult = Result<(), SubscriptionEmptyError>; - -/// The error returned by the subscription's method for the rpc server implementation. -/// -/// It contains no data, and neither is the error utilized. It provides an abstraction to make the -/// API more ergonomic while handling errors that may occur during the subscription callback. -#[derive(Debug, Clone, Copy)] -pub struct SubscriptionEmptyError; - -impl From for SubscriptionEmptyError { - fn from(_: anyhow::Error) -> Self { - SubscriptionEmptyError - } -} - -impl From for SubscriptionEmptyError { - fn from(_: CallError) -> Self { - SubscriptionEmptyError - } -} - -impl<'a> From> for SubscriptionEmptyError { - fn from(_: ErrorObject<'a>) -> Self { - SubscriptionEmptyError - } -} - -impl From for SubscriptionEmptyError { - fn from(_: SubscriptionAcceptRejectError) -> Self { - SubscriptionEmptyError - } -} - -/// The error returned while accepting or rejecting a subscription. -#[derive(Debug, Copy, Clone)] -pub enum SubscriptionAcceptRejectError { - /// The method was already called. - AlreadyCalled, - /// The remote peer closed the connection or called the unsubscribe method. - RemotePeerAborted, -} /// Owned variant of [`ErrorObject`]. pub type ErrorObjectOwned = ErrorObject<'static>; @@ -230,10 +186,6 @@ pub const SERVER_IS_BUSY_CODE: i32 = -32604; pub const CALL_EXECUTION_FAILED_CODE: i32 = -32000; /// Unknown error. pub const UNKNOWN_ERROR_CODE: i32 = -32001; -/// Subscription got closed by the server. -pub const SUBSCRIPTION_CLOSED: i32 = -32003; -/// Subscription got closed by the server. -pub const SUBSCRIPTION_CLOSED_WITH_ERROR: i32 = -32004; /// Batched requests are not supported by the server. pub const BATCHES_NOT_SUPPORTED_CODE: i32 = -32005; /// Subscription limit per connection was exceeded. diff --git a/types/src/lib.rs b/types/src/lib.rs index 528482747d..b68af06059 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -43,10 +43,7 @@ pub mod response; /// JSON-RPC response error object related types. pub mod error; -pub use error::{ErrorObject, ErrorObjectOwned, ErrorResponse, SubscriptionEmptyError, SubscriptionResult}; +pub use error::{ErrorObject, ErrorObjectOwned, ErrorResponse}; pub use params::{Id, Params, ParamsSequence, SubscriptionId, TwoPointZero}; pub use request::{InvalidRequest, Notification, NotificationSer, Request, RequestSer}; pub use response::{Response, SubscriptionPayload, SubscriptionResponse}; - -/// Empty server `RpcParams` type to use while registering modules. -pub type EmptyServerParams = Vec<()>;