Skip to content

Commit

Permalink
feat(workflows): add tags
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Jul 3, 2024
1 parent 8413349 commit 050db41
Show file tree
Hide file tree
Showing 15 changed files with 440 additions and 41 deletions.
85 changes: 84 additions & 1 deletion lib/chirp-workflow/core/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,43 @@ where

db_from_ctx(ctx)
.await?
.dispatch_workflow(ctx.ray_id(), id, &name, input_val)
.dispatch_workflow(ctx.ray_id(), id, &name, None, input_val)
.await
.map_err(GlobalError::raw)?;

tracing::info!(%name, ?id, "workflow dispatched");

Ok(id)
}

pub async fn dispatch_tagged_workflow<I, B>(
ctx: &rivet_operation::OperationContext<B>,
tags: &serde_json::Value,
input: I,
) -> GlobalResult<Uuid>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
B: Debug + Clone,
{
if ctx.from_workflow {
bail!("cannot dispatch a workflow from an operation within a workflow execution. trigger it from the workflow's body.");
}

let name = I::Workflow::NAME;

tracing::debug!(%name, ?input, "dispatching workflow");

let id = Uuid::new_v4();

// Serialize input
let input_val = serde_json::to_value(input)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;

db_from_ctx(ctx)
.await?
.dispatch_workflow(ctx.ray_id(), id, &name, Some(tags), input_val)
.await
.map_err(GlobalError::raw)?;

Expand Down Expand Up @@ -92,6 +128,26 @@ where
.await?
}

/// Dispatch a new workflow and wait for it to complete. Has a 60s timeout.
pub async fn tagged_workflow<I, B>(
ctx: &rivet_operation::OperationContext<B>,
tags: &serde_json::Value,
input: I,
) -> GlobalResult<<<I as WorkflowInput>::Workflow as Workflow>::Output>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
B: Debug + Clone,
{
let workflow_id = dispatch_tagged_workflow(ctx, tags, input).await?;

tokio::time::timeout(
WORKFLOW_TIMEOUT,
wait_for_workflow::<I::Workflow, _>(ctx, workflow_id),
)
.await?
}

