-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(rpc module): stream API
for SubscriptionSink
#639
Changes from 8 commits
72024ea
fe176df
eebc73d
a29f988
6aa22e2
9bdea0d
79a8e55
7e81acf
6982598
d589d24
92bb97e
b9598ca
bce48da
6d16927
a47a965
07f80e2
bedf808
03ef669
ef7b965
f7aa544
72f4726
7a20a52
de4ac6a
fa15cfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ use crate::to_json_raw_value; | |
use crate::traits::{IdProvider, ToRpcParams}; | ||
use beef::Cow; | ||
use futures_channel::{mpsc, oneshot}; | ||
use futures_util::{future::BoxFuture, FutureExt, StreamExt}; | ||
use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt}; | ||
use jsonrpsee_types::error::{invalid_subscription_err, ErrorCode, CALL_EXECUTION_FAILED_CODE}; | ||
use jsonrpsee_types::{ | ||
Id, Params, Request, Response, SubscriptionId as RpcSubscriptionId, SubscriptionPayload, SubscriptionResponse, | ||
|
@@ -51,16 +51,32 @@ use serde::{de::DeserializeOwned, Serialize}; | |
/// implemented as a function pointer to a `Fn` function taking four arguments: | ||
/// the `id`, `params`, a channel the function uses to communicate the result (or error) | ||
/// back to `jsonrpsee`, and the connection ID (useful for the websocket transport). | ||
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId, &dyn IdProvider) -> bool>; | ||
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, Option<ConnState>) -> bool>; | ||
/// Similar to [`SyncMethod`], but represents an asynchronous handler and takes an additional argument containing a [`ResourceGuard`] if configured. | ||
pub type AsyncMethod<'a> = Arc< | ||
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, Option<ResourceGuard>, &dyn IdProvider) -> BoxFuture<'a, bool>, | ||
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, ConnectionId, Option<ResourceGuard>) -> BoxFuture<'a, bool>, | ||
>; | ||
/// Connection ID, used for stateful protocol such as WebSockets. | ||
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value. | ||
pub type ConnectionId = usize; | ||
/// Raw RPC response. | ||
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, mpsc::UnboundedSender<String>); | ||
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, async_channel::Sender<()>); | ||
|
||
/// Data for stateful connections. | ||
pub struct ConnState<'a> { | ||
/// Connection ID | ||
pub conn_id: ConnectionId, | ||
/// Channel to know whether the connection is closed or not. | ||
pub close: async_channel::Receiver<()>, | ||
/// ID provider. | ||
pub id_provider: &'a dyn IdProvider, | ||
} | ||
|
||
impl<'a> std::fmt::Debug for ConnState<'a> { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close).finish() | ||
} | ||
} | ||
|
||
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, oneshot::Receiver<()>)>>>; | ||
|
||
|
@@ -157,24 +173,23 @@ impl MethodCallback { | |
pub fn execute( | ||
&self, | ||
sink: &MethodSink, | ||
conn_state: Option<ConnState>, | ||
req: Request<'_>, | ||
conn_id: ConnectionId, | ||
claimed: Option<ResourceGuard>, | ||
id_gen: &dyn IdProvider, | ||
) -> MethodResult<bool> { | ||
let id = req.id.clone(); | ||
let params = Params::new(req.params.map(|params| params.get())); | ||
|
||
let result = match &self.callback { | ||
MethodKind::Sync(callback) => { | ||
tracing::trace!( | ||
"[MethodCallback::execute] Executing sync callback, params={:?}, req.id={:?}, conn_id={:?}", | ||
"[MethodCallback::execute] Executing sync callback, params={:?}, req.id={:?}, conn_state={:?}", | ||
params, | ||
id, | ||
conn_id | ||
conn_state | ||
niklasad1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
); | ||
|
||
let result = (callback)(id, params, sink, conn_id, id_gen); | ||
let result = (callback)(id, params, sink, conn_state); | ||
|
||
// Release claimed resources | ||
drop(claimed); | ||
|
@@ -185,14 +200,15 @@ impl MethodCallback { | |
let sink = sink.clone(); | ||
let params = params.into_owned(); | ||
let id = id.into_owned(); | ||
let conn_id = conn_state.map(|s| s.conn_id).unwrap_or(0); | ||
tracing::trace!( | ||
"[MethodCallback::execute] Executing async callback, params={:?}, req.id={:?}, conn_id={:?}", | ||
"[MethodCallback::execute] Executing async callback, params={:?}, req.id={:?}, conn_state={:?}", | ||
params, | ||
id, | ||
conn_id | ||
conn_id, | ||
); | ||
|
||
MethodResult::Async((callback)(id, params, sink, claimed, id_gen)) | ||
MethodResult::Async((callback)(id, params, sink, conn_id, claimed)) | ||
} | ||
}; | ||
|
||
|
@@ -307,16 +323,10 @@ impl Methods { | |
} | ||
|
||
/// Attempt to execute a callback, sending the resulting JSON (success or error) to the specified sink. | ||
pub fn execute( | ||
&self, | ||
sink: &MethodSink, | ||
req: Request, | ||
conn_id: ConnectionId, | ||
id_gen: &dyn IdProvider, | ||
) -> MethodResult<bool> { | ||
pub fn execute(&self, sink: &MethodSink, conn_state: Option<ConnState>, req: Request) -> MethodResult<bool> { | ||
tracing::trace!("[Methods::execute] Executing request: {:?}", req); | ||
match self.callbacks.get(&*req.method) { | ||
Some(callback) => callback.execute(sink, req, conn_id, None, id_gen), | ||
Some(callback) => callback.execute(sink, conn_state, req, None), | ||
None => { | ||
sink.send_error(req.id, ErrorCode::MethodNotFound.into()); | ||
MethodResult::Sync(false) | ||
|
@@ -329,15 +339,14 @@ impl Methods { | |
pub fn execute_with_resources<'r>( | ||
&self, | ||
sink: &MethodSink, | ||
conn_state: Option<ConnState<'r>>, | ||
req: Request<'r>, | ||
conn_id: ConnectionId, | ||
resources: &Resources, | ||
id_gen: &dyn IdProvider, | ||
) -> Result<(&'static str, MethodResult<bool>), Cow<'r, str>> { | ||
tracing::trace!("[Methods::execute_with_resources] Executing request: {:?}", req); | ||
match self.callbacks.get_key_value(&*req.method) { | ||
Some((&name, callback)) => match callback.claim(&req.method, resources) { | ||
Ok(guard) => Ok((name, callback.execute(sink, req, conn_id, Some(guard), id_gen))), | ||
Ok(guard) => Ok((name, callback.execute(sink, conn_state, req, Some(guard)))), | ||
Err(err) => { | ||
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err); | ||
sink.send_error(req.id, ErrorCode::ServerIsBusy.into()); | ||
|
@@ -425,14 +434,17 @@ impl Methods { | |
/// Wrapper over [`Methods::execute`] to execute a callback. | ||
async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse { | ||
let (tx, mut rx) = mpsc::unbounded(); | ||
let sink = MethodSink::new(tx.clone()); | ||
let sink = MethodSink::new(tx); | ||
let (close_tx, close_rx) = async_channel::unbounded(); | ||
|
||
if let MethodResult::Async(fut) = self.execute(&sink, req, 0, &RandomIntegerIdProvider) { | ||
let conn_state = Some(ConnState { conn_id: 0, close: close_rx, id_provider: &RandomIntegerIdProvider }); | ||
|
||
if let MethodResult::Async(fut) = self.execute(&sink, conn_state, req) { | ||
niklasad1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fut.await; | ||
} | ||
|
||
let resp = rx.next().await.expect("tx and rx still alive; qed"); | ||
(resp, rx, tx) | ||
(resp, rx, close_tx) | ||
} | ||
|
||
/// Helper to create a subscription on the `RPC module` without having to spin up a server. | ||
|
@@ -527,7 +539,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
let ctx = self.ctx.clone(); | ||
let callback = self.methods.verify_and_insert( | ||
method_name, | ||
MethodCallback::new_sync(Arc::new(move |id, params, sink, _, _| match callback(params, &*ctx) { | ||
MethodCallback::new_sync(Arc::new(move |id, params, sink, _| match callback(params, &*ctx) { | ||
Ok(res) => sink.send_response(id, res), | ||
Err(err) => sink.send_call_error(id, err), | ||
})), | ||
|
@@ -550,7 +562,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
let ctx = self.ctx.clone(); | ||
let callback = self.methods.verify_and_insert( | ||
method_name, | ||
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| { | ||
MethodCallback::new_async(Arc::new(move |id, params, sink, _, claimed| { | ||
let ctx = ctx.clone(); | ||
let future = async move { | ||
let result = match callback(params, ctx).await { | ||
|
@@ -585,7 +597,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
let ctx = self.ctx.clone(); | ||
let callback = self.methods.verify_and_insert( | ||
method_name, | ||
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| { | ||
MethodCallback::new_async(Arc::new(move |id, params, sink, _, claimed| { | ||
let ctx = ctx.clone(); | ||
|
||
tokio::task::spawn_blocking(move || { | ||
|
@@ -671,12 +683,13 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
let subscribers = subscribers.clone(); | ||
self.methods.mut_callbacks().insert( | ||
subscribe_method_name, | ||
MethodCallback::new_sync(Arc::new(move |id, params, method_sink, conn_id, id_provider| { | ||
MethodCallback::new_sync(Arc::new(move |id, params, method_sink, conn| { | ||
let (conn_tx, conn_rx) = oneshot::channel::<()>(); | ||
let c = conn.expect("conn must be Some; this is bug"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it bust be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's way better but I think it requires additional callback types right? see my comment below. thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I had a look at the code to solve this properly at // server code
let result = match RpcModule::as_callback(&req.method) {
None => {
send_error(sink, msg),
continue
}
Some(MethodKind::Sync(cb)) => (cb)(id, params, sink, conn_id),
Some(MethodKind::Async(cb)) => (cb)(id, params, sink, resources).await,
// should be used for subscriptions...
// servers that don't support subscriptions should throw an error here...
Some(MethodKind::Subscription) => (cb)(id, params, sink, resources, conn_state),
};
// modify RpcModule::register_subscription
pub fn register_subscription<F>(
&mut self,
subscribe_method_name: &'static str,
notif_method_name: &'static str,
unsubscribe_method_name: &'static str,
callback: F,
) -> Result<(), Error> {
....
....
self.methods.mut_callbacks().insert(
subscribe_method_name,
MethodCallback::new_subscription(Arc::new(move |id, params, method_sink, conn| {
...
}
);
}
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one benefit is not to have to clone this "introduced channel connection_state" for every method instead only for subscriptions where it's actually used... not sure I liked how the old code abstracted this away There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I already thought about splitting async calls from subscriptions when doing my refactoring earlier, and for regular method calls returning values instead of passing in the sink as a parameter, I reckon that would make things much more readable and straight forward, and potentially make the final binary smaller. So if you want to go that route and add another enum variant I think that's cool, and I can do a PR that switches method calls to have a return value later on :). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, I already added the enum variant in this PR but just a hacky draft to check that it works and to show you and David what I had in mind :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds great to get rid of the sink for the synchronous calls. |
||
|
||
let sub_id = { | ||
let sub_id: RpcSubscriptionId = id_provider.next_id().into_owned(); | ||
let uniq_sub = SubscriptionKey { conn_id, sub_id: sub_id.clone() }; | ||
let sub_id: RpcSubscriptionId = c.id_provider.next_id().into_owned(); | ||
let uniq_sub = SubscriptionKey { conn_id: c.conn_id, sub_id: sub_id.clone() }; | ||
|
||
subscribers.lock().insert(uniq_sub, (method_sink.clone(), conn_rx)); | ||
|
||
|
@@ -687,9 +700,10 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
|
||
let sink = SubscriptionSink { | ||
inner: method_sink.clone(), | ||
close: c.close, | ||
method: notif_method_name, | ||
subscribers: subscribers.clone(), | ||
uniq_sub: SubscriptionKey { conn_id, sub_id }, | ||
uniq_sub: SubscriptionKey { conn_id: c.conn_id, sub_id }, | ||
is_connected: Some(conn_tx), | ||
}; | ||
if let Err(err) = callback(params, sink, ctx.clone()) { | ||
|
@@ -710,7 +724,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
{ | ||
self.methods.mut_callbacks().insert( | ||
unsubscribe_method_name, | ||
MethodCallback::new_sync(Arc::new(move |id, params, sink, conn_id, _| { | ||
MethodCallback::new_sync(Arc::new(move |id, params, sink, conn_state| { | ||
let c = conn_state.expect("conn must be Some; this is bug"); | ||
|
||
let sub_id = match params.one::<RpcSubscriptionId>() { | ||
Ok(sub_id) => sub_id, | ||
Err(_) => { | ||
|
@@ -727,7 +743,11 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
}; | ||
let sub_id = sub_id.into_owned(); | ||
|
||
if subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id: sub_id.clone() }).is_some() { | ||
if subscribers | ||
.lock() | ||
.remove(&SubscriptionKey { conn_id: c.conn_id, sub_id: sub_id.clone() }) | ||
.is_some() | ||
{ | ||
sink.send_response(id, "Unsubscribed") | ||
} else { | ||
let err = to_json_raw_value(&format!( | ||
|
@@ -764,6 +784,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> { | |
pub struct SubscriptionSink { | ||
/// Sink. | ||
inner: MethodSink, | ||
/// Close | ||
close: async_channel::Receiver<()>, | ||
/// MethodCallback. | ||
method: &'static str, | ||
/// Unique subscription. | ||
|
@@ -786,9 +808,30 @@ impl SubscriptionSink { | |
self.inner_send(msg).map_err(Into::into) | ||
} | ||
|
||
/// Consume the sink by passing a stream to be sent via the sink. | ||
pub async fn add_stream<S, T>(mut self, mut stream: S) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about calling this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair; I'd have liked Of your suggestions I like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, this method consumes a stream, feeding the items into the subscription? I guess I'd go with something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes I guess we could return a type that impls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like |
||
where | ||
S: Stream<Item = T> + Unpin, | ||
T: Serialize, | ||
{ | ||
loop { | ||
tokio::select! { | ||
Some(item) = stream.next() => { | ||
if let Err(Error::SubscriptionClosed(_)) = self.send(&item) { | ||
break; | ||
} | ||
}, | ||
// No messages should be sent over this channel (just ignore and continue) | ||
Some(_) = self.close.next() => {}, | ||
// Stream or connection was dropped => close stream. | ||
else => break, | ||
} | ||
} | ||
} | ||
|
||
/// Returns whether this channel is closed without needing a context. | ||
pub fn is_closed(&self) -> bool { | ||
self.inner.is_closed() | ||
self.inner.is_closed() || self.close.is_closed() | ||
niklasad1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
fn build_message<T: Serialize>(&self, result: &T) -> Result<String, Error> { | ||
|
@@ -806,7 +849,7 @@ impl SubscriptionSink { | |
self.inner.send_raw(msg).map_err(|_| Some(SubscriptionClosedReason::ConnectionReset)) | ||
} | ||
Some(_) => Err(Some(SubscriptionClosedReason::Unsubscribed)), | ||
// NOTE(niklasad1): this should be unreachble, after the first error is detected the subscription is closed. | ||
// NOTE(niklasad1): this should be unreachable, after the first error is detected the subscription is closed. | ||
None => Err(None), | ||
}; | ||
|
||
|
@@ -850,15 +893,15 @@ impl Drop for SubscriptionSink { | |
/// Wrapper struct that maintains a subscription "mainly" for testing. | ||
#[derive(Debug)] | ||
pub struct Subscription { | ||
tx: mpsc::UnboundedSender<String>, | ||
tx: async_channel::Sender<()>, | ||
rx: mpsc::UnboundedReceiver<String>, | ||
sub_id: RpcSubscriptionId<'static>, | ||
} | ||
|
||
impl Subscription { | ||
/// Close the subscription channel. | ||
pub fn close(&mut self) { | ||
self.tx.close_channel(); | ||
self.tx.close(); | ||
} | ||
|
||
/// Get the subscription ID | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,6 +359,44 @@ async fn ws_server_should_stop_subscription_after_client_drop() { | |
assert!(matches!(close_err.close_reason(), &SubscriptionClosedReason::ConnectionReset)); | ||
} | ||
|
||
#[tokio::test] | ||
async fn ws_server_cancels_stream_after_reset_conn() { | ||
use futures::{channel::mpsc, SinkExt, StreamExt}; | ||
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule}; | ||
|
||
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); | ||
let server_url = format!("ws://{}", server.local_addr().unwrap()); | ||
|
||
let (tx, mut rx) = mpsc::channel(1); | ||
let mut module = RpcModule::new(tx); | ||
|
||
module | ||
.register_subscription("subscribe_never_produce", "n", "unsubscribe_never_produce", |_, sink, mut tx| { | ||
// create stream that doesn't produce items. | ||
let stream = futures::stream::empty::<usize>(); | ||
tokio::spawn(async move { | ||
sink.add_stream(stream).await; | ||
let send_back = Arc::make_mut(&mut tx); | ||
send_back.feed(()).await.unwrap(); | ||
dvdplm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}); | ||
Ok(()) | ||
}) | ||
.unwrap(); | ||
|
||
server.start(module).unwrap(); | ||
|
||
let client = WsClientBuilder::default().build(&server_url).await.unwrap(); | ||
let _sub1: Subscription<usize> = | ||
client.subscribe("subscribe_never_produce", None, "unsubscribe_never_produce").await.unwrap(); | ||
let _sub2: Subscription<usize> = | ||
client.subscribe("subscribe_never_produce", None, "unsubscribe_never_produce").await.unwrap(); | ||
|
||
// terminate connection. | ||
drop(client); | ||
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped"); | ||
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped"); | ||
Comment on lines
+396
to
+397
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I felt like I understood the test until I got here. I thought There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha, maybe it's more clear as you described it but the test actually sends a message on the channel when the subscription terminated. the reason why is because the |
||
} | ||
|
||
#[tokio::test] | ||
async fn ws_batch_works() { | ||
let server_addr = websocket_server().await; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: this panics if count > usize::MAX / 2
but if we reach that we likely have other problems such as OOM :)