diff --git a/lib/chirp-workflow/core/src/ctx/activity.rs b/lib/chirp-workflow/core/src/ctx/activity.rs index 858f854141..ec488fa226 100644 --- a/lib/chirp-workflow/core/src/ctx/activity.rs +++ b/lib/chirp-workflow/core/src/ctx/activity.rs @@ -2,12 +2,7 @@ use global_error::{GlobalError, GlobalResult}; use rivet_pools::prelude::*; use uuid::Uuid; -use crate::{ - ctx::{MessageCtx, OperationCtx}, - error::{WorkflowError, WorkflowResult}, - message::Message, - DatabaseHandle, Operation, OperationInput, -}; +use crate::{ctx::OperationCtx, error::WorkflowError, DatabaseHandle, Operation, OperationInput}; #[derive(Clone)] pub struct ActivityCtx { @@ -19,21 +14,20 @@ pub struct ActivityCtx { db: DatabaseHandle, conn: rivet_connection::Connection, - msg_ctx: MessageCtx, // Backwards compatibility op_ctx: rivet_operation::OperationContext<()>, } impl ActivityCtx { - pub async fn new( + pub fn new( workflow_id: Uuid, db: DatabaseHandle, conn: &rivet_connection::Connection, activity_create_ts: i64, ray_id: Uuid, name: &'static str, - ) -> WorkflowResult { + ) -> Self { let ts = rivet_util::timestamp::now(); let req_id = Uuid::new_v4(); let conn = conn.wrap(req_id, ray_id, name); @@ -49,9 +43,7 @@ impl ActivityCtx { ); op_ctx.from_workflow = true; - let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?; - - Ok(ActivityCtx { + ActivityCtx { workflow_id, ray_id, name, @@ -59,8 +51,7 @@ impl ActivityCtx { db, conn, op_ctx, - msg_ctx, - }) + } } } @@ -95,26 +86,6 @@ impl ActivityCtx { .await .map_err(GlobalError::raw) } - - pub async fn msg(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> - where - M: Message, - { - self.msg_ctx - .message(tags, body) - .await - .map_err(GlobalError::raw) - } - - pub async fn msg_wait(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> - where - M: Message, - { - self.msg_ctx - .message_wait(tags, body) - .await - .map_err(GlobalError::raw) - } } impl ActivityCtx { diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index a6dbd78f87..a85ae55f54 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -7,10 +7,12 @@ use uuid::Uuid; use crate::{ activity::ActivityId, + ctx::{ActivityCtx, MessageCtx}, event::Event, - util::{self, Location}, - Activity, ActivityCtx, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow, - RegistryHandle, Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult, + message::Message, + util::Location, + Activity, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow, RegistryHandle, + Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult, }; // Time to delay a worker from retrying after an error @@ -55,15 +57,19 @@ pub struct WorkflowCtx { root_location: Location, location_idx: usize, + + msg_ctx: MessageCtx, } impl WorkflowCtx { - pub fn new( + pub async fn new( registry: RegistryHandle, db: DatabaseHandle, conn: rivet_connection::Connection, workflow: PulledWorkflow, ) -> GlobalResult { + let msg_ctx = MessageCtx::new(&conn, workflow.workflow_id, workflow.ray_id).await?; + Ok(WorkflowCtx { workflow_id: workflow.workflow_id, name: workflow.workflow_name, @@ -77,18 +83,13 @@ impl WorkflowCtx { conn, - event_history: Arc::new( - util::combine_events( - workflow.activity_events, - workflow.signal_events, - workflow.sub_workflow_events, - ) - .map_err(GlobalError::raw)?, - ), + event_history: Arc::new(workflow.events), input: Arc::new(workflow.input), root_location: Box::new([]), location_idx: 0, + + msg_ctx, }) } @@ -117,6 +118,8 @@ impl WorkflowCtx { .chain(std::iter::once(self.location_idx)) .collect(), location_idx: 0, + + msg_ctx: self.msg_ctx.clone(), }; self.location_idx += 1; @@ -258,8 +261,7 @@ impl WorkflowCtx { self.create_ts, self.ray_id, A::NAME, - ) - .await?; + ); let res = tokio::time::timeout(A::TIMEOUT, A::run(&ctx, input)) .await @@ -568,6 +570,8 @@ impl WorkflowCtx { .chain(std::iter::once(self.location_idx)) .collect(), location_idx: 0, + + msg_ctx: self.msg_ctx.clone(), }; self.location_idx += 1; @@ -669,19 +673,47 @@ impl WorkflowCtx { workflow_id: Uuid, body: T, ) -> GlobalResult { - let signal_id = Uuid::new_v4(); + let event = { self.relevant_history().nth(self.location_idx) }; - tracing::info!(name=%T::NAME, %workflow_id, %signal_id, "dispatching signal"); + // Signal sent before + let signal_id = if let Some(event) = event { + // Validate history is consistent + let Event::SignalSend(signal) = event else { + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); + }; - // Serialize input - let input_val = serde_json::to_value(&body) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; + tracing::debug!(id=%self.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying signal dispatch"); - self.db - .publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; + signal.signal_id + } + // Send signal + else { + let signal_id = Uuid::new_v4(); + tracing::info!(id=%self.workflow_id, signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal"); + + // Serialize input + let input_val = serde_json::to_value(&body) + .map_err(WorkflowError::SerializeSignalBody) + .map_err(GlobalError::raw)?; + + self.db + .publish_signal_from_workflow( + self.workflow_id, + self.full_location().as_ref(), + self.ray_id, + workflow_id, + signal_id, + T::NAME, + input_val, + ) + .await + .map_err(GlobalError::raw)?; + + signal_id + }; + + // Move to next event + self.location_idx += 1; Ok(signal_id) } @@ -692,19 +724,48 @@ impl WorkflowCtx { tags: &serde_json::Value, body: T, ) -> GlobalResult { - let signal_id = Uuid::new_v4(); + let event = { self.relevant_history().nth(self.location_idx) }; - tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); + // Signal sent before + let signal_id = if let Some(event) = event { + // Validate history is consistent + let Event::SignalSend(signal) = event else { + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); + }; - // Serialize input - let input_val = serde_json::to_value(&body) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; + tracing::debug!(id=%self.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying tagged signal dispatch"); - self.db - .publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; + signal.signal_id + } + // Send signal + else { + let signal_id = Uuid::new_v4(); + + tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); + + // Serialize input + let input_val = serde_json::to_value(&body) + .map_err(WorkflowError::SerializeSignalBody) + .map_err(GlobalError::raw)?; + + self.db + .publish_tagged_signal_from_workflow( + self.workflow_id, + self.full_location().as_ref(), + self.ray_id, + tags, + signal_id, + T::NAME, + input_val, + ) + .await + .map_err(GlobalError::raw)?; + + signal_id + }; + + // Move to next event + self.location_idx += 1; Ok(signal_id) } @@ -785,6 +846,98 @@ impl WorkflowCtx { Ok(signal) } + pub async fn msg(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()> + where + M: Message, + { + let event = { self.relevant_history().nth(self.location_idx) }; + + // Message sent before + if let Some(event) = event { + // Validate history is consistent + let Event::MessageSend(msg) = event else { + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); + }; + + tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch"); + } + // Send message + else { + tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message"); + + // Serialize body + let body_val = serde_json::to_value(&body) + .map_err(WorkflowError::SerializeWorkflowOutput) + .map_err(GlobalError::raw)?; + let location = self.full_location(); + + let (msg, write) = tokio::join!( + self.db.publish_message_from_workflow( + self.workflow_id, + location.as_ref(), + &tags, + M::NAME, + body_val + ), + self.msg_ctx.message(tags.clone(), body), + ); + + msg.map_err(GlobalError::raw)?; + write.map_err(GlobalError::raw)?; + } + + // Move to next event + self.location_idx += 1; + + Ok(()) + } + + pub async fn msg_wait(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()> + where + M: Message, + { + let event = { self.relevant_history().nth(self.location_idx) }; + + // Message sent before + if let Some(event) = event { + // Validate history is consistent + let Event::MessageSend(msg) = event else { + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); + }; + + tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch"); + } + // Send message + else { + tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message"); + + // Serialize body + let body_val = serde_json::to_value(&body) + .map_err(WorkflowError::SerializeWorkflowOutput) + .map_err(GlobalError::raw)?; + let location = self.full_location(); + + let (msg, write) = tokio::join!( + self.db.publish_message_from_workflow( + self.workflow_id, + location.as_ref(), + &tags, + M::NAME, + body_val + ), + self.msg_ctx.message_wait(tags.clone(), body), + ); + + msg.map_err(GlobalError::raw)?; + write.map_err(GlobalError::raw)?; + } + + // Move to next event + self.location_idx += 1; + + Ok(()) + } + // TODO: sleep_for, sleep_until } diff --git a/lib/chirp-workflow/core/src/db/mod.rs b/lib/chirp-workflow/core/src/db/mod.rs index e929d9ba74..ee5097efe0 100644 --- a/lib/chirp-workflow/core/src/db/mod.rs +++ b/lib/chirp-workflow/core/src/db/mod.rs @@ -1,8 +1,10 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use uuid::Uuid; -use crate::{activity::ActivityId, Workflow, WorkflowError, WorkflowResult}; +use crate::{ + activity::ActivityId, event::Event, util::Location, Workflow, WorkflowError, WorkflowResult, +}; mod postgres; pub use postgres::DatabasePostgres; @@ -58,6 +60,12 @@ pub trait Database: Send { output: Result, ) -> WorkflowResult<()>; + async fn pull_next_signal( + &self, + workflow_id: Uuid, + filter: &[&str], + location: &[usize], + ) -> WorkflowResult>; async fn publish_signal( &self, ray_id: Uuid, @@ -74,12 +82,26 @@ pub trait Database: Send { signal_name: &str, body: serde_json::Value, ) -> WorkflowResult<()>; - async fn pull_next_signal( + async fn publish_signal_from_workflow( &self, + from_workflow_id: Uuid, + location: &[usize], + ray_id: Uuid, workflow_id: Uuid, - filter: &[&str], + signal_id: Uuid, + signal_name: &str, + body: serde_json::Value, + ) -> WorkflowResult<()>; + async fn publish_tagged_signal_from_workflow( + &self, + from_workflow_id: Uuid, location: &[usize], - ) -> WorkflowResult>; + ray_id: Uuid, + tags: &serde_json::Value, + signal_id: Uuid, + signal_name: &str, + body: serde_json::Value, + ) -> WorkflowResult<()>; async fn dispatch_sub_workflow( &self, @@ -99,6 +121,15 @@ pub trait Database: Send { input: &serde_json::Value, after_ts: i64, ) -> WorkflowResult>; + + async fn publish_message_from_workflow( + &self, + from_workflow_id: Uuid, + location: &[usize], + tags: &serde_json::Value, + message_name: &str, + body: serde_json::Value, + ) -> WorkflowResult<()>; } #[derive(sqlx::FromRow)] @@ -136,9 +167,7 @@ pub struct PulledWorkflow { pub input: serde_json::Value, pub wake_deadline_ts: Option, - pub activity_events: Vec, - pub signal_events: Vec, - pub sub_workflow_events: Vec, + pub events: HashMap>, } #[derive(sqlx::FromRow)] @@ -160,6 +189,21 @@ pub struct SignalEventRow { pub body: serde_json::Value, } +#[derive(sqlx::FromRow)] +pub struct SignalSendEventRow { + pub workflow_id: Uuid, + pub location: Vec, + pub signal_id: Uuid, + pub signal_name: String, +} + +#[derive(sqlx::FromRow)] +pub struct MessageSendEventRow { + pub workflow_id: Uuid, + pub location: Vec, + pub message_name: String, +} + #[derive(sqlx::FromRow)] pub struct SubWorkflowEventRow { pub workflow_id: Uuid, diff --git a/lib/chirp-workflow/core/src/db/postgres.rs b/lib/chirp-workflow/core/src/db/postgres.rs index e057db1692..7fa70d3fc2 100644 --- a/lib/chirp-workflow/core/src/db/postgres.rs +++ b/lib/chirp-workflow/core/src/db/postgres.rs @@ -1,14 +1,14 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use indoc::indoc; use sqlx::{pool::PoolConnection, PgPool, Postgres}; use uuid::Uuid; use super::{ - ActivityEventRow, Database, PulledWorkflow, PulledWorkflowRow, SignalEventRow, SignalRow, - SubWorkflowEventRow, WorkflowRow, + ActivityEventRow, Database, MessageSendEventRow, PulledWorkflow, PulledWorkflowRow, + SignalEventRow, SignalRow, SignalSendEventRow, SubWorkflowEventRow, WorkflowRow, }; -use crate::{activity::ActivityId, WorkflowError, WorkflowResult}; +use crate::{activity::ActivityId, util, WorkflowError, WorkflowResult}; pub struct DatabasePostgres { pool: PgPool, @@ -106,7 +106,7 @@ impl Database for DatabasePostgres { ) -> WorkflowResult> { // TODO(RVT-3753): include limit on query to allow better workflow spread between nodes? // Select all workflows that haven't started or that have a wake condition - let rows = sqlx::query_as::<_, PulledWorkflowRow>(indoc!( + let workflow_rows = sqlx::query_as::<_, PulledWorkflowRow>(indoc!( " WITH pull_workflows AS ( @@ -173,34 +173,24 @@ impl Database for DatabasePostgres { .await .map_err(WorkflowError::Sqlx)?; - if rows.is_empty() { + if workflow_rows.is_empty() { return Ok(Vec::new()); } // Turn rows into hashmap - let workflow_ids = rows.iter().map(|row| row.workflow_id).collect::>(); - let mut workflows_by_id = rows - .into_iter() - .map(|row| { - ( - row.workflow_id, - PulledWorkflow { - workflow_id: row.workflow_id, - workflow_name: row.workflow_name, - create_ts: row.create_ts, - ray_id: row.ray_id, - input: row.input, - wake_deadline_ts: row.wake_deadline_ts, - activity_events: Vec::new(), - signal_events: Vec::new(), - sub_workflow_events: Vec::new(), - }, - ) - }) - .collect::>(); + let workflow_ids = workflow_rows + .iter() + .map(|row| row.workflow_id) + .collect::>(); // Fetch all events for all fetched workflows - let (activity_events, signal_events, sub_workflow_events) = tokio::try_join!( + let ( + activity_events, + signal_events, + signal_send_events, + msg_send_events, + sub_workflow_events, + ) = tokio::try_join!( async { sqlx::query_as::<_, ActivityEventRow>(indoc!( " @@ -242,6 +232,36 @@ impl Database for DatabasePostgres { .await .map_err(WorkflowError::Sqlx) }, + async { + sqlx::query_as::<_, SignalSendEventRow>(indoc!( + " + SELECT + workflow_id, location, signal_id, signal_name + FROM db_workflow.workflow_signal_send_events + WHERE workflow_id = ANY($1) + ORDER BY workflow_id, location ASC + ", + )) + .bind(&workflow_ids) + .fetch_all(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx) + }, + async { + sqlx::query_as::<_, MessageSendEventRow>(indoc!( + " + SELECT + workflow_id, location, message_name + FROM db_workflow.workflow_message_send_events + WHERE workflow_id = ANY($1) + ORDER BY workflow_id, location ASC + ", + )) + .bind(&workflow_ids) + .fetch_all(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx) + }, async { sqlx::query_as::<_, SubWorkflowEventRow>(indoc!( " @@ -264,30 +284,16 @@ impl Database for DatabasePostgres { } )?; - // Insert events into hashmap - for event in activity_events { - workflows_by_id - .get_mut(&event.workflow_id) - .expect("unreachable, workflow for event not found") - .activity_events - .push(event); - } - for event in signal_events { - workflows_by_id - .get_mut(&event.workflow_id) - .expect("unreachable, workflow for event not found") - .signal_events - .push(event); - } - for event in sub_workflow_events { - workflows_by_id - .get_mut(&event.workflow_id) - .expect("unreachable, workflow for event not found") - .sub_workflow_events - .push(event); - } + let workflows = util::combine_events( + workflow_rows, + activity_events, + signal_events, + signal_send_events, + msg_send_events, + sub_workflow_events, + )?; - Ok(workflows_by_id.into_values().collect()) + Ok(workflows) } async fn commit_workflow( @@ -565,6 +571,92 @@ impl Database for DatabasePostgres { Ok(()) } + async fn publish_signal_from_workflow( + &self, + from_workflow_id: Uuid, + location: &[usize], + ray_id: Uuid, + to_workflow_id: Uuid, + signal_id: Uuid, + signal_name: &str, + body: serde_json::Value, + ) -> WorkflowResult<()> { + sqlx::query(indoc!( + " + WITH + signal AS ( + INSERT INTO db_workflow.signals (signal_id, workflow_id, signal_name, body, ray_id, create_ts) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING 1 + ), + send_event AS ( + INSERT INTO db_workflow.workflow_signal_send_events( + workflow_id, location, signal_id, signal_name, body + ) + VALUES($7, $8, $1, $3, $4) + RETURNING 1 + ) + SELECT 1 + ", + )) + .bind(signal_id) + .bind(to_workflow_id) + .bind(signal_name) + .bind(body) + .bind(ray_id) + .bind(rivet_util::timestamp::now()) + .bind(from_workflow_id) + .bind(location.iter().map(|x| *x as i64).collect::>()) + .execute(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx)?; + + Ok(()) + } + + async fn publish_tagged_signal_from_workflow( + &self, + from_workflow_id: Uuid, + location: &[usize], + ray_id: Uuid, + tags: &serde_json::Value, + signal_id: Uuid, + signal_name: &str, + body: serde_json::Value, + ) -> WorkflowResult<()> { + sqlx::query(indoc!( + " + WITH + signal AS ( + INSERT INTO db_workflow.tagged_signals (signal_id, tags, signal_name, body, ray_id, create_ts) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING 1 + ), + send_event AS ( + INSERT INTO db_workflow.workflow_signal_send_events( + workflow_id, location, signal_id, signal_name, body + ) + VALUES($7, $8, $1, $3, $4) + RETURNING 1 + ) + SELECT 1 + ", + )) + .bind(signal_id) + .bind(tags) + .bind(signal_name) + .bind(body) + .bind(ray_id) + .bind(rivet_util::timestamp::now()) + .bind(from_workflow_id) + .bind(location.iter().map(|x| *x as i64).collect::>()) + .execute(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx)?; + + Ok(()) + } + async fn dispatch_sub_workflow( &self, ray_id: Uuid, @@ -634,4 +726,33 @@ impl Database for DatabasePostgres { .await .map_err(WorkflowError::Sqlx) } + + async fn publish_message_from_workflow( + &self, + from_workflow_id: Uuid, + location: &[usize], + tags: &serde_json::Value, + message_name: &str, + body: serde_json::Value, + ) -> WorkflowResult<()> { + sqlx::query(indoc!( + " + INSERT INTO db_workflow.workflow_message_send_events( + workflow_id, location, tags, message_name, body + ) + VALUES($1, $2, $3, $4, $5) + RETURNING 1 + ", + )) + .bind(from_workflow_id) + .bind(location.iter().map(|x| *x as i64).collect::>()) + .bind(tags) + .bind(message_name) + .bind(body) + .execute(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx)?; + + Ok(()) + } } diff --git a/lib/chirp-workflow/core/src/event.rs b/lib/chirp-workflow/core/src/event.rs index 4646dba299..8d5c9e8a8d 100644 --- a/lib/chirp-workflow/core/src/event.rs +++ b/lib/chirp-workflow/core/src/event.rs @@ -2,8 +2,8 @@ use serde::de::DeserializeOwned; use uuid::Uuid; use crate::{ - activity::ActivityId, ActivityEventRow, SignalEventRow, SubWorkflowEventRow, WorkflowError, - WorkflowResult, + activity::ActivityId, ActivityEventRow, MessageSendEventRow, SignalEventRow, + SignalSendEventRow, SubWorkflowEventRow, WorkflowError, WorkflowResult, }; /// An event that happened in the workflow run. @@ -13,6 +13,8 @@ use crate::{ pub enum Event { Activity(ActivityEvent), Signal(SignalEvent), + SignalSend(SignalSendEvent), + MessageSend(MessageSendEvent), SubWorkflow(SubWorkflowEvent), // Used as a placeholder for branching locations Branch, @@ -71,6 +73,38 @@ impl TryFrom for SignalEvent { } } +#[derive(Debug)] +pub struct SignalSendEvent { + pub signal_id: Uuid, + pub name: String, +} + +impl TryFrom for SignalSendEvent { + type Error = WorkflowError; + + fn try_from(value: SignalSendEventRow) -> WorkflowResult { + Ok(SignalSendEvent { + signal_id: value.signal_id, + name: value.signal_name, + }) + } +} + +#[derive(Debug)] +pub struct MessageSendEvent { + pub name: String, +} + +impl TryFrom for MessageSendEvent { + type Error = WorkflowError; + + fn try_from(value: MessageSendEventRow) -> WorkflowResult { + Ok(MessageSendEvent { + name: value.message_name, + }) + } +} + #[derive(Debug)] pub struct SubWorkflowEvent { pub sub_workflow_id: Uuid, diff --git a/lib/chirp-workflow/core/src/util.rs b/lib/chirp-workflow/core/src/util.rs index decd6166d6..cacb91551c 100644 --- a/lib/chirp-workflow/core/src/util.rs +++ b/lib/chirp-workflow/core/src/util.rs @@ -9,8 +9,8 @@ use tokio::time::{self, Duration}; use uuid::Uuid; use crate::{ - error::WorkflowError, event::Event, ActivityEventRow, SignalEventRow, SubWorkflowEventRow, - WorkflowResult, + error::WorkflowError, event::Event, ActivityEventRow, MessageSendEventRow, PulledWorkflow, + PulledWorkflowRow, SignalEventRow, SignalSendEventRow, SubWorkflowEventRow, WorkflowResult, }; pub type Location = Box<[usize]>; @@ -67,16 +67,29 @@ pub async fn sleep_until_ts(ts: i64) { /// ], /// } pub fn combine_events( + workflow_rows: Vec, activity_events: Vec, signal_events: Vec, + signal_send_events: Vec, + msg_send_events: Vec, sub_workflow_events: Vec, -) -> WorkflowResult>> { - let mut events_by_location: HashMap> = HashMap::new(); +) -> WorkflowResult> { + let mut workflows_by_id = workflow_rows + .into_iter() + .map(|row| { + let events_by_location: HashMap> = HashMap::new(); + + (row.workflow_id, (row, events_by_location)) + }) + .collect::>(); for event in activity_events { + let (_, ref mut events_by_location) = workflows_by_id + .get_mut(&event.workflow_id) + .expect("unreachable, workflow for event not found"); let (root_location, idx) = split_location(&event.location); - insert_placeholder(&mut events_by_location, &event.location, idx); + insert_placeholder(events_by_location, &event.location, idx); events_by_location .entry(root_location) @@ -85,9 +98,12 @@ pub fn combine_events( } for event in signal_events { + let (_, ref mut events_by_location) = workflows_by_id + .get_mut(&event.workflow_id) + .expect("unreachable, workflow for event not found"); let (root_location, idx) = split_location(&event.location); - insert_placeholder(&mut events_by_location, &event.location, idx); + insert_placeholder(events_by_location, &event.location, idx); events_by_location .entry(root_location) @@ -95,33 +111,79 @@ pub fn combine_events( .push((idx, Event::Signal(event.try_into()?))); } - for event in sub_workflow_events { + for event in signal_send_events { + let (_, ref mut events_by_location) = workflows_by_id + .get_mut(&event.workflow_id) + .expect("unreachable, workflow for event not found"); let (root_location, idx) = split_location(&event.location); - insert_placeholder(&mut events_by_location, &event.location, idx); + insert_placeholder(events_by_location, &event.location, idx); events_by_location .entry(root_location) .or_default() - .push((idx, Event::SubWorkflow(event.try_into()?))); + .push((idx, Event::SignalSend(event.try_into()?))); } - // TODO(RVT-3754): This involves inserting, sorting, then recollecting into lists and recollecting into a - // hashmap. Could be improved by iterating over both lists simultaneously and sorting each item at a - // time before inserting - // Sort all of the events because we are inserting from two different lists. Both lists are already - // sorted themselves so this should be fairly cheap - for (_, list) in events_by_location.iter_mut() { - list.sort_by_key(|(idx, _)| *idx); + for event in msg_send_events { + let (_, ref mut events_by_location) = workflows_by_id + .get_mut(&event.workflow_id) + .expect("unreachable, workflow for event not found"); + let (root_location, idx) = split_location(&event.location); + + insert_placeholder(events_by_location, &event.location, idx); + + events_by_location + .entry(root_location) + .or_default() + .push((idx, Event::MessageSend(event.try_into()?))); } - // Remove idx from lists - let event_history = events_by_location - .into_iter() - .map(|(k, v)| (k, v.into_iter().map(|(_, v)| v).collect())) + for event in sub_workflow_events { + let (_, ref mut events_by_location) = workflows_by_id + .get_mut(&event.workflow_id) + .expect("unreachable, workflow for event not found"); + let (root_location, idx) = split_location(&event.location); + + insert_placeholder(events_by_location, &event.location, idx); + + events_by_location + .entry(root_location) + .or_default() + .push((idx, Event::SubWorkflow(event.try_into()?))); + } + + let workflows = workflows_by_id + .into_values() + .map(|(row, mut events_by_location)| { + // TODO(RVT-3754): This involves inserting, sorting, then recollecting into lists and recollecting into a + // hashmap. Could be improved by iterating over both lists simultaneously and sorting each item at a + // time before inserting + // Sort all of the events because we are inserting from two different lists. Both lists are already + // sorted themselves so this should be fairly cheap + for (_, list) in events_by_location.iter_mut() { + list.sort_by_key(|(idx, _)| *idx); + } + + // Remove idx from lists + let event_history = events_by_location + .into_iter() + .map(|(k, v)| (k, v.into_iter().map(|(_, v)| v).collect())) + .collect(); + + PulledWorkflow { + workflow_id: row.workflow_id, + workflow_name: row.workflow_name, + create_ts: row.create_ts, + ray_id: row.ray_id, + input: row.input, + wake_deadline_ts: row.wake_deadline_ts, + events: event_history, + } + }) .collect(); - Ok(event_history) + Ok(workflows) } fn split_location(location: &[i64]) -> (Location, i64) { diff --git a/lib/chirp-workflow/core/src/worker.rs b/lib/chirp-workflow/core/src/worker.rs index 3c10dd310c..61ac94d9a2 100644 --- a/lib/chirp-workflow/core/src/worker.rs +++ b/lib/chirp-workflow/core/src/worker.rs @@ -75,7 +75,8 @@ impl Worker { &workflow.workflow_name, ); let wake_deadline_ts = workflow.wake_deadline_ts; - let ctx = WorkflowCtx::new(self.registry.clone(), self.db.clone(), conn, workflow)?; + let ctx = + WorkflowCtx::new(self.registry.clone(), self.db.clone(), conn, workflow).await?; tokio::task::spawn( async move { diff --git a/svc/pkg/workflow/db/workflow/migrations/20240711213725_signal_msg_events.down.sql b/svc/pkg/workflow/db/workflow/migrations/20240711213725_signal_msg_events.down.sql new file mode 100644 index 0000000000..e69de29bb2 diff --git a/svc/pkg/workflow/db/workflow/migrations/20240711213725_signal_msg_events.up.sql b/svc/pkg/workflow/db/workflow/migrations/20240711213725_signal_msg_events.up.sql new file mode 100644 index 0000000000..0a6466e47d --- /dev/null +++ b/svc/pkg/workflow/db/workflow/migrations/20240711213725_signal_msg_events.up.sql @@ -0,0 +1,21 @@ +-- Stores sent signals for replay +CREATE TABLE workflow_signal_send_events ( + workflow_id UUID NOT NULL REFERENCES workflows, + location INT[] NOT NULL, + signal_id TEXT NOT NULL, + signal_name TEXT NOT NULL, + body JSONB NOT NULL, + + PRIMARY KEY (workflow_id, location) +); + +-- Stores messages signals for replay +CREATE TABLE workflow_message_send_events ( + workflow_id UUID NOT NULL REFERENCES workflows, + location INT[] NOT NULL, + tags JSONB NOT NULL, + message_name TEXT NOT NULL, + body JSONB NOT NULL, + + PRIMARY KEY (workflow_id, location) +);