pub async fn signal<I: Signal + Serialize, B: Debug + Clone>(
ctx: &rivet_operation::OperationContext<B>,
workflow_id: Uuid,
Expand Down Expand Up @@ -119,6 +175,33 @@ pub async fn signal<I: Signal + Serialize, B: Debug + Clone>(
Ok(signal_id)
}

pub async fn tagged_signal<I: Signal + Serialize, B: Debug + Clone>(
ctx: &rivet_operation::OperationContext<B>,
tags: &serde_json::Value,
input: I,
) -> GlobalResult<Uuid> {
if ctx.from_workflow {
bail!("cannot dispatch a signal from an operation within a workflow execution. trigger it from the workflow's body.");
}

tracing::debug!(name=%I::NAME, ?tags, "dispatching tagged signal");

let signal_id = Uuid::new_v4();

// Serialize input
let input_val = serde_json::to_value(input)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

db_from_ctx(ctx)
.await?
.publish_tagged_signal(ctx.ray_id(), tags, signal_id, I::NAME, input_val)
.await
.map_err(GlobalError::raw)?;

Ok(signal_id)
}

pub async fn op<I, B>(
ctx: &rivet_operation::OperationContext<B>,
input: I,
Expand Down
69 changes: 68 additions & 1 deletion lib/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,33 @@ impl ApiCtx {
.map_err(GlobalError::raw)?;

self.db
.dispatch_workflow(self.ray_id, id, &name, input_val)
.dispatch_workflow(self.ray_id, id, &name, None, input_val)
.await
.map_err(GlobalError::raw)?;

tracing::info!(%name, ?id, "workflow dispatched");

Ok(id)
}

pub async fn dispatch_tagged_workflow<I>(&self, tags: &serde_json::Value, input: I) -> GlobalResult<Uuid>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
let name = I::Workflow::NAME;

tracing::debug!(%name, ?tags, ?input, "dispatching tagged workflow");

let id = Uuid::new_v4();

// Serialize input
let input_val = serde_json::to_value(input)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;

self.db
.dispatch_workflow(self.ray_id, id, &name, Some(tags), input_val)
.await
.map_err(GlobalError::raw)?;

Expand Down Expand Up @@ -128,6 +154,25 @@ impl ApiCtx {
.await?
}

/// Dispatch a new workflow with tags and wait for it to complete. Has a 60s timeout.
pub async fn tagged_workflow<I>(
&self,
tags: &serde_json::Value,
input: I,
) -> GlobalResult<<<I as WorkflowInput>::Workflow as Workflow>::Output>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
let workflow_id = self.dispatch_tagged_workflow(tags, input).await?;

tokio::time::timeout(
WORKFLOW_TIMEOUT,
self.wait_for_workflow::<I::Workflow>(workflow_id),
)
.await?
}

pub async fn signal<T: Signal + Serialize>(
&self,
workflow_id: Uuid,
Expand All @@ -150,6 +195,28 @@ impl ApiCtx {
Ok(signal_id)
}

pub async fn tagged_signal<T: Signal + Serialize>(
&self,
tags: &serde_json::Value,
input: T,
) -> GlobalResult<Uuid> {
tracing::debug!(name=%T::NAME, ?tags, "dispatching tagged signal");

let signal_id = Uuid::new_v4();

// Serialize input
let input_val = serde_json::to_value(input)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

self.db
.publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;

Ok(signal_id)
}

pub async fn op<I>(
&self,
input: I,
Expand Down
64 changes: 63 additions & 1 deletion lib/chirp-workflow/core/src/ctx/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,33 @@ impl TestCtx {
.map_err(GlobalError::raw)?;

self.db
.dispatch_workflow(self.ray_id, id, &name, input_val)
.dispatch_workflow(self.ray_id, id, &name, None, input_val)
.await
.map_err(GlobalError::raw)?;

tracing::info!(%name, ?id, "workflow dispatched");

Ok(id)
}

pub async fn dispatch_tagged_workflow<I>(&self, tags: &serde_json::Value, input: I) -> GlobalResult<Uuid>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
let name = I::Workflow::NAME;

tracing::debug!(%name, ?tags, ?input, "dispatching tagged workflow");

let id = Uuid::new_v4();

// Serialize input
let input_val = serde_json::to_value(input)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;

self.db
.dispatch_workflow(self.ray_id, id, &name, Some(tags), input_val)
.await
.map_err(GlobalError::raw)?;

Expand Down Expand Up @@ -137,6 +163,20 @@ impl TestCtx {
Ok(output)
}

pub async fn tagged_workflow<I>(
&self,
tags: &serde_json::Value,
input: I,
) -> GlobalResult<<<I as WorkflowInput>::Workflow as Workflow>::Output>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
let workflow_id = self.dispatch_tagged_workflow(tags, input).await?;
let output = self.wait_for_workflow::<I::Workflow>(workflow_id).await?;
Ok(output)
}

pub async fn signal<T: Signal + Serialize>(
&self,
workflow_id: Uuid,
Expand All @@ -159,6 +199,28 @@ impl TestCtx {
Ok(signal_id)
}

pub async fn tagged_signal<T: Signal + Serialize>(
&self,
tags: &serde_json::Value,
input: T,
) -> GlobalResult<Uuid> {
tracing::debug!(name=%T::NAME, ?tags, "dispatching tagged signal");

let signal_id = Uuid::new_v4();

// Serialize input
let input_val = serde_json::to_value(input)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

self.db
.publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;

Ok(signal_id)
}

/// Waits for a workflow to be triggered with a superset of given input. Strictly for tests.
pub fn observe<W: Workflow>(&self, input: serde_json::Value) -> GlobalResult<ObserveHandle> {
// Serialize input
Expand Down
Loading

0 comments on commit 050db41

Please sign in to comment.