Skip to content

Commit

Permalink
feat(workflows): add api ctx for workflows (#865)
Browse files Browse the repository at this point in the history
<!-- Please make sure there is an issue that this PR is correlated to. -->

## Changes

<!-- If there are frontend changes, please include screenshots. -->
  • Loading branch information
MasterPtato committed Jun 10, 2024
1 parent 885f8dc commit 1a468d3
Show file tree
Hide file tree
Showing 18 changed files with 358 additions and 65 deletions.
1 change: 1 addition & 0 deletions lib/api-helper/build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ macros = []
api-helper-macros = { path = "../macros" }
async-trait = "0.1"
chirp-client = { path = "../../chirp/client" }
chirp-workflow = { path = "../../chirp-workflow/core" }
chrono = "0.4"
formatted-error = { path = "../../formatted-error" }
futures-util = "0.3"
Expand Down
15 changes: 8 additions & 7 deletions lib/api-helper/build/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::net::IpAddr;

use chirp_workflow::ctx::ApiCtx;
use rivet_operation::OperationContext;
use types::rivet::backend;
use url::Url;
Expand All @@ -8,7 +9,7 @@ use crate::auth;

pub struct Ctx<A: auth::ApiAuth> {
pub(crate) auth: A,
pub(crate) op_ctx: OperationContext<()>,
pub(crate) internal_ctx: ApiCtx,
pub(crate) user_agent: Option<String>,
pub(crate) origin: Option<Url>,
pub(crate) remote_address: Option<IpAddr>,
Expand All @@ -22,19 +23,19 @@ impl<A: auth::ApiAuth> Ctx<A> {
}

pub fn op_ctx(&self) -> &OperationContext<()> {
&self.op_ctx
self.internal_ctx.op_ctx()
}

pub fn chirp(&self) -> &chirp_client::Client {
self.op_ctx.chirp()
self.op_ctx().chirp()
}

pub fn cache(&self) -> rivet_cache::RequestConfig {
self.op_ctx.cache()
self.op_ctx().cache()
}

pub fn cache_handle(&self) -> rivet_cache::Cache {
self.op_ctx.cache_handle()
self.op_ctx().cache_handle()
}

pub fn client_info(&self) -> backend::net::ClientInfo {
Expand Down Expand Up @@ -69,9 +70,9 @@ impl<A: auth::ApiAuth> Ctx<A> {
}

impl<A: auth::ApiAuth> std::ops::Deref for Ctx<A> {
type Target = OperationContext<()>;
type Target = ApiCtx;

fn deref(&self) -> &Self::Target {
&self.op_ctx
&self.internal_ctx
}
}
19 changes: 6 additions & 13 deletions lib/api-helper/build/src/macro_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rivet_operation::prelude::util;
use serde::de::DeserializeOwned;
use url::Url;
use uuid::Uuid;
use chirp_workflow::ctx::ApiCtx;

use crate::{
auth::{self, AuthRateLimitCtx},
Expand Down Expand Up @@ -308,12 +309,12 @@ pub async fn __with_ctx<A: auth::ApiAuth + Send>(
// Create connections
let req_id = Uuid::new_v4();
let ts = rivet_util::timestamp::now();
let svc_name = rivet_util::env::chirp_service_name().to_string();
let svc_name = rivet_util::env::chirp_service_name();
let client = shared_client.wrap(
req_id,
ray_id,
vec![chirp_client::TraceEntry {
context_name: svc_name.clone(),
context_name: svc_name.to_string(),
req_id: Some(req_id.into()),
ts,
run_context: match rivet_util::env::run_context() {
Expand All @@ -323,16 +324,8 @@ pub async fn __with_ctx<A: auth::ApiAuth + Send>(
}],
);
let conn = rivet_connection::Connection::new(client, pools.clone(), cache.clone());
let op_ctx = rivet_operation::OperationContext::new(
svc_name,
std::time::Duration::from_secs(60),
conn,
req_id,
ray_id,
ts,
ts,
(),
);
let db = chirp_workflow::compat::db_from_pools(&pools).await?;
let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name);

// Create auth
let rate_limit_ctx = AuthRateLimitCtx {
Expand All @@ -349,7 +342,7 @@ pub async fn __with_ctx<A: auth::ApiAuth + Send>(

Ok(Ctx {
auth,
op_ctx,
internal_ctx,
user_agent,
origin,
remote_address,
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub trait Activity {
const MAX_RETRIES: u32;
const TIMEOUT: std::time::Duration;

async fn run(ctx: &mut ActivityCtx, input: &Self::Input) -> GlobalResult<Self::Output>;
async fn run(ctx: &ActivityCtx, input: &Self::Input) -> GlobalResult<Self::Output>;
}

pub trait ActivityInput: Serialize + DeserializeOwned + Debug + Hash + Send {
Expand Down
32 changes: 23 additions & 9 deletions lib/chirp-workflow/core/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use serde::Serialize;
use uuid::Uuid;

use crate::{
DatabaseHandle, DatabasePostgres, Operation, OperationCtx, OperationInput, Signal, Workflow,
WorkflowError, WorkflowInput,
ctx::api::WORKFLOW_TIMEOUT, DatabaseHandle, DatabasePostgres, Operation, OperationCtx,
OperationInput, Signal, Workflow, WorkflowError, WorkflowInput,
};

pub async fn dispatch_workflow<I, B>(
Expand All @@ -35,7 +35,7 @@ where
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;

db(ctx)
db_from_ctx(ctx)
.await?
.dispatch_workflow(ctx.ray_id(), id, &name, input_val)
.await
Expand All @@ -46,6 +46,8 @@ where
Ok(id)
}

/// Wait for a given workflow to complete.
/// **IMPORTANT:** Has no timeout.
pub async fn wait_for_workflow<W: Workflow, B: Debug + Clone>(
ctx: &rivet_operation::OperationContext<B>,
workflow_id: Uuid,
Expand All @@ -58,7 +60,7 @@ pub async fn wait_for_workflow<W: Workflow, B: Debug + Clone>(
interval.tick().await;

// Check if state finished
let workflow = db(ctx)
let workflow = db_from_ctx(ctx)
.await?
.get_workflow(workflow_id)
.await
Expand All @@ -71,6 +73,7 @@ pub async fn wait_for_workflow<W: Workflow, B: Debug + Clone>(
}
}

/// Dispatch a new workflow and wait for it to complete. Has a 60s timeout.
pub async fn workflow<I, B>(
ctx: &rivet_operation::OperationContext<B>,
input: I,
Expand All @@ -81,8 +84,12 @@ where
B: Debug + Clone,
{
let workflow_id = dispatch_workflow(ctx, input).await?;
let output = wait_for_workflow::<I::Workflow, _>(ctx, workflow_id).await?;
Ok(output)

tokio::time::timeout(
WORKFLOW_TIMEOUT,
wait_for_workflow::<I::Workflow, _>(ctx, workflow_id),
)
.await?
}

pub async fn signal<I: Signal + Serialize, B: Debug + Clone>(
Expand All @@ -103,7 +110,7 @@ pub async fn signal<I: Signal + Serialize, B: Debug + Clone>(
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

db(ctx)
db_from_ctx(ctx)
.await?
.publish_signal(ctx.ray_id(), workflow_id, signal_id, I::NAME, input_val)
.await
Expand All @@ -122,7 +129,7 @@ where
B: Debug + Clone,
{
let mut ctx = OperationCtx::new(
db(ctx).await?,
db_from_ctx(ctx).await?,
ctx.conn(),
ctx.ray_id(),
ctx.req_ts(),
Expand All @@ -137,10 +144,17 @@ where
}

// Get crdb pool as a trait object
async fn db<B: Debug + Clone>(
async fn db_from_ctx<B: Debug + Clone>(
ctx: &rivet_operation::OperationContext<B>,
) -> GlobalResult<DatabaseHandle> {
let crdb = ctx.crdb().await?;

Ok(DatabasePostgres::from_pool(crdb))
}

// Get crdb pool as a trait object
pub async fn db_from_pools(pools: &rivet_pools::Pools) -> GlobalResult<DatabaseHandle> {
let crdb = pools.crdb()?;

Ok(DatabasePostgres::from_pool(crdb))
}
2 changes: 2 additions & 0 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ impl ActivityCtx {
.map_err(WorkflowError::OperationFailure)
.map_err(GlobalError::raw)
}
}

impl ActivityCtx {
pub fn name(&self) -> &str {
self.name
}
Expand Down
Loading

0 comments on commit 1a468d3

Please sign in to comment.