Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(workflows): filter messages by tags #1142

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/chirp-workflow/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ indoc = "2.0.5"
lazy_static = "1.4"
prost = "0.12.4"
prost-types = "0.12.4"
rand = "0.8.5"
rivet-cache = { path = "../../cache/build" }
rivet-connection = { path = "../../connection" }
rivet-metrics = { path = "../../metrics" }
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ where
M: Message,
B: Debug + Clone,
{
let msg_ctx = MessageCtx::new(ctx.conn(), ctx.req_id(), ctx.ray_id())
let msg_ctx = MessageCtx::new(ctx.conn(), ctx.ray_id())
.await
.map_err(GlobalError::raw)?;

Expand Down
6 changes: 3 additions & 3 deletions lib/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
db::DatabaseHandle,
error::WorkflowResult,
message::{Message, ReceivedMessage},
message::{Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
workflow::{Workflow, WorkflowInput},
Expand Down Expand Up @@ -55,7 +55,7 @@ impl ApiCtx {
(),
);

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;
let msg_ctx = MessageCtx::new(&conn, ray_id).await?;

Ok(ApiCtx {
ray_id,
Expand Down Expand Up @@ -129,7 +129,7 @@ impl ApiCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/ctx/backfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
};
use uuid::Uuid;

use crate::util::Location;
use crate::utils::Location;

// Yes
type Query = Box<
Expand Down
131 changes: 23 additions & 108 deletions lib/chirp-workflow/core/src/ctx/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use uuid::Uuid;

use crate::{
error::{WorkflowError, WorkflowResult},
message::{self, Message, MessageWrapper, ReceivedMessage, TraceEntry},
message::{redis_keys, Message, NatsMessage, NatsMessageWrapper},
utils,
};

/// Time (in ms) that we subtract from the anchor grace period in order to
Expand All @@ -29,29 +30,15 @@ pub struct MessageCtx {
/// Used for writing to message tails. This cache is ephemeral.
redis_chirp_ephemeral: RedisPool,

req_id: Uuid,
ray_id: Uuid,
trace: Vec<TraceEntry>,
}

impl MessageCtx {
pub async fn new(
conn: &rivet_connection::Connection,
req_id: Uuid,
ray_id: Uuid,
) -> WorkflowResult<Self> {
pub async fn new(conn: &rivet_connection::Connection, ray_id: Uuid) -> WorkflowResult<Self> {
Ok(MessageCtx {
nats: conn.nats().await?,
redis_chirp_ephemeral: conn.redis_chirp_ephemeral().await?,
req_id,
ray_id,
trace: conn
.chirp()
.trace()
.iter()
.cloned()
.map(TryInto::try_into)
.collect::<WorkflowResult<Vec<_>>>()?,
})
}
}
Expand Down Expand Up @@ -109,7 +96,7 @@ impl MessageCtx {
M: Message,
{
let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?;
let nats_subject = message::serialize_message_nats_subject::<M>(&tags_str);
let nats_subject = M::nats_subject();
let duration_since_epoch = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|err| unreachable!("time is broken: {}", err));
Expand All @@ -124,12 +111,11 @@ impl MessageCtx {

// Serialize message
let req_id = Uuid::new_v4();
let message = MessageWrapper {
let message = NatsMessageWrapper {
req_id: req_id,
ray_id: self.ray_id,
tags: tags.clone(),
tags,
ts,
trace: self.trace.clone(),
allow_recursive: false, // TODO:
body: &body_buf,
};
Expand Down Expand Up @@ -278,8 +264,7 @@ impl MessageCtx {
where
M: Message,
{
let tags_str = cjson::to_string(opts.tags).map_err(WorkflowError::SerializeMessageTags)?;
let nats_subject = message::serialize_message_nats_subject::<M>(&tags_str);
let nats_subject = M::nats_subject();

// Create subscription and flush immediately.
tracing::info!(%nats_subject, tags = ?opts.tags, "creating subscription");
Expand All @@ -296,7 +281,7 @@ impl MessageCtx {
}

// Return handle
let subscription = SubscriptionHandle::new(nats_subject, subscription, self.req_id);
let subscription = SubscriptionHandle::new(nats_subject, subscription, opts.tags.clone());
Ok(subscription)
}

Expand All @@ -305,7 +290,7 @@ impl MessageCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> WorkflowResult<Option<ReceivedMessage<M>>>
) -> WorkflowResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand All @@ -320,7 +305,7 @@ impl MessageCtx {

// Deserialize message
let message = if let Some(message_buf) = message_buf {
let message = ReceivedMessage::<M>::deserialize(message_buf.as_slice())?;
let message = NatsMessage::<M>::deserialize(message_buf.as_slice())?;
tracing::info!(?message, "immediate read tail message");

let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.;
Expand Down Expand Up @@ -410,7 +395,7 @@ where
_guard: DropGuard,
subject: String,
subscription: nats::Subscriber,
req_id: Uuid,
pub tags: serde_json::Value,
}

impl<M> Debug for SubscriptionHandle<M>
Expand All @@ -420,6 +405,7 @@ where
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SubscriptionHandle")
.field("subject", &self.subject)
.field("tags", &self.tags)
.finish()
}
}
Expand All @@ -429,7 +415,7 @@ where
M: Message,
{
#[tracing::instrument(level = "debug", skip_all)]
fn new(subject: String, subscription: nats::Subscriber, req_id: Uuid) -> Self {
fn new(subject: String, subscription: nats::Subscriber, tags: serde_json::Value) -> Self {
let token = CancellationToken::new();

{
Expand Down Expand Up @@ -458,34 +444,15 @@ where
_guard: token.drop_guard(),
subject,
subscription,
req_id,
tags,
}
}

/// Waits for the next message in the subscription.
///
/// This future can be safely dropped.
#[tracing::instrument]
pub async fn next(&mut self) -> WorkflowResult<ReceivedMessage<M>> {
self.next_inner(false).await
}

// TODO: Add a full config struct to pass to `next` that impl's `Default`
/// Waits for the next message in the subscription that originates from the
/// parent request ID via trace.
///
/// This future can be safely dropped.
#[tracing::instrument]
pub async fn next_with_trace(
&mut self,
filter_trace: bool,
) -> WorkflowResult<ReceivedMessage<M>> {
self.next_inner(filter_trace).await
}

/// This future can be safely dropped.
#[tracing::instrument(level = "trace")]
async fn next_inner(&mut self, filter_trace: bool) -> WorkflowResult<ReceivedMessage<M>> {
pub async fn next(&mut self) -> WorkflowResult<NatsMessage<M>> {
tracing::info!("waiting for message");

loop {
Expand All @@ -501,47 +468,22 @@ where
}
};

if filter_trace {
let message_wrapper =
ReceivedMessage::<M>::deserialize_wrapper(&nats_message.payload[..])?;

// Check if the message trace stack originates from this client
//
// We intentionally use the request ID instead of just checking the ray ID because
// there may be multiple calls to `message_with_subscribe` within the same ray.
// Explicitly checking the parent request ensures the response is unique to this
// message.
if message_wrapper
.trace
.iter()
.rev()
.any(|trace_entry| trace_entry.req_id == self.req_id)
{
let message = ReceivedMessage::<M>::deserialize(&nats_message.payload[..])?;
tracing::info!(?message, "received message");

return Ok(message);
}
} else {
let message = ReceivedMessage::<M>::deserialize(&nats_message.payload[..])?;
tracing::info!(?message, "received message");
let message_wrapper = NatsMessage::<M>::deserialize_wrapper(&nats_message.payload[..])?;

let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.;
crate::metrics::MESSAGE_RECV_LAG
.with_label_values(&[M::NAME])
.observe(recv_lag);
// Check if the subscription tags match a subset of the message tags
if utils::is_value_subset(&self.tags, &message_wrapper.tags) {
let message = NatsMessage::<M>::deserialize_from_wrapper(message_wrapper)?;
tracing::info!(?message, "received message");

return Ok(message);
}

// Message not from parent, continue with loop
// Message tags don't match, continue with loop
}
}

/// Converts the subscription in to a stream.
pub fn into_stream(
self,
) -> impl futures_util::Stream<Item = WorkflowResult<ReceivedMessage<M>>> {
pub fn into_stream(self) -> impl futures_util::Stream<Item = WorkflowResult<NatsMessage<M>>> {
futures_util::stream::try_unfold(self, |mut sub| async move {
let message = sub.next().await?;
Ok(Some((message, sub)))
Expand Down Expand Up @@ -569,7 +511,7 @@ pub enum TailAnchorResponse<M>
where
M: Message + Debug,
{
Message(ReceivedMessage<M>),
Message(NatsMessage<M>),

/// Anchor was older than the TTL of the message.
AnchorExpired,
Expand All @@ -589,30 +531,3 @@ where
}
}
}

mod redis_keys {
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};

use crate::message::Message;

/// HASH
pub fn message_tail<M>(tags_str: &str) -> String
where
M: Message,
{
// Get hash of the tags
let mut hasher = DefaultHasher::new();
tags_str.hash(&mut hasher);

format!("{{topic:{}:{:x}}}:tail", M::NAME, hasher.finish())
}

pub mod message_tail {
pub const REQUEST_ID: &str = "r";
pub const TS: &str = "t";
pub const BODY: &str = "b";
}
}
16 changes: 13 additions & 3 deletions lib/chirp-workflow/core/src/ctx/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use crate::{
},
db::DatabaseHandle,
error::WorkflowResult,
message::{Message, ReceivedMessage},
listen::Listen,
message::{Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
workflow::{Workflow, WorkflowInput},
Expand Down Expand Up @@ -54,7 +55,7 @@ impl StandaloneCtx {
(),
);

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;
let msg_ctx = MessageCtx::new(&conn, ray_id).await?;

Ok(StandaloneCtx {
ray_id,
Expand Down Expand Up @@ -92,6 +93,15 @@ impl StandaloneCtx {
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body)
}

// /// Listens for a signal indefinitely.
// pub async fn listen<T: Listen>(&mut self) -> GlobalResult<T> {
// tracing::info!(name=%self.name, "listening for signal");

// let ctx = ListenCtx::new(self);

// T::listen(&ctx).await
// }

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
pub async fn op<I>(
&self,
Expand Down Expand Up @@ -128,7 +138,7 @@ impl StandaloneCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand Down
10 changes: 5 additions & 5 deletions lib/chirp-workflow/core/src/ctx/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use crate::{
},
db::{DatabaseHandle, DatabasePgNats},
error::WorkflowError,
message::{Message, ReceivedMessage},
message::{Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
util,
utils,
workflow::{Workflow, WorkflowInput},
};

Expand Down Expand Up @@ -50,7 +50,7 @@ impl TestCtx {
.expect("failed to create chirp client");
let cache =
rivet_cache::CacheInner::from_env(pools.clone()).expect("failed to create cache");
let conn = util::new_conn(
let conn = utils::new_conn(
&shared_client,
&pools,
&cache,
Expand All @@ -73,7 +73,7 @@ impl TestCtx {

let db =
DatabasePgNats::from_pools(pools.crdb().unwrap(), pools.nats_option().clone().unwrap());
let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await.unwrap();
let msg_ctx = MessageCtx::new(&conn, ray_id).await.unwrap();

TestCtx {
name: service_name,
Expand Down Expand Up @@ -176,7 +176,7 @@ impl TestCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand Down
Loading
Loading