diff --git a/benches/common/tasks.rs b/benches/common/tasks.rs index b3a14c2..5dfd0c5 100644 --- a/benches/common/tasks.rs +++ b/benches/common/tasks.rs @@ -1,4 +1,4 @@ -use durable::{Task, TaskContext, TaskResult, async_trait}; +use durable::{SpawnOptions, Task, TaskContext, TaskHandle, TaskResult, async_trait}; use serde::{Deserialize, Serialize}; use std::borrow::Cow; @@ -136,3 +136,115 @@ impl Task<()> for LargePayloadBenchTask { Ok(params.payload_size) } } + +// ============================================================================ +// Hierarchical Tasks - Parent -> Child -> Grandchild for stress testing +// ============================================================================ + +#[allow(dead_code)] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParentParams { + pub num_children: u32, + pub grandchildren_per_child: u32, +} + +#[allow(dead_code)] +#[derive(Default)] +pub struct ParentTask; + +#[async_trait] +impl Task<()> for ParentTask { + fn name(&self) -> Cow<'static, str> { + Cow::Borrowed("bench-parent") + } + type Params = ParentParams; + type Output = u32; + + async fn run( + &self, + params: Self::Params, + mut ctx: TaskContext, + _state: (), + ) -> TaskResult { + let mut handles = Vec::new(); + for i in 0..params.num_children { + let handle: TaskHandle = ctx + .spawn::( + &format!("child-{}", i), + ChildParams { + num_grandchildren: params.grandchildren_per_child, + }, + SpawnOptions::default(), + ) + .await?; + handles.push(handle); + } + + let mut total = 0u32; + for handle in handles { + total += ctx.join(handle).await?; + } + Ok(total) + } +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChildParams { + pub num_grandchildren: u32, +} + +#[allow(dead_code)] +#[derive(Default)] +pub struct ChildTask; + +#[async_trait] +impl Task<()> for ChildTask { + fn name(&self) -> Cow<'static, str> { + Cow::Borrowed("bench-child") + } + type Params = ChildParams; + type Output = u32; + + async fn run( + &self, + params: Self::Params, + mut ctx: TaskContext, + _state: (), + ) -> TaskResult { + let mut handles = Vec::new(); + for i in 0..params.num_grandchildren { + let handle: TaskHandle<()> = ctx + .spawn::(&format!("grandchild-{}", i), (), SpawnOptions::default()) + .await?; + handles.push(handle); + } + + for handle in handles { + ctx.join(handle).await?; + } + Ok(params.num_grandchildren) + } +} + +#[allow(dead_code)] +#[derive(Default)] +pub struct GrandchildTask; + +#[async_trait] +impl Task<()> for GrandchildTask { + fn name(&self) -> Cow<'static, str> { + Cow::Borrowed("bench-grandchild") + } + type Params = (); + type Output = (); + + async fn run( + &self, + _params: Self::Params, + _ctx: TaskContext, + _state: (), + ) -> TaskResult { + Ok(()) + } +} diff --git a/sql/schema.sql b/sql/schema.sql index a0b8d66..4d06c04 100644 --- a/sql/schema.sql +++ b/sql/schema.sql @@ -1051,27 +1051,23 @@ create function durable.set_task_checkpoint_state ( as $$ declare v_now timestamptz := durable.current_time(); - v_new_attempt integer; v_task_state text; + v_attempt integer; begin if p_step_name is null or length(trim(p_step_name)) = 0 then raise exception 'step_name must be provided'; end if; - -- get the current attempt number and task state + -- Lock task row (consistent task -> run lock order) execute format( - 'select r.attempt, t.state - from durable.%I r - join durable.%I t on t.task_id = r.task_id - where r.run_id = $1', - 'r_' || p_queue_name, + 'select state from durable.%I where task_id = $1 for update', 't_' || p_queue_name ) - into v_new_attempt, v_task_state - using p_owner_run; + into v_task_state + using p_task_id; - if v_new_attempt is null then - raise exception 'Run "%" not found for checkpoint', p_owner_run; + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; end if; -- if the task was cancelled raise a special error the caller can catch to terminate @@ -1079,17 +1075,43 @@ begin raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; end if; - -- Extend the claim if requested + -- Validate lease and lock run row by conditionally updating it. if p_extend_claim_by is not null and p_extend_claim_by > 0 then execute format( 'update durable.%I - set claim_expires_at = $2 + make_interval(secs => $3) + set claim_expires_at = $3 + make_interval(secs => $4) + where run_id = $1 + and task_id = $2 + and state = ''running'' + and claim_expires_at is not null + and claim_expires_at > $3 + returning attempt', + 'r_' || p_queue_name + ) + into v_attempt + using p_owner_run, p_task_id, v_now, p_extend_claim_by; + else + -- Touch row to lock it + validate lease even when not extending. + -- If the run has been cancelled then this row's state will be set to + -- 'failed' and this check will return null + execute format( + 'update durable.%I + set claim_expires_at = claim_expires_at where run_id = $1 + and task_id = $2 and state = ''running'' - and claim_expires_at is not null', + and claim_expires_at is not null + and claim_expires_at > $3 + returning attempt', 'r_' || p_queue_name ) - using p_owner_run, v_now, p_extend_claim_by; + into v_attempt + using p_owner_run, p_task_id, v_now; + end if; + + -- If the check above returns null then we shouldn't be running it anymore. + if v_attempt is null then + raise exception sqlstate 'AB002' using message = 'Task lease expired'; end if; execute format( @@ -1108,7 +1130,8 @@ begin 'c_' || p_queue_name, 'r_' || p_queue_name, 'c_' || p_queue_name - ) using p_task_id, p_step_name, p_state, p_owner_run, v_now, v_new_attempt; + ) + using p_task_id, p_step_name, p_state, p_owner_run, v_now, v_attempt; end; $$; @@ -1123,22 +1146,26 @@ create function durable.extend_claim ( as $$ declare v_now timestamptz := durable.current_time(); - v_extend_by integer; - v_claim_timeout integer; - v_rows_updated integer; + v_task_id uuid; v_task_state text; + v_attempt integer; begin + -- Lock task row first (consistent task -> run lock order) execute format( - 'select t.state - from durable.%I r - join durable.%I t on t.task_id = r.task_id - where r.run_id = $1', - 'r_' || p_queue_name, - 't_' || p_queue_name + 'select task_id, state + from durable.%I + where task_id = (select task_id from durable.%I where run_id = $1) + for update', + 't_' || p_queue_name, + 'r_' || p_queue_name ) - into v_task_state + into v_task_id, v_task_state using p_run_id; + if v_task_state is null then + raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; + end if; + if v_task_state = 'cancelled' then raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; end if; @@ -1147,11 +1174,19 @@ begin 'update durable.%I set claim_expires_at = $2 + make_interval(secs => $3) where run_id = $1 + and task_id = $4 and state = ''running'' - and claim_expires_at is not null', + and claim_expires_at is not null + and claim_expires_at > $2 + returning attempt', 'r_' || p_queue_name ) - using p_run_id, v_now, p_extend_by; + into v_attempt + using p_run_id, v_now, p_extend_by, v_task_id; + + if v_attempt is null then + raise exception sqlstate 'AB002' using message = 'Task lease expired'; + end if; end; $$; diff --git a/src/error.rs b/src/error.rs index b5a69dc..d238bbe 100644 --- a/src/error.rs +++ b/src/error.rs @@ -19,6 +19,12 @@ pub enum ControlFlow { /// Detected when database operations return error code AB001, indicating /// the task was cancelled via [`Durable::cancel_task`](crate::Durable::cancel_task). Cancelled, + /// Task lease expired (claim lost). + /// + /// Detected when database operations return error code AB002. Treated as control + /// flow to avoid double-failing runs that were already failed by `claim_task`, + /// and to let the next claim sweep fail the run if it hasn't happened yet. + LeaseExpired, } /// Error type for task execution. @@ -28,6 +34,7 @@ pub enum ControlFlow { /// /// - `Control(Suspend)` - Task is waiting; worker does nothing (scheduler will resume it) /// - `Control(Cancelled)` - Task was cancelled; worker does nothing +/// - `Control(LeaseExpired)` - Task lost its lease; worker stops without failing the run /// - All other variants - Actual errors; worker records failure and may retry /// /// # Example @@ -218,6 +225,8 @@ impl TaskError { pub(crate) fn from_sqlx_error(err: sqlx::Error) -> Self { if is_cancelled_error(&err) { TaskError::Control(ControlFlow::Cancelled) + } else if is_lease_expired_error(&err) { + TaskError::Control(ControlFlow::LeaseExpired) } else { TaskError::Database(err) } @@ -233,6 +242,15 @@ pub fn is_cancelled_error(err: &sqlx::Error) -> bool { } } +/// Check if a sqlx error indicates lease expiration (error code AB002) +pub fn is_lease_expired_error(err: &sqlx::Error) -> bool { + if let sqlx::Error::Database(db_err) = err { + db_err.code().is_some_and(|c| c == "AB002") + } else { + false + } +} + /// Serialize a TaskError for storage in fail_run pub fn serialize_task_error(err: &TaskError) -> JsonValue { match err { diff --git a/src/postgres/migrations/20260129120000_enforce_lease_on_checkpoint.sql b/src/postgres/migrations/20260129120000_enforce_lease_on_checkpoint.sql new file mode 100644 index 0000000..6842e04 --- /dev/null +++ b/src/postgres/migrations/20260129120000_enforce_lease_on_checkpoint.sql @@ -0,0 +1,152 @@ +-- Enforce lease validity on checkpoint/heartbeat by raising AB002 when expired. + +create or replace function durable.set_task_checkpoint_state ( + p_queue_name text, + p_task_id uuid, + p_step_name text, + p_state jsonb, + p_owner_run uuid, + p_extend_claim_by integer default null -- seconds +) + returns void + language plpgsql +as $$ +declare + v_now timestamptz := durable.current_time(); + v_task_state text; + v_attempt integer; +begin + if p_step_name is null or length(trim(p_step_name)) = 0 then + raise exception 'step_name must be provided'; + end if; + + -- Lock task row (consistent task -> run lock order) + execute format( + 'select state from durable.%I where task_id = $1 for update', + 't_' || p_queue_name + ) + into v_task_state + using p_task_id; + + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; + + -- if the task was cancelled raise a special error the caller can catch to terminate + if v_task_state = 'cancelled' then + raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + end if; + + -- Validate lease and lock run row by conditionally updating it. + if p_extend_claim_by is not null and p_extend_claim_by > 0 then + execute format( + 'update durable.%I + set claim_expires_at = $3 + make_interval(secs => $4) + where run_id = $1 + and task_id = $2 + and state = ''running'' + and claim_expires_at is not null + and claim_expires_at > $3 + returning attempt', + 'r_' || p_queue_name + ) + into v_attempt + using p_owner_run, p_task_id, v_now, p_extend_claim_by; + else + -- Touch row to lock it + validate lease even when not extending. + -- If the run has been cancelled then this row's state will be set to + -- 'failed' and this check will return null + execute format( + 'update durable.%I + set claim_expires_at = claim_expires_at + where run_id = $1 + and task_id = $2 + and state = ''running'' + and claim_expires_at is not null + and claim_expires_at > $3 + returning attempt', + 'r_' || p_queue_name + ) + into v_attempt + using p_owner_run, p_task_id, v_now; + end if; + + -- If the check above returns null then we shouldn't be running it anymore. + if v_attempt is null then + raise exception sqlstate 'AB002' using message = 'Task lease expired'; + end if; + + execute format( + 'insert into durable.%I (task_id, checkpoint_name, state, owner_run_id, updated_at) + values ($1, $2, $3, $4, $5) + on conflict (task_id, checkpoint_name) + do update set state = excluded.state, + owner_run_id = excluded.owner_run_id, + updated_at = excluded.updated_at + where $6 >= coalesce( + (select r.attempt + from durable.%I r + where r.run_id = durable.%I.owner_run_id), + $6 + )', + 'c_' || p_queue_name, + 'r_' || p_queue_name, + 'c_' || p_queue_name + ) + using p_task_id, p_step_name, p_state, p_owner_run, v_now, v_attempt; +end; +$$; + +create or replace function durable.extend_claim ( + p_queue_name text, + p_run_id uuid, + p_extend_by integer +) + returns void + language plpgsql +as $$ +declare + v_now timestamptz := durable.current_time(); + v_task_id uuid; + v_task_state text; + v_attempt integer; +begin + -- Lock task row first (consistent task -> run lock order) + execute format( + 'select task_id, state + from durable.%I + where task_id = (select task_id from durable.%I where run_id = $1) + for update', + 't_' || p_queue_name, + 'r_' || p_queue_name + ) + into v_task_id, v_task_state + using p_run_id; + + if v_task_state is null then + raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; + end if; + + if v_task_state = 'cancelled' then + raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + end if; + + execute format( + 'update durable.%I + set claim_expires_at = $2 + make_interval(secs => $3) + where run_id = $1 + and task_id = $4 + and state = ''running'' + and claim_expires_at is not null + and claim_expires_at > $2 + returning attempt', + 'r_' || p_queue_name + ) + into v_attempt + using p_run_id, v_now, p_extend_by, v_task_id; + + if v_attempt is null then + raise exception sqlstate 'AB002' using message = 'Task lease expired'; + end if; +end; +$$; diff --git a/src/worker.rs b/src/worker.rs index 0da6563..2fa7695 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -498,6 +498,14 @@ impl Worker { } tracing::info!("Task {} was cancelled", task_label); } + Err(TaskError::Control(ControlFlow::LeaseExpired)) => { + // Lease expired - stop execution without double-failing the run. + #[cfg(feature = "telemetry")] + { + outcome = "lease_expired"; + } + tracing::warn!("Task {} lease expired", task_label); + } Err(ref e) => { // All other errors are failures (Timeout, Database, Serialization, ChildFailed, etc.) #[cfg(feature = "telemetry")] diff --git a/tests/common/tasks.rs b/tests/common/tasks.rs index 9b84b7c..2005b1b 100644 --- a/tests/common/tasks.rs +++ b/tests/common/tasks.rs @@ -1057,6 +1057,43 @@ impl Task<()> for SlowNoHeartbeatTask { } } +// ============================================================================ +// SleepThenCheckpointTask - Sleep past lease then attempt a checkpoint +// ============================================================================ + +#[allow(dead_code)] +#[derive(Default)] +pub struct SleepThenCheckpointTask; + +#[allow(dead_code)] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SleepThenCheckpointParams { + /// Duration to sleep in milliseconds (should be > claim_timeout) + pub sleep_ms: u64, +} + +#[async_trait] +impl Task<()> for SleepThenCheckpointTask { + fn name(&self) -> Cow<'static, str> { + Cow::Borrowed("sleep-then-checkpoint") + } + type Params = SleepThenCheckpointParams; + type Output = String; + + async fn run( + &self, + params: Self::Params, + mut ctx: TaskContext, + _state: (), + ) -> TaskResult { + tokio::time::sleep(std::time::Duration::from_millis(params.sleep_ms)).await; + let _: String = ctx + .step("after-sleep", (), |_, _| async { Ok("ok".to_string()) }) + .await?; + Ok("done".to_string()) + } +} + // ============================================================================ // CheckpointReplayTask - Tracks execution count via external counter // ============================================================================ diff --git a/tests/lease_test.rs b/tests/lease_test.rs index 3cbd81b..7e0a36d 100644 --- a/tests/lease_test.rs +++ b/tests/lease_test.rs @@ -3,10 +3,14 @@ mod common; use common::helpers::{ - get_claim_expires_at, get_last_run_id, get_task_state, set_fake_time, wait_for_task_terminal, + get_checkpoint_count, get_claim_expires_at, get_last_run_id, get_task_state, set_fake_time, + wait_for_task_state, wait_for_task_terminal, }; -use common::tasks::{LongRunningHeartbeatParams, LongRunningHeartbeatTask}; -use durable::{Durable, MIGRATOR, WorkerOptions}; +use common::tasks::{ + LongRunningHeartbeatParams, LongRunningHeartbeatTask, SleepThenCheckpointParams, + SleepThenCheckpointTask, +}; +use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; use sqlx::PgPool; use std::time::Duration; @@ -228,6 +232,157 @@ async fn test_checkpoint_extends_lease(pool: PgPool) -> sqlx::Result<()> { Ok(()) } +/// Test that checkpoints are rejected once the lease has expired. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_checkpoint_rejected_after_lease_expired(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lease_expired_ckpt").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let claim_timeout = Duration::from_secs(1); + + let spawn_result = client + .spawn_with_options::( + SleepThenCheckpointParams { sleep_ms: 1500 }, + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { + base_delay: Duration::from_secs(0), + }); + opts.max_attempts = Some(1); + opts + }, + ) + .await + .expect("Failed to spawn task"); + + let worker1 = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout, + ..Default::default() + }) + .await + .unwrap(); + + let running = wait_for_task_state( + &pool, + "lease_expired_ckpt", + spawn_result.task_id, + "running", + Duration::from_secs(2), + ) + .await?; + assert!(running, "Task should enter running state"); + + // Wait for lease to expire but before the task wakes to checkpoint. + tokio::time::sleep(claim_timeout + Duration::from_millis(200)).await; + + // Second worker polls to fail expired lease. + let worker2 = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(5), + ..Default::default() + }) + .await + .unwrap(); + + let terminal = wait_for_task_terminal( + &pool, + "lease_expired_ckpt", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + assert_eq!(terminal, Some("failed".to_string())); + + let checkpoint_count = + get_checkpoint_count(&pool, "lease_expired_ckpt", spawn_result.task_id).await?; + assert_eq!( + checkpoint_count, 0, + "Checkpoint should not be written after lease expiry" + ); + + worker2.shutdown().await; + worker1.shutdown().await; + + Ok(()) +} + +/// Test that checkpoints are rejected once the lease has expired - single worker variant. +/// Unlike the multi-worker test, this verifies that a single worker's poll loop +/// can detect and fail the expired run after the task stops due to lease expiration. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_checkpoint_rejected_after_lease_expired_single_worker( + pool: PgPool, +) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lease_expired_single").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let claim_timeout = Duration::from_secs(1); + + let spawn_result = client + .spawn_with_options::( + SleepThenCheckpointParams { sleep_ms: 1500 }, + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { + base_delay: Duration::from_secs(0), + }); + opts.max_attempts = Some(1); + opts + }, + ) + .await + .expect("Failed to spawn task"); + + // Single worker handles both running the task and detecting the expired lease + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout, + ..Default::default() + }) + .await + .unwrap(); + + let running = wait_for_task_state( + &pool, + "lease_expired_single", + spawn_result.task_id, + "running", + Duration::from_secs(2), + ) + .await?; + assert!(running, "Task should enter running state"); + + // Wait for: + // 1. Lease to expire (1s) + // 2. Task to wake and attempt checkpoint (at 1.5s) - gets AB002 error + // 3. Worker's next poll to claim and fail the expired run + let terminal = wait_for_task_terminal( + &pool, + "lease_expired_single", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + assert_eq!(terminal, Some("failed".to_string())); + + let checkpoint_count = + get_checkpoint_count(&pool, "lease_expired_single", spawn_result.task_id).await?; + assert_eq!( + checkpoint_count, 0, + "Checkpoint should not be written after lease expiry" + ); + + worker.shutdown().await; + + Ok(()) +} + /// Test that heartbeat detects if the task has been cancelled. #[sqlx::test(migrator = "MIGRATOR")] async fn test_heartbeat_detects_cancellation(pool: PgPool) -> sqlx::Result<()> {