Skip to content

Commit

Permalink
Pubsub: Keep stream running after sink was closed.
Browse files Browse the repository at this point in the history
  • Loading branch information
nihohit committed Sep 14, 2024
1 parent 2e1b5c1 commit 636a749
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
15 changes: 11 additions & 4 deletions redis/src/aio/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,17 @@ impl PubSubSink {
{
let (sender, mut receiver) = unbounded_channel();
let sink = PipelineSink::new(sink_stream, messages_sender);
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(sink)
.map(|_| ());
let f = stream::poll_fn(move |cx| {
let res = receiver.poll_recv(cx);
match res {
// We don't want to stop the backing task for the stream, even if the sink was closed.
Poll::Ready(None) => Poll::Pending,
_ => res,
}
})
.map(Ok)
.forward(sink)
.map(|_| ());
(PubSubSink { sender }, f)
}

Expand Down
50 changes: 37 additions & 13 deletions redis/tests/test_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,8 +686,17 @@ mod basic_async {
let mut publish_conn = ctx.async_connection().await?;
let _: () = publish_conn.publish("phonewave", "banana").await?;

let msg_payload: String = pubsub_stream.next().await.unwrap().get_payload()?;
assert_eq!("banana".to_string(), msg_payload);
let repeats = 6;
for _ in 0..repeats {
let _: () = publish_conn.publish("phonewave", "banana").await?;
}

for _ in 0..repeats {
let message: String =
pubsub_stream.next().await.unwrap().get_payload().unwrap();

assert_eq!("banana".to_string(), message);
}

Ok(())
})
Expand Down Expand Up @@ -748,14 +757,18 @@ mod basic_async {
block_on_all(async move {
let (mut sink, mut stream) = ctx.async_pubsub().await?.split();
let mut publish_conn = ctx.async_connection().await?;
let spawned_read = tokio::spawn(async move { stream.next().await });

let _: () = sink.subscribe("phonewave").await?;
let _: () = publish_conn.publish("phonewave", "banana").await?;
let repeats = 6;
for _ in 0..repeats {
let _: () = publish_conn.publish("phonewave", "banana").await?;
}

let message: String = spawned_read.await.unwrap().unwrap().get_payload().unwrap();
for _ in 0..repeats {
let message: String = stream.next().await.unwrap().get_payload().unwrap();

assert_eq!("banana".to_string(), message);
assert_eq!("banana".to_string(), message);
}

Ok(())
})
Expand All @@ -768,15 +781,19 @@ mod basic_async {
block_on_all(async move {
let (mut sink, mut stream) = ctx.async_pubsub().await?.split();
let mut publish_conn = ctx.async_connection().await?;
let spawned_read = tokio::spawn(async move { stream.next().await });

let _: () = sink.subscribe("phonewave").await?;
drop(sink);
let _: () = publish_conn.publish("phonewave", "banana").await?;
let repeats = 6;
for _ in 0..repeats {
let _: () = publish_conn.publish("phonewave", "banana").await?;
}

let message: String = spawned_read.await.unwrap().unwrap().get_payload().unwrap();
for _ in 0..repeats {
let message: String = stream.next().await.unwrap().get_payload().unwrap();

assert_eq!("banana".to_string(), message);
assert_eq!("banana".to_string(), message);
}

Ok(())
})
Expand All @@ -792,11 +809,18 @@ mod basic_async {

let _: () = pubsub.subscribe("phonewave").await?;
let mut stream = pubsub.into_on_message();
let _: () = publish_conn.publish("phonewave", "banana").await?;
// wait a bit
sleep(Duration::from_secs(2).into()).await;
let repeats = 6;
for _ in 0..repeats {
let _: () = publish_conn.publish("phonewave", "banana").await?;
}

let message: String = stream.next().await.unwrap().get_payload().unwrap();
for _ in 0..repeats {
let message: String = stream.next().await.unwrap().get_payload().unwrap();

assert_eq!("banana".to_string(), message);
assert_eq!("banana".to_string(), message);
}

Ok(())
})
Expand Down

0 comments on commit 636a749

Please sign in to comment.