Skip to content

Commit

Permalink
fix(workflows): allow op ctx to do all the things
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Feb 1, 2025
1 parent 9883dbd commit d541a3f
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ use serde::Serialize;

use crate::{builder::BuilderError, ctx::MessageCtx, message::Message, metrics};

pub struct MessageBuilder<'a, M: Message> {
msg_ctx: &'a MessageCtx,
pub struct MessageBuilder<M: Message> {
msg_ctx: MessageCtx,
body: M,
tags: serde_json::Map<String, serde_json::Value>,
wait: bool,
error: Option<BuilderError>,
}

impl<'a, M: Message> MessageBuilder<'a, M> {
pub(crate) fn new(msg_ctx: &'a MessageCtx, body: M) -> Self {
impl<M: Message> MessageBuilder<M> {
pub(crate) fn new(msg_ctx: MessageCtx, body: M) -> Self {
MessageBuilder {
msg_ctx,
body,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ pub struct SignalBuilder<T: Signal + Serialize> {
}

impl<T: Signal + Serialize> SignalBuilder<T> {
pub(crate) fn new(db: DatabaseHandle, ray_id: Uuid, body: T) -> Self {
pub(crate) fn new(db: DatabaseHandle, ray_id: Uuid, body: T, from_workflow: bool) -> Self {
SignalBuilder {
db,
ray_id,
body,
to_workflow_name: None,
to_workflow_id: None,
tags: serde_json::Map::new(),
error: None,
error: from_workflow.then_some(BuilderError::CannotDispatchFromOpInWorkflow),
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ impl<I: WorkflowInput> WorkflowBuilder<I>
where
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
pub(crate) fn new(db: DatabaseHandle, ray_id: Uuid, input: I) -> Self {
pub(crate) fn new(db: DatabaseHandle, ray_id: Uuid, input: I, from_workflow: bool) -> Self {
WorkflowBuilder {
db,
ray_id,
input,
tags: serde_json::Map::new(),
unique: false,
error: None,
error: from_workflow.then_some(BuilderError::CannotDispatchFromOpInWorkflow),
}
}

Expand Down
27 changes: 17 additions & 10 deletions packages/common/chirp-workflow/core/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use uuid::Uuid;

use crate::{
builder::common as builder,
builder::BuilderError,
ctx::{
common,
message::{MessageCtx, SubscriptionHandle},
Expand Down Expand Up @@ -40,16 +39,13 @@ where
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
B: Debug + Clone,
{
if ctx.from_workflow {
return Err(BuilderError::CannotDispatchFromOpInWorkflow.into());
}

let db = db_from_ctx(ctx).await?;

Ok(builder::workflow::WorkflowBuilder::new(
db,
ctx.ray_id(),
input,
ctx.from_workflow,
))
}

Expand All @@ -58,13 +54,24 @@ pub async fn signal<T: Signal + Serialize, B: Debug + Clone>(
ctx: &rivet_operation::OperationContext<B>,
body: T,
) -> GlobalResult<builder::signal::SignalBuilder<T>> {
if ctx.from_workflow {
return Err(BuilderError::CannotDispatchFromOpInWorkflow.into());
}

let db = db_from_ctx(ctx).await?;

Ok(builder::signal::SignalBuilder::new(db, ctx.ray_id(), body))
Ok(builder::signal::SignalBuilder::new(
db,
ctx.ray_id(),
body,
ctx.from_workflow,
))
}

/// Creates a message builder.
pub async fn msg<M: Message, B: Debug + Clone>(
ctx: &rivet_operation::OperationContext<B>,
body: M,
) -> GlobalResult<builder::message::MessageBuilder<M>> {
let msg_ctx = MessageCtx::new(ctx.conn(), ctx.ray_id()).await?;

Ok(builder::message::MessageBuilder::new(msg_ctx, body))
}

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
Expand Down
11 changes: 4 additions & 7 deletions packages/common/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ impl ApiCtx {
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input)
builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input, false)
}

/// Creates a signal builder.
pub fn signal<T: Signal + Serialize>(&self, body: T) -> builder::signal::SignalBuilder<T> {
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body)
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body, false)
}

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
Expand All @@ -117,11 +117,8 @@ impl ApiCtx {
}

/// Creates a message builder.
pub fn msg<M>(&self, body: M) -> builder::message::MessageBuilder<M>
where
M: Message,
{
builder::message::MessageBuilder::new(&self.msg_ctx, body)
pub fn msg<M: Message>(&self, body: M) -> builder::message::MessageBuilder<M> {
builder::message::MessageBuilder::new(self.msg_ctx.clone(), body)
}

pub async fn subscribe<M>(&self, tags: impl AsTags) -> GlobalResult<SubscriptionHandle<M>>
Expand Down
4 changes: 3 additions & 1 deletion packages/common/chirp-workflow/core/src/ctx/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ where
req_ts,
from_workflow,
I::Operation::NAME,
);
)
.await
.map_err(GlobalError::raw)?;

let res = tokio::time::timeout(I::Operation::TIMEOUT, I::Operation::run(&ctx, &input))
.await
Expand Down
98 changes: 88 additions & 10 deletions packages/common/chirp-workflow/core/src/ctx/operation.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
use global_error::GlobalResult;
use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use serde::Serialize;
use uuid::Uuid;

use crate::{
builder::common as builder,
ctx::common,
ctx::{
common,
message::{SubscriptionHandle, TailAnchor, TailAnchorResponse},
MessageCtx,
},
db::DatabaseHandle,
error::WorkflowResult,
message::{AsTags, Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
workflow::{Workflow, WorkflowInput},
};

#[derive(Clone)]
Expand All @@ -21,21 +28,22 @@ pub struct OperationCtx {

config: rivet_config::Config,
conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl OperationCtx {
pub fn new(
pub async fn new(
db: DatabaseHandle,
config: &rivet_config::Config,
conn: &rivet_connection::Connection,
ray_id: Uuid,
req_ts: i64,
from_workflow: bool,
name: &'static str,
) -> Self {
) -> WorkflowResult<Self> {
let ts = rivet_util::timestamp::now();
let req_id = Uuid::new_v4();
let conn = conn.wrap(req_id, ray_id, name);
Expand All @@ -52,19 +60,56 @@ impl OperationCtx {
);
op_ctx.from_workflow = from_workflow;

OperationCtx {
let msg_ctx = MessageCtx::new(&conn, ray_id).await?;

Ok(OperationCtx {
ray_id,
name,
ts,
db,
config: config.clone(),
conn,
op_ctx,
}
msg_ctx,
})
}
}

impl OperationCtx {
/// Wait for a given workflow to complete.
/// 60 second timeout.
pub async fn wait_for_workflow<W: Workflow>(
&self,
workflow_id: Uuid,
) -> GlobalResult<W::Output> {
common::wait_for_workflow::<W>(&self.db, workflow_id).await
}

/// Creates a workflow builder.
pub fn workflow<I>(&self, input: I) -> builder::workflow::WorkflowBuilder<I>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
builder::workflow::WorkflowBuilder::new(
self.db.clone(),
self.ray_id,
input,
self.op_ctx.from_workflow,
)
}

/// Creates a signal builder.
pub fn signal<T: Signal + Serialize>(&self, body: T) -> builder::signal::SignalBuilder<T> {
// TODO: Add check for from_workflow so you cant dispatch a signal
builder::signal::SignalBuilder::new(
self.db.clone(),
self.ray_id,
body,
self.op_ctx.from_workflow,
)
}

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
pub async fn op<I>(
&self,
Expand All @@ -86,10 +131,43 @@ impl OperationCtx {
.await
}

/// Creates a signal builder.
pub fn signal<T: Signal + Serialize>(&self, body: T) -> builder::signal::SignalBuilder<T> {
// TODO: Add check for from_workflow so you cant dispatch a signal
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body)
/// Creates a message builder.
pub fn msg<M: Message>(&self, body: M) -> builder::message::MessageBuilder<M> {
builder::message::MessageBuilder::new(self.msg_ctx.clone(), body)
}

pub async fn subscribe<M>(&self, tags: impl AsTags) -> GlobalResult<SubscriptionHandle<M>>
where
M: Message,
{
self.msg_ctx
.subscribe::<M>(tags)
.await
.map_err(GlobalError::raw)
}

pub async fn tail_read<M>(&self, tags: impl AsTags) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
self.msg_ctx
.tail_read::<M>(tags)
.await
.map_err(GlobalError::raw)
}

pub async fn tail_anchor<M>(
&self,
tags: impl AsTags,
anchor: &TailAnchor,
) -> GlobalResult<TailAnchorResponse<M>>
where
M: Message,
{
self.msg_ctx
.tail_anchor::<M>(tags, anchor)
.await
.map_err(GlobalError::raw)
}
}

Expand Down
11 changes: 4 additions & 7 deletions packages/common/chirp-workflow/core/src/ctx/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ impl StandaloneCtx {
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input)
builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input, false)
}

/// Creates a signal builder.
pub fn signal<T: Signal + Serialize>(&self, body: T) -> builder::signal::SignalBuilder<T> {
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body)
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body, false)
}

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
Expand All @@ -118,11 +118,8 @@ impl StandaloneCtx {
}

/// Creates a message builder.
pub fn msg<M>(&mut self, body: M) -> builder::message::MessageBuilder<M>
where
M: Message,
{
builder::message::MessageBuilder::new(&self.msg_ctx, body)
pub fn msg<M: Message>(&self, body: M) -> builder::message::MessageBuilder<M> {
builder::message::MessageBuilder::new(self.msg_ctx.clone(), body)
}

pub async fn subscribe<M>(&self, tags: impl AsTags) -> GlobalResult<SubscriptionHandle<M>>
Expand Down
11 changes: 4 additions & 7 deletions packages/common/chirp-workflow/core/src/ctx/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ impl TestCtx {
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input)
builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input, false)
}

/// Creates a signal builder.
pub fn signal<T: Signal + Serialize>(&self, body: T) -> builder::signal::SignalBuilder<T> {
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body)
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body, false)
}

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
Expand All @@ -141,11 +141,8 @@ impl TestCtx {
.await
}

pub fn msg<M>(&self, body: M) -> builder::message::MessageBuilder<M>
where
M: Message,
{
builder::message::MessageBuilder::new(&self.msg_ctx, body)
pub fn msg<M: Message>(&self, body: M) -> builder::message::MessageBuilder<M> {
builder::message::MessageBuilder::new(self.msg_ctx.clone(), body)
}

pub async fn subscribe<M>(&self, tags: impl AsTags) -> GlobalResult<SubscriptionHandle<M>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,8 @@ impl<'a> VersionedWorkflowCtx<'a> {
})
}

/// Creates a signal builder.
pub fn msg<M>(&mut self, body: M) -> builder::message::MessageBuilder<M>
where
M: Message,
{
/// Creates a message builder.
pub fn msg<M: Message>(&mut self, body: M) -> builder::message::MessageBuilder<M> {
builder::message::MessageBuilder::new(self.inner, self.version(), body)
}

Expand Down
5 changes: 1 addition & 4 deletions packages/common/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,7 @@ impl WorkflowCtx {
}

/// Creates a message builder.
pub fn msg<M>(&mut self, body: M) -> builder::message::MessageBuilder<M>
where
M: Message,
{
pub fn msg<M: Message>(&mut self, body: M) -> builder::message::MessageBuilder<M> {
builder::message::MessageBuilder::new(self, self.version, body)
}

Expand Down

0 comments on commit d541a3f

Please sign in to comment.