Skip to content

Commit

Permalink
feat(workflows): add message and signal history (#987)
Browse files Browse the repository at this point in the history
<!-- Please make sure there is an issue that this PR is correlated to. -->

## Changes

<!-- If there are frontend changes, please include screenshots. -->
  • Loading branch information
MasterPtato committed Aug 2, 2024
1 parent 38c1171 commit 0003acc
Show file tree
Hide file tree
Showing 9 changed files with 557 additions and 150 deletions.
39 changes: 5 additions & 34 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@ use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use uuid::Uuid;

use crate::{
ctx::{MessageCtx, OperationCtx},
error::{WorkflowError, WorkflowResult},
message::Message,
DatabaseHandle, Operation, OperationInput,
};
use crate::{ctx::OperationCtx, error::WorkflowError, DatabaseHandle, Operation, OperationInput};

#[derive(Clone)]
pub struct ActivityCtx {
Expand All @@ -19,21 +14,20 @@ pub struct ActivityCtx {
db: DatabaseHandle,

conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl ActivityCtx {
pub async fn new(
pub fn new(
workflow_id: Uuid,
db: DatabaseHandle,
conn: &rivet_connection::Connection,
activity_create_ts: i64,
ray_id: Uuid,
name: &'static str,
) -> WorkflowResult<Self> {
) -> Self {
let ts = rivet_util::timestamp::now();
let req_id = Uuid::new_v4();
let conn = conn.wrap(req_id, ray_id, name);
Expand All @@ -49,18 +43,15 @@ impl ActivityCtx {
);
op_ctx.from_workflow = true;

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;

Ok(ActivityCtx {
ActivityCtx {
workflow_id,
ray_id,
name,
ts,
db,
conn,
op_ctx,
msg_ctx,
})
}
}
}

Expand Down Expand Up @@ -95,26 +86,6 @@ impl ActivityCtx {
.await
.map_err(GlobalError::raw)
}

pub async fn msg<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message(tags, body)
.await
.map_err(GlobalError::raw)
}

pub async fn msg_wait<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message_wait(tags, body)
.await
.map_err(GlobalError::raw)
}
}

impl ActivityCtx {
Expand Down
221 changes: 187 additions & 34 deletions lib/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use uuid::Uuid;

use crate::{
activity::ActivityId,
ctx::{ActivityCtx, MessageCtx},
event::Event,
util::{self, Location},
Activity, ActivityCtx, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow,
RegistryHandle, Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult,
message::Message,
util::Location,
Activity, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow, RegistryHandle,
Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult,
};

// Time to delay a worker from retrying after an error
Expand Down Expand Up @@ -55,15 +57,19 @@ pub struct WorkflowCtx {

root_location: Location,
location_idx: usize,

msg_ctx: MessageCtx,
}

impl WorkflowCtx {
pub fn new(
pub async fn new(
registry: RegistryHandle,
db: DatabaseHandle,
conn: rivet_connection::Connection,
workflow: PulledWorkflow,
) -> GlobalResult<Self> {
let msg_ctx = MessageCtx::new(&conn, workflow.workflow_id, workflow.ray_id).await?;

Ok(WorkflowCtx {
workflow_id: workflow.workflow_id,
name: workflow.workflow_name,
Expand All @@ -77,18 +83,13 @@ impl WorkflowCtx {

conn,

event_history: Arc::new(
util::combine_events(
workflow.activity_events,
workflow.signal_events,
workflow.sub_workflow_events,
)
.map_err(GlobalError::raw)?,
),
event_history: Arc::new(workflow.events),
input: Arc::new(workflow.input),

root_location: Box::new([]),
location_idx: 0,

msg_ctx,
})
}

Expand Down Expand Up @@ -117,6 +118,8 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,

msg_ctx: self.msg_ctx.clone(),
};

self.location_idx += 1;
Expand Down Expand Up @@ -258,8 +261,7 @@ impl WorkflowCtx {
self.create_ts,
self.ray_id,
A::NAME,
)
.await?;
);

let res = tokio::time::timeout(A::TIMEOUT, A::run(&ctx, input))
.await
Expand Down Expand Up @@ -568,6 +570,8 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,

msg_ctx: self.msg_ctx.clone(),
};

self.location_idx += 1;
Expand Down Expand Up @@ -669,19 +673,47 @@ impl WorkflowCtx {
workflow_id: Uuid,
body: T,
) -> GlobalResult<Uuid> {
let signal_id = Uuid::new_v4();
let event = { self.relevant_history().nth(self.location_idx) };

tracing::info!(name=%T::NAME, %workflow_id, %signal_id, "dispatching signal");
// Signal sent before
let signal_id = if let Some(event) = event {
// Validate history is consistent
let Event::SignalSend(signal) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;
tracing::debug!(id=%self.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying signal dispatch");

self.db
.publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;
signal.signal_id
}
// Send signal
else {
let signal_id = Uuid::new_v4();
tracing::info!(id=%self.workflow_id, signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal");

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

self.db
.publish_signal_from_workflow(
self.workflow_id,
self.full_location().as_ref(),
self.ray_id,
workflow_id,
signal_id,
T::NAME,
input_val,
)
.await
.map_err(GlobalError::raw)?;

signal_id
};

// Move to next event
self.location_idx += 1;

Ok(signal_id)
}
Expand All @@ -692,19 +724,48 @@ impl WorkflowCtx {
tags: &serde_json::Value,
body: T,
) -> GlobalResult<Uuid> {
let signal_id = Uuid::new_v4();
let event = { self.relevant_history().nth(self.location_idx) };

tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal");
// Signal sent before
let signal_id = if let Some(event) = event {
// Validate history is consistent
let Event::SignalSend(signal) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;
tracing::debug!(id=%self.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying tagged signal dispatch");

self.db
.publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;
signal.signal_id
}
// Send signal
else {
let signal_id = Uuid::new_v4();

tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal");

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

self.db
.publish_tagged_signal_from_workflow(
self.workflow_id,
self.full_location().as_ref(),
self.ray_id,
tags,
signal_id,
T::NAME,
input_val,
)
.await
.map_err(GlobalError::raw)?;

signal_id
};

// Move to next event
self.location_idx += 1;

Ok(signal_id)
}
Expand Down Expand Up @@ -785,6 +846,98 @@ impl WorkflowCtx {
Ok(signal)
}

pub async fn msg<M>(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
let event = { self.relevant_history().nth(self.location_idx) };

// Message sent before
if let Some(event) = event {
// Validate history is consistent
let Event::MessageSend(msg) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch");
}
// Send message
else {
tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message");

// Serialize body
let body_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;
let location = self.full_location();

let (msg, write) = tokio::join!(
self.db.publish_message_from_workflow(
self.workflow_id,
location.as_ref(),
&tags,
M::NAME,
body_val
),
self.msg_ctx.message(tags.clone(), body),
);

msg.map_err(GlobalError::raw)?;
write.map_err(GlobalError::raw)?;
}

// Move to next event
self.location_idx += 1;

Ok(())
}

pub async fn msg_wait<M>(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
let event = { self.relevant_history().nth(self.location_idx) };

// Message sent before
if let Some(event) = event {
// Validate history is consistent
let Event::MessageSend(msg) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch");
}
// Send message
else {
tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message");

// Serialize body
let body_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;
let location = self.full_location();

let (msg, write) = tokio::join!(
self.db.publish_message_from_workflow(
self.workflow_id,
location.as_ref(),
&tags,
M::NAME,
body_val
),
self.msg_ctx.message_wait(tags.clone(), body),
);

msg.map_err(GlobalError::raw)?;
write.map_err(GlobalError::raw)?;
}

// Move to next event
self.location_idx += 1;

Ok(())
}

// TODO: sleep_for, sleep_until
}

Expand Down
Loading

0 comments on commit 0003acc

Please sign in to comment.