Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion benches/common/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use durable::{Task, TaskContext, TaskResult, async_trait};
use durable::{SpawnOptions, Task, TaskContext, TaskHandle, TaskResult, async_trait};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;

Expand Down Expand Up @@ -136,3 +136,115 @@ impl Task<()> for LargePayloadBenchTask {
Ok(params.payload_size)
}
}

// ============================================================================
// Hierarchical Tasks - Parent -> Child -> Grandchild for stress testing
// ============================================================================

#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParentParams {
pub num_children: u32,
pub grandchildren_per_child: u32,
}

#[allow(dead_code)]
#[derive(Default)]
pub struct ParentTask;

#[async_trait]
impl Task<()> for ParentTask {
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("bench-parent")
}
type Params = ParentParams;
type Output = u32;

async fn run(
&self,
params: Self::Params,
mut ctx: TaskContext,
_state: (),
) -> TaskResult<Self::Output> {
let mut handles = Vec::new();
for i in 0..params.num_children {
let handle: TaskHandle<u32> = ctx
.spawn::<ChildTask>(
&format!("child-{}", i),
ChildParams {
num_grandchildren: params.grandchildren_per_child,
},
SpawnOptions::default(),
)
.await?;
handles.push(handle);
}

let mut total = 0u32;
for handle in handles {
total += ctx.join(handle).await?;
}
Ok(total)
}
}

#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChildParams {
pub num_grandchildren: u32,
}

#[allow(dead_code)]
#[derive(Default)]
pub struct ChildTask;

#[async_trait]
impl Task<()> for ChildTask {
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("bench-child")
}
type Params = ChildParams;
type Output = u32;

async fn run(
&self,
params: Self::Params,
mut ctx: TaskContext,
_state: (),
) -> TaskResult<Self::Output> {
let mut handles = Vec::new();
for i in 0..params.num_grandchildren {
let handle: TaskHandle<()> = ctx
.spawn::<GrandchildTask>(&format!("grandchild-{}", i), (), SpawnOptions::default())
.await?;
handles.push(handle);
}

for handle in handles {
ctx.join(handle).await?;
}
Ok(params.num_grandchildren)
}
}

#[allow(dead_code)]
#[derive(Default)]
pub struct GrandchildTask;

#[async_trait]
impl Task<()> for GrandchildTask {
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("bench-grandchild")
}
type Params = ();
type Output = ();

async fn run(
&self,
_params: Self::Params,
_ctx: TaskContext,
_state: (),
) -> TaskResult<Self::Output> {
Ok(())
}
}
91 changes: 63 additions & 28 deletions sql/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -1051,45 +1051,67 @@ create function durable.set_task_checkpoint_state (
as $$
declare
v_now timestamptz := durable.current_time();
v_new_attempt integer;
v_task_state text;
v_attempt integer;
begin
if p_step_name is null or length(trim(p_step_name)) = 0 then
raise exception 'step_name must be provided';
end if;

-- get the current attempt number and task state
-- Lock task row (consistent task -> run lock order)
execute format(
'select r.attempt, t.state
from durable.%I r
join durable.%I t on t.task_id = r.task_id
where r.run_id = $1',
'r_' || p_queue_name,
'select state from durable.%I where task_id = $1 for update',
't_' || p_queue_name
)
into v_new_attempt, v_task_state
using p_owner_run;
into v_task_state
using p_task_id;

if v_new_attempt is null then
raise exception 'Run "%" not found for checkpoint', p_owner_run;
if v_task_state is null then
raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name;
end if;

-- if the task was cancelled raise a special error the caller can catch to terminate
if v_task_state = 'cancelled' then
raise exception sqlstate 'AB001' using message = 'Task has been cancelled';
end if;

-- Extend the claim if requested
-- Validate lease and lock run row by conditionally updating it.
if p_extend_claim_by is not null and p_extend_claim_by > 0 then
execute format(
'update durable.%I
set claim_expires_at = $2 + make_interval(secs => $3)
set claim_expires_at = $3 + make_interval(secs => $4)
where run_id = $1
and task_id = $2
and state = ''running''
and claim_expires_at is not null
and claim_expires_at > $3
returning attempt',
'r_' || p_queue_name
)
into v_attempt
using p_owner_run, p_task_id, v_now, p_extend_claim_by;
else
-- Touch row to lock it + validate lease even when not extending.
-- If the run has been cancelled then this row's state will be set to
-- 'failed' and this check will return null
execute format(
'update durable.%I
set claim_expires_at = claim_expires_at
where run_id = $1
and task_id = $2
and state = ''running''
and claim_expires_at is not null',
and claim_expires_at is not null
and claim_expires_at > $3
returning attempt',
'r_' || p_queue_name
)
using p_owner_run, v_now, p_extend_claim_by;
into v_attempt
using p_owner_run, p_task_id, v_now;
end if;

-- If the check above returns null then we shouldn't be running it anymore.
if v_attempt is null then
raise exception sqlstate 'AB002' using message = 'Task lease expired';
end if;

execute format(
Expand All @@ -1108,7 +1130,8 @@ begin
'c_' || p_queue_name,
'r_' || p_queue_name,
'c_' || p_queue_name
) using p_task_id, p_step_name, p_state, p_owner_run, v_now, v_new_attempt;
)
using p_task_id, p_step_name, p_state, p_owner_run, v_now, v_attempt;
end;
$$;

Expand All @@ -1123,22 +1146,26 @@ create function durable.extend_claim (
as $$
declare
v_now timestamptz := durable.current_time();
v_extend_by integer;
v_claim_timeout integer;
v_rows_updated integer;
v_task_id uuid;
v_task_state text;
v_attempt integer;
begin
-- Lock task row first (consistent task -> run lock order)
execute format(
'select t.state
from durable.%I r
join durable.%I t on t.task_id = r.task_id
where r.run_id = $1',
'r_' || p_queue_name,
't_' || p_queue_name
'select task_id, state
from durable.%I
where task_id = (select task_id from durable.%I where run_id = $1)
for update',
't_' || p_queue_name,
'r_' || p_queue_name
)
into v_task_state
into v_task_id, v_task_state
using p_run_id;

if v_task_state is null then
raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name;
end if;

if v_task_state = 'cancelled' then
raise exception sqlstate 'AB001' using message = 'Task has been cancelled';
end if;
Expand All @@ -1147,11 +1174,19 @@ begin
'update durable.%I
set claim_expires_at = $2 + make_interval(secs => $3)
where run_id = $1
and task_id = $4
and state = ''running''
and claim_expires_at is not null',
and claim_expires_at is not null
and claim_expires_at > $2
returning attempt',
'r_' || p_queue_name
)
using p_run_id, v_now, p_extend_by;
into v_attempt
using p_run_id, v_now, p_extend_by, v_task_id;

if v_attempt is null then
raise exception sqlstate 'AB002' using message = 'Task lease expired';
end if;
end;
$$;

Expand Down
18 changes: 18 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ pub enum ControlFlow {
/// Detected when database operations return error code AB001, indicating
/// the task was cancelled via [`Durable::cancel_task`](crate::Durable::cancel_task).
Cancelled,
/// Task lease expired (claim lost).
///
/// Detected when database operations return error code AB002. Treated as control
/// flow to avoid double-failing runs that were already failed by `claim_task`,
/// and to let the next claim sweep fail the run if it hasn't happened yet.
LeaseExpired,
}

/// Error type for task execution.
Expand All @@ -28,6 +34,7 @@ pub enum ControlFlow {
///
/// - `Control(Suspend)` - Task is waiting; worker does nothing (scheduler will resume it)
/// - `Control(Cancelled)` - Task was cancelled; worker does nothing
/// - `Control(LeaseExpired)` - Task lost its lease; worker stops without failing the run
/// - All other variants - Actual errors; worker records failure and may retry
///
/// # Example
Expand Down Expand Up @@ -218,6 +225,8 @@ impl TaskError {
pub(crate) fn from_sqlx_error(err: sqlx::Error) -> Self {
if is_cancelled_error(&err) {
TaskError::Control(ControlFlow::Cancelled)
} else if is_lease_expired_error(&err) {
TaskError::Control(ControlFlow::LeaseExpired)
} else {
TaskError::Database(err)
}
Expand All @@ -233,6 +242,15 @@ pub fn is_cancelled_error(err: &sqlx::Error) -> bool {
}
}

/// Check if a sqlx error indicates lease expiration (error code AB002)
pub fn is_lease_expired_error(err: &sqlx::Error) -> bool {
if let sqlx::Error::Database(db_err) = err {
db_err.code().is_some_and(|c| c == "AB002")
} else {
false
}
}

/// Serialize a TaskError for storage in fail_run
pub fn serialize_task_error(err: &TaskError) -> JsonValue {
match err {
Expand Down
Loading