Skip to content

Commit

Permalink
feat(workflows): add sleep fn (#1077)
Browse files Browse the repository at this point in the history
<!-- Please make sure there is an issue that this PR is correlated to. -->

## Changes

<!-- If there are frontend changes, please include screenshots. -->
  • Loading branch information
MasterPtato committed Aug 19, 2024
1 parent 0c58f83 commit c477ba9
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 186 deletions.
12 changes: 12 additions & 0 deletions docs/libraries/workflow/GOTCHAS.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,15 @@ ctx
})
.await?;
```

## Nested options with serde

Nested options do not serialize/deserialize consistently with serde.

```rust
Some(Some(1234)) -> "1234" -> Some(Some(1234))
Some(None) -> "null" -> None
None -> "null" -> None
```

Be careful when writing your struct definitions.
65 changes: 62 additions & 3 deletions lib/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ use crate::{
metrics,
registry::RegistryHandle,
signal::Signal,
util::{self, GlobalErrorExt, Location},
util::{
self,
time::{DurationToMillis, TsToMillis},
GlobalErrorExt, Location,
},
workflow::{Workflow, WorkflowInput},
worker,
};

// Time to delay a workflow from retrying after an error
Expand Down Expand Up @@ -211,7 +216,7 @@ impl WorkflowCtx {
}
Err(err) => {
// Retry the workflow if its recoverable
let deadline_ts = if let Some(deadline_ts) = err.backoff() {
let deadline_ts = if let Some(deadline_ts) = err.deadline_ts() {
Some(deadline_ts)
} else if err.is_retryable() {
Some(rivet_util::timestamp::now() + RETRY_TIMEOUT_MS as i64)
Expand Down Expand Up @@ -1251,7 +1256,61 @@ impl WorkflowCtx {
Ok(output)
}

// TODO: sleep_for, sleep_until
pub async fn sleep<T: DurationToMillis>(&mut self, duration: T) -> GlobalResult<()> {
self.sleep_until(rivet_util::timestamp::now() + duration.to_millis()?)
.await
}

pub async fn sleep_until<T: TsToMillis>(&mut self, time: T) -> GlobalResult<()> {
let event = self.relevant_history().nth(self.location_idx);

// Slept before
if let Some(event) = event {
// Validate history is consistent
let Event::Sleep(_) = event else {
return Err(WorkflowError::HistoryDiverged(format!(
"expected {event} at {}, found sleep",
self.loc(),
)))
.map_err(GlobalError::raw);
};

tracing::debug!(name=%self.name, id=%self.workflow_id, "skipping replayed sleep");
}
// Sleep
else {
let ts = time.to_millis()?;

self.db
.commit_workflow_sleep_event(
self.workflow_id,
self.full_location().as_ref(),
ts,
self.loop_location(),
)
.await?;

let duration = ts - rivet_util::timestamp::now();
if duration < 0 {
// No-op
tracing::warn!("tried to sleep for a negative duration");
} else if duration < worker::TICK_INTERVAL.as_millis() as i64 + 1 {
tracing::info!(name=%self.name, id=%self.workflow_id, until_ts=%ts, "sleeping in memory");

// Sleep in memory if duration is shorter than the worker tick
tokio::time::sleep(std::time::Duration::from_millis(duration.try_into()?)).await;
} else {
tracing::info!(name=%self.name, id=%self.workflow_id, until_ts=%ts, "sleeping");

return Err(WorkflowError::Sleep(ts)).map_err(GlobalError::raw);
}
}

// Move to next event
self.location_idx += 1;

Ok(())
}
}

impl WorkflowCtx {
Expand Down
15 changes: 15 additions & 0 deletions lib/chirp-workflow/core/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ pub trait Database: Send {
output: Option<serde_json::Value>,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;

/// Writes a workflow sleep event to history.
async fn commit_workflow_sleep_event(
&self,
from_workflow_id: Uuid,
location: &[usize],
util_ts: i64,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;
}

#[derive(sqlx::FromRow)]
Expand Down Expand Up @@ -266,3 +275,9 @@ pub struct LoopEventRow {
pub output: Option<serde_json::Value>,
pub iteration: i64,
}

#[derive(sqlx::FromRow)]
pub struct SleepEventRow {
pub workflow_id: Uuid,
pub location: Vec<i64>,
}
198 changes: 127 additions & 71 deletions lib/chirp-workflow/core/src/db/pg_nats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ use uuid::Uuid;

use super::{
ActivityEventRow, Database, LoopEventRow, MessageSendEventRow, PulledWorkflow,
PulledWorkflowRow, SignalEventRow, SignalRow, SignalSendEventRow, SubWorkflowEventRow,
WorkflowRow,
PulledWorkflowRow, SignalEventRow, SignalRow, SignalSendEventRow, SleepEventRow,
SubWorkflowEventRow, WorkflowRow,
};
use crate::{
activity::ActivityId,
error::{WorkflowError, WorkflowResult},
event::combine_events,
message,
message, worker,
};

/// Max amount of workflows pulled from the database with each call to `pull_workflows`.
Expand Down Expand Up @@ -152,74 +152,84 @@ impl Database for DatabasePgNats {
filter: &[&str],
) -> WorkflowResult<Vec<PulledWorkflow>> {
// Select all workflows that haven't started or that have a wake condition
let workflow_rows = sqlx::query_as::<_, PulledWorkflowRow>(indoc!(
"
WITH
pull_workflows AS (
UPDATE db_workflow.workflows AS w
-- Assign this node to this workflow
SET worker_instance_id = $1
WHERE
-- Filter
workflow_name = ANY($2) AND
-- Not already complete
output IS NULL AND
-- No assigned node (not running)
worker_instance_id IS NULL AND
-- Check for wake condition
(
-- Immediate
wake_immediate OR
-- After deadline
wake_deadline_ts IS NOT NULL OR
-- Signal exists
(
SELECT true
FROM db_workflow.signals AS s
WHERE
s.workflow_id = w.workflow_id AND
s.signal_name = ANY(w.wake_signals) AND
s.ack_ts IS NULL
LIMIT 1
) OR
-- Tagged signal exists
(
SELECT true
FROM db_workflow.tagged_signals AS s
WHERE
s.signal_name = ANY(w.wake_signals) AND
s.tags <@ w.tags AND
s.ack_ts IS NULL
LIMIT 1
) OR
-- Sub workflow completed
(
SELECT true
FROM db_workflow.workflows AS w2
WHERE
w2.workflow_id = w.wake_sub_workflow_id AND
output IS NOT NULL
)
let workflow_rows = self
.query(|| async {
sqlx::query_as::<_, PulledWorkflowRow>(indoc!(
"
WITH
pull_workflows AS (
UPDATE db_workflow.workflows AS w
-- Assign this node to this workflow
SET worker_instance_id = $1
WHERE
-- Filter
workflow_name = ANY($2) AND
-- Not already complete
output IS NULL AND
-- No assigned node (not running)
worker_instance_id IS NULL AND
-- Check for wake condition
(
-- Immediate
wake_immediate OR
-- After deadline
(
wake_deadline_ts IS NOT NULL AND
$3 > wake_deadline_ts - $4
) OR
-- Signal exists
(
SELECT true
FROM db_workflow.signals AS s
WHERE
s.workflow_id = w.workflow_id AND
s.signal_name = ANY(w.wake_signals) AND
s.ack_ts IS NULL
LIMIT 1
) OR
-- Tagged signal exists
(
SELECT true
FROM db_workflow.tagged_signals AS s
WHERE
s.signal_name = ANY(w.wake_signals) AND
s.tags <@ w.tags AND
s.ack_ts IS NULL
LIMIT 1
) OR
-- Sub workflow completed
(
SELECT true
FROM db_workflow.workflows AS w2
WHERE
w2.workflow_id = w.wake_sub_workflow_id AND
output IS NOT NULL
)
)
LIMIT $5
RETURNING workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts
),
-- Update last ping
worker_instance_update AS (
UPSERT INTO db_workflow.worker_instances (worker_instance_id, last_ping_ts)
VALUES ($1, $3)
RETURNING 1
)
LIMIT $4
RETURNING workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts
),
-- Update last ping
worker_instance_update AS (
UPSERT INTO db_workflow.worker_instances (worker_instance_id, last_ping_ts)
VALUES ($1, $3)
RETURNING 1
)
SELECT * FROM pull_workflows
",
))
.bind(worker_instance_id)
.bind(filter)
.bind(rivet_util::timestamp::now())
.bind(MAX_PULLED_WORKFLOWS)
.fetch_all(&mut *self.conn().await?)
.await
.map_err(WorkflowError::Sqlx)?;
SELECT * FROM pull_workflows
",
))
.bind(worker_instance_id)
.bind(filter)
.bind(rivet_util::timestamp::now())
// Add padding to the tick interval so that the workflow deadline is never passed before its pulled.
// The worker sleeps internally to handle this
.bind(worker::TICK_INTERVAL.as_millis() as i64 + 1)
.bind(MAX_PULLED_WORKFLOWS)
.fetch_all(&mut *self.conn().await?)
.await
.map_err(WorkflowError::Sqlx)
})
.await?;

if workflow_rows.is_empty() {
return Ok(Vec::new());
Expand All @@ -240,6 +250,7 @@ impl Database for DatabasePgNats {
msg_send_events,
sub_workflow_events,
loop_events,
sleep_events,
) = tokio::try_join!(
async {
sqlx::query_as::<_, ActivityEventRow>(indoc!(
Expand Down Expand Up @@ -347,6 +358,21 @@ impl Database for DatabasePgNats {
.await
.map_err(WorkflowError::Sqlx)
},
async {
sqlx::query_as::<_, SleepEventRow>(indoc!(
"
SELECT
workflow_id, location
FROM db_workflow.workflow_sleep_events
WHERE workflow_id = ANY($1) AND forgotten = FALSE
ORDER BY workflow_id, location ASC
",
))
.bind(&workflow_ids)
.fetch_all(&mut *self.conn().await?)
.await
.map_err(WorkflowError::Sqlx)
},
)?;

let workflows = combine_events(
Expand All @@ -357,6 +383,7 @@ impl Database for DatabasePgNats {
msg_send_events,
sub_workflow_events,
loop_events,
sleep_events,
)?;

Ok(workflows)
Expand Down Expand Up @@ -397,7 +424,6 @@ impl Database for DatabasePgNats {
wake_sub_workflow_id: Option<Uuid>,
error: &str,
) -> WorkflowResult<()> {
// TODO(RVT-3762): Should this compare `wake_deadline_ts` before setting it?
self.query(|| async {
sqlx::query(indoc!(
"
Expand Down Expand Up @@ -1017,4 +1043,34 @@ impl Database for DatabasePgNats {

Ok(())
}

async fn commit_workflow_sleep_event(
&self,
from_workflow_id: Uuid,
location: &[usize],
until_ts: i64,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()> {
self.query(|| async {
sqlx::query(indoc!(
"
INSERT INTO db_workflow.workflow_sleep_events(
workflow_id, location, until_ts, loop_location
)
VALUES($1, $2, $3, $4)
RETURNING 1
",
))
.bind(from_workflow_id)
.bind(location.iter().map(|x| *x as i64).collect::<Vec<_>>())
.bind(until_ts)
.bind(loop_location.map(|l| l.iter().map(|x| *x as i64).collect::<Vec<_>>()))
.execute(&mut *self.conn().await?)
.await
.map_err(WorkflowError::Sqlx)
})
.await?;

Ok(())
}
}
Loading

0 comments on commit c477ba9

Please sign in to comment.