diff --git a/contributing/PIPELINES.md b/contributing/PIPELINES.md index e41949b4f..7e01c2465 100644 --- a/contributing/PIPELINES.md +++ b/contributing/PIPELINES.md @@ -92,6 +92,10 @@ It's ok not to force all pipelines into one exact shape. When writing processing results, update the main row with a filter by both `id` and `lock_token`. This guarantees that only the worker that still owns the lock can apply its results. If the update affects no rows, treat the item as stale and skip applying other changes (status changes, related updates, events). A stale item means another worker or replica already continued processing. +**Locking related resources before refetch** + +If you first refetch a main resource and only after lock the related resources, you need to ensure the worker doesn't get the stale view on related resources or works properly even in this case. It's often more robust to first lock related resources and then refetch the main resource with related resources already locked. + **Locking many related resources** A pipeline may need to lock a potentially big set of related resource, e.g. fleet pipeline locking all fleet's instances. For this, do one SELECT FOR UPDATE of non-locked instances and one SELECT to see how many instances there are, and check if you managed to lock all of them. If fail to lock, release the main lock and try processing on another fetch iteration. You may keep `lock_owner` on the main resource or set `lock_owner` on locked related resource and make other pipelines respect that to guarantee the eventual locking of all related resources and avoid lock starvation. diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index a48aabbad..f0b9c7be3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -15,6 +15,7 @@ from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, ) +from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline from dstack._internal.server.background.pipeline_tasks.volumes import VolumePipeline from dstack._internal.utils.logging import get_logger @@ -32,6 +33,7 @@ def __init__(self) -> None: JobTerminatingPipeline(), InstancePipeline(), PlacementGroupPipeline(), + RunPipeline(), VolumePipeline(), ] self._hinter = PipelineHinter(self._pipelines) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 51f4230a8..4ba94a103 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -62,7 +62,7 @@ def __init__( workers_num: int = 10, queue_lower_limit_factor: float = 0.5, queue_upper_limit_factor: float = 2.0, - min_processing_interval: timedelta = timedelta(seconds=60), + min_processing_interval: timedelta = timedelta(seconds=30), lock_timeout: timedelta = timedelta(seconds=20), heartbeat_trigger: timedelta = timedelta(seconds=10), ) -> None: @@ -199,19 +199,13 @@ async def process(self, item: PipelineItem): process_context = await _load_process_context(item) if process_context is None: return - result = await _process_fleet( - process_context.fleet_model, - consolidation_fleet_spec=process_context.consolidation_fleet_spec, - consolidation_instances=process_context.consolidation_instances, - ) + result = await _process_fleet(process_context.fleet_model) await _apply_process_result(item, process_context, result) @dataclass class _ProcessContext: fleet_model: FleetModel - consolidation_fleet_spec: Optional[FleetSpec] - consolidation_instances: Optional[list[InstanceModel]] locked_instance_ids: set[uuid.UUID] = field(default_factory=set) @@ -260,34 +254,64 @@ def has_changes(self) -> bool: async def _load_process_context(item: PipelineItem) -> Optional[_ProcessContext]: async with get_session_ctx() as session: - fleet_model = await _refetch_locked_fleet(session=session, item=item) + fleet_model = await _refetch_locked_fleet_for_lock_decision(session=session, item=item) if fleet_model is None: log_lock_token_mismatch(logger, item) return None - consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) - consolidation_instances = None - if consolidation_fleet_spec is not None: - consolidation_instances = await _lock_fleet_instances_for_consolidation( - session=session, - item=item, - ) - if consolidation_instances is None: - return None + locked_instance_ids = await _lock_fleet_instances_for_processing( + session=session, + item=item, + fleet_model=fleet_model, + ) + if locked_instance_ids is None: + return None + + fleet_model = await _refetch_locked_fleet_for_processing(session=session, item=item) + if fleet_model is None: + log_lock_token_mismatch(logger, item) + if locked_instance_ids: + await _unlock_fleet_locked_instances( + session=session, + item=item, + locked_instance_ids=locked_instance_ids, + ) + await session.commit() + return None return _ProcessContext( fleet_model=fleet_model, - consolidation_fleet_spec=consolidation_fleet_spec, - consolidation_instances=consolidation_instances, - locked_instance_ids=( - set() - if consolidation_instances is None - else {i.id for i in consolidation_instances} - ), + locked_instance_ids=locked_instance_ids, ) -async def _refetch_locked_fleet( +async def _refetch_locked_fleet_for_lock_decision( + session: AsyncSession, + item: PipelineItem, +) -> Optional[FleetModel]: + res = await session.execute( + select(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .options( + load_only( + FleetModel.id, + FleetModel.status, + FleetModel.spec, + FleetModel.current_master_instance_id, + FleetModel.consolidation_attempt, + FleetModel.last_consolidated_at, + FleetModel.last_processed_at, + ) + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _refetch_locked_fleet_for_processing( session: AsyncSession, item: PipelineItem, ) -> Optional[FleetModel]: @@ -308,6 +332,7 @@ async def _refetch_locked_fleet( FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) ).load_only(RunModel.status) ) + .execution_options(populate_existing=True) ) return res.unique().scalar_one_or_none() @@ -326,10 +351,17 @@ def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optio return consolidation_fleet_spec -async def _lock_fleet_instances_for_consolidation( +async def _lock_fleet_instances_for_processing( session: AsyncSession, item: PipelineItem, -) -> Optional[list[InstanceModel]]: + fleet_model: FleetModel, +) -> Optional[set[uuid.UUID]]: + if _get_fleet_spec_if_ready_for_consolidation(fleet_model) is None: + if fleet_model.current_master_instance_id is None: + return set() + if not _is_cloud_cluster_fleet_spec(get_fleet_spec(fleet_model)): + return set() + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) async with instance_lock: res = await session.execute( @@ -347,6 +379,7 @@ async def _lock_fleet_instances_for_consolidation( ), ) .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) + .options(load_only(InstanceModel.id)) ) locked_instance_models = list(res.scalars().all()) locked_instance_ids = {instance_model.id for instance_model in locked_instance_models} @@ -389,7 +422,7 @@ async def _lock_fleet_instances_for_consolidation( instance_model.lock_token = item.lock_token instance_model.lock_owner = FleetPipeline.__name__ await session.commit() - return locked_instance_models + return locked_instance_ids async def _apply_process_result( @@ -461,30 +494,29 @@ async def _apply_process_result( async def _process_fleet( fleet_model: FleetModel, - consolidation_fleet_spec: Optional[FleetSpec] = None, - consolidation_instances: Optional[Sequence[InstanceModel]] = None, ) -> _ProcessResult: result = _ProcessResult() - effective_instances = list(consolidation_instances or fleet_model.instances) + consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) if consolidation_fleet_spec is not None: result = _consolidate_fleet_state_with_spec( fleet_model, consolidation_fleet_spec=consolidation_fleet_spec, - consolidation_instances=effective_instances, + consolidation_instances=fleet_model.instances, ) if len(result.new_instance_creates) == 0 and _should_delete_fleet(fleet_model): result.fleet_update_map["status"] = FleetStatus.TERMINATED result.fleet_update_map["deleted"] = True result.fleet_update_map["deleted_at"] = NOW_PLACEHOLDER + return result _set_fail_instances_on_master_bootstrap_failure( fleet_model=fleet_model, - instance_models=effective_instances, + instance_models=fleet_model.instances, instance_id_to_update_map=result.instance_id_to_update_map, ) _set_current_master_instance_id( fleet_model=fleet_model, fleet_update_map=result.fleet_update_map, - instance_models=effective_instances, + instance_models=fleet_model.instances, instance_id_to_update_map=result.instance_id_to_update_map, new_instance_creates=result.new_instance_creates, ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 67f017619..8c31cd537 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -40,6 +40,7 @@ ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( + FleetModel, InstanceHealthCheckModel, InstanceModel, JobModel, @@ -147,6 +148,7 @@ async def fetch(self, limit: int) -> list[InstancePipelineItem]: now = get_current_datetime() res = await session.execute( select(InstanceModel) + .join(InstanceModel.fleet, isouter=True) .where( InstanceModel.status.in_( [ @@ -164,6 +166,11 @@ async def fetch(self, limit: int) -> list[InstancePipelineItem]: ) ), InstanceModel.deleted == False, + or_( + # Do not try to lock instances if the fleet is waiting for the lock. + InstanceModel.fleet_id.is_(None), + FleetModel.lock_owner.is_(None), + ), or_( InstanceModel.last_processed_at <= now - self._min_processing_interval, InstanceModel.last_processed_at == InstanceModel.created_at, @@ -239,6 +246,9 @@ async def process(self, item: InstancePipelineItem): if process_context is None: return + # Keep apply centralized here because every instance path returns the same + # `ProcessResult` shape for one primary model, with only a small set of + # optional side effects such as health checks or placement-group scheduling. await _apply_process_result( item=item, instance_model=process_context.instance_model, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 590b9907b..5e061ee77 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -197,6 +197,8 @@ async def fetch(self, limit: int) -> list[JobRunningPipelineItem]: ), RunModel.status.not_in([RunStatus.TERMINATING]), JobModel.last_processed_at <= now - self._min_processing_interval, + # Do not try to lock jobs if the run is waiting for the lock. + RunModel.lock_owner.is_(None), or_( JobModel.lock_expires_at.is_(None), JobModel.lock_expires_at < now, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py index 9e93e7b2a..0d79d4bae 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py @@ -243,6 +243,8 @@ async def fetch(self, limit: int) -> list[JobSubmittedPipelineItem]: JobModel.last_processed_at <= now - self._min_processing_interval, JobModel.last_processed_at == JobModel.submitted_at, ), + # Do not try to lock jobs if the run is waiting for the lock. + RunModel.lock_owner.is_(None), or_( JobModel.lock_expires_at.is_(None), JobModel.lock_expires_at < now, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 5f1a0f74b..2925dd844 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -582,7 +582,7 @@ async def _process_terminating_job( """ Stops the job: tells shim to stop the container, detaches the job from the instance, and detaches volumes from the instance. - Graceful stop should already be done by `process_terminating_run`. + Graceful stop should already be done by the run terminating path. """ instance_update_map = None if instance_model is None else _InstanceUpdateMap() result = _ProcessResult(instance_update_map=instance_update_map) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py new file mode 100644 index 000000000..5c3ec260b --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py @@ -0,0 +1,939 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Optional, Sequence + +from sqlalchemy import and_, func, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only + +import dstack._internal.server.background.pipeline_tasks.runs.active as active +import dstack._internal.server.background.pipeline_tasks.runs.pending as pending +import dstack._internal.server.background.pipeline_tasks.runs.terminating as terminating +from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + Worker, + log_lock_token_changed_after_processing, + log_lock_token_changed_on_reset, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel +from dstack._internal.server.services import events +from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.jobs import emit_job_status_change_event +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.prometheus.client_metrics import run_metrics +from dstack._internal.server.services.runs import emit_run_status_change_event, get_run_spec +from dstack._internal.server.services.secrets import get_project_secrets_mapping +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + +# No need to lock finished or terminating jobs since run processing does not update them. +JOB_STATUSES_EXCLUDED_FOR_LOCKING = JobStatus.finished_statuses() + [JobStatus.TERMINATING] + + +@dataclass +class RunPipelineItem(PipelineItem): + status: RunStatus + + +class RunPipeline(Pipeline[RunPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=10), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[RunPipelineItem]( + model_type=RunModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = RunFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + RunWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return RunModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[RunPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[RunPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["RunWorker"]: + return self.__workers + + +class RunFetcher(Fetcher[RunPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[RunPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[RunPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.RunFetcher.fetch") + async def fetch(self, limit: int) -> list[RunPipelineItem]: + if limit <= 0: + return [] + + run_lock, _ = get_locker(get_db().dialect_name).get_lockset(RunModel.__tablename__) + async with run_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(RunModel) + .where( + # Filter out runs that do not need processing. + # This is only to reduce unnecessary fetch/apply churn. + # Otherwise, we could fetch all active runs and filter them in the worker. + or_( + # Active non-pending runs. + RunModel.status.not_in( + RunStatus.finished_statuses() + [RunStatus.PENDING] + ), + # Retrying runs. + and_( + RunModel.status == RunStatus.PENDING, + RunModel.resubmission_attempt > 0, + ), + # Scheduled ready runs. + and_( + RunModel.status == RunStatus.PENDING, + RunModel.resubmission_attempt == 0, + RunModel.next_triggered_at.is_not(None), + RunModel.next_triggered_at < now, + ), + # Scaled-to-zero runs. + # Such runs cannot be scheduled, so we detect them via + # `next_triggered_at is None`. + # If scheduled services ever support downscaling to zero, + # this selector must be revisited. + and_( + RunModel.status == RunStatus.PENDING, + RunModel.resubmission_attempt == 0, + RunModel.next_triggered_at.is_(None), + ), + ), + or_( + RunModel.last_processed_at <= now - self._min_processing_interval, + RunModel.last_processed_at == RunModel.submitted_at, + ), + or_( + RunModel.lock_expires_at.is_(None), + RunModel.lock_expires_at < now, + ), + or_( + RunModel.lock_owner.is_(None), + RunModel.lock_owner == RunPipeline.__name__, + ), + ) + .order_by(RunModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=RunModel) + .options( + load_only( + RunModel.id, + RunModel.lock_token, + RunModel.lock_expires_at, + RunModel.status, + ) + ) + ) + run_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for run_model in run_models: + prev_lock_expired = run_model.lock_expires_at is not None + run_model.lock_expires_at = lock_expires_at + run_model.lock_token = lock_token + run_model.lock_owner = RunPipeline.__name__ + items.append( + RunPipelineItem( + __tablename__=RunModel.__tablename__, + id=run_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=run_model.status, + ) + ) + await session.commit() + return items + + +class RunWorker(Worker[RunPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[RunPipelineItem], + heartbeater: Heartbeater[RunPipelineItem], + ) -> None: + super().__init__(queue=queue, heartbeater=heartbeater) + + @sentry_utils.instrument_named_task("pipeline_tasks.RunWorker.process") + async def process(self, item: RunPipelineItem): + # Currently `dstack` supports runs with + # * one multi-node replica (multi-node tasks) + # * or multiple single-node replicas (services) + # The multiple multi-node replica is not supported but the most of the processing logic + # is written to be able to handle this generic case. + # + # Different run stats have completely separate load/process/apply phases + # due to distinct processing flows and related-row requirements. + if item.status == RunStatus.PENDING: + await _process_pending_item(item) + return + if item.status in { + RunStatus.SUBMITTED, + RunStatus.PROVISIONING, + RunStatus.RUNNING, + }: + await _process_active_item(item) + return + if item.status == RunStatus.TERMINATING: + await _process_terminating_item(item) + return + + logger.error("Skipping run %s with unexpected status %s", item.id, item.status) + + +async def _process_pending_item(item: RunPipelineItem) -> None: + async with get_session_ctx() as session: + context = await _load_pending_context(session=session, item=item) + if context is None: + return + + result = await pending.process_pending_run(context) + if result is None: + await _apply_noop_result( + item=item, + locked_job_ids=context.locked_job_ids, + ) + return + + await _apply_pending_result(item=item, context=context, result=result) + + +async def _load_pending_context( + session: AsyncSession, + item: RunPipelineItem, +) -> Optional[pending.PendingContext]: + locked_job_ids = await _lock_related_jobs(session=session, item=item) + if locked_job_ids is None: + return None + run_model = await _refetch_locked_run_for_pending(session=session, item=item) + if run_model is None: + log_lock_token_mismatch(logger, item) + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=locked_job_ids, + ) + await session.commit() + return None + secrets = await get_project_secrets_mapping(session=session, project=run_model.project) + run_spec = get_run_spec(run_model) + + gateway_stats = None + if run_spec.configuration.type == "service" and run_model.gateway_id is not None: + _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name) + + return pending.PendingContext( + run_model=run_model, + run_spec=run_spec, + secrets=secrets, + locked_job_ids=locked_job_ids, + gateway_stats=gateway_stats, + ) + + +async def _refetch_locked_run_for_pending( + session: AsyncSession, + item: RunPipelineItem, +) -> Optional[RunModel]: + latest_sq = _build_latest_submissions_subquery(item.id) + job_alias = aliased(JobModel) + res = await session.execute( + select(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + .outerjoin(latest_sq, latest_sq.c.run_id == RunModel.id) + .outerjoin( + job_alias, + and_( + job_alias.run_id == latest_sq.c.run_id, + job_alias.replica_num == latest_sq.c.replica_num, + job_alias.job_num == latest_sq.c.job_num, + job_alias.submission_num == latest_sq.c.max_submission_num, + ), + ) + .options( + joinedload(RunModel.project).load_only( + ProjectModel.id, + ProjectModel.name, + ), + ) + .options(contains_eager(RunModel.jobs, alias=job_alias)) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +def _build_latest_submissions_subquery(run_id: uuid.UUID): + """Subquery selecting only the latest submission per (replica_num, job_num).""" + return ( + select( + JobModel.run_id.label("run_id"), + JobModel.replica_num.label("replica_num"), + JobModel.job_num.label("job_num"), + func.max(JobModel.submission_num).label("max_submission_num"), + ) + .where(JobModel.run_id == run_id) + .group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num) + .subquery() + ) + + +async def _apply_pending_result( + item: RunPipelineItem, + context: pending.PendingContext, + result: pending.PendingResult, +) -> None: + set_processed_update_map_fields(result.run_update_map) + set_unlock_update_map_fields(result.run_update_map) + + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(result.run_update_map, now=now) + + res = await session.execute( + update(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + .values(**result.run_update_map) + .returning(RunModel.id) + ) + updated_run_ids = list(res.scalars().all()) + if len(updated_run_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=context.locked_job_ids, + ) + await session.commit() + return + + for job_model in result.new_job_models: + session.add(job_model) + events.emit( + session, + f"Job created on new submission. Status: {job_model.status.upper()}", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) + + emit_run_status_change_event( + session=session, + run_model=context.run_model, + old_status=context.run_model.status, + new_status=result.run_update_map.get("status", context.run_model.status), + ) + + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=context.locked_job_ids, + ) + await session.commit() + + +async def _apply_noop_result( + item: RunPipelineItem, + locked_job_ids: set[uuid.UUID], +) -> None: + """Unlock the run without changing state. Used when processing decides to skip.""" + async with get_session_ctx() as session: + now = get_current_datetime() + await session.execute( + update(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + .values( + lock_expires_at=None, + lock_token=None, + lock_owner=None, + last_processed_at=now, + ) + ) + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=locked_job_ids, + ) + await session.commit() + + +async def _process_active_item(item: RunPipelineItem) -> None: + async with get_session_ctx() as session: + load_result = await _load_active_context(session=session, item=item) + if load_result is None: + return + context = load_result + + result = await active.process_active_run(context) + await _apply_active_result(item=item, context=context, result=result) + + +async def _load_active_context( + session: AsyncSession, + item: RunPipelineItem, +) -> Optional[active.ActiveContext]: + """Returns None on lock mismatch (already handled). + Returns context when processing should proceed. + """ + locked_job_ids = await _lock_related_jobs(session=session, item=item) + if locked_job_ids is None: + return None + run_model = await _refetch_locked_run_for_active(session=session, item=item) + if run_model is None: + log_lock_token_mismatch(logger, item) + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=locked_job_ids, + ) + await session.commit() + return None + secrets = await get_project_secrets_mapping(session=session, project=run_model.project) + run_spec = get_run_spec(run_model) + + gateway_stats = None + if run_spec.configuration.type == "service" and run_model.gateway_id is not None: + _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name) + + return active.ActiveContext( + run_model=run_model, + run_spec=run_spec, + secrets=secrets, + locked_job_ids=locked_job_ids, + gateway_stats=gateway_stats, + ) + + +async def _refetch_locked_run_for_active( + session: AsyncSession, + item: RunPipelineItem, +) -> Optional[RunModel]: + latest_sq = _build_latest_submissions_subquery(item.id) + job_alias = aliased(JobModel) + res = await session.execute( + select(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + .outerjoin(latest_sq, latest_sq.c.run_id == RunModel.id) + .outerjoin( + job_alias, + and_( + job_alias.run_id == latest_sq.c.run_id, + job_alias.replica_num == latest_sq.c.replica_num, + job_alias.job_num == latest_sq.c.job_num, + job_alias.submission_num == latest_sq.c.max_submission_num, + ), + ) + .options( + joinedload(RunModel.project).load_only( + ProjectModel.id, + ProjectModel.name, + ), + ) + .options( + contains_eager(RunModel.jobs, alias=job_alias) + .joinedload(JobModel.instance) + .load_only(InstanceModel.fleet_id), + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _apply_active_result( + item: RunPipelineItem, + context: active.ActiveContext, + result: active.ActiveResult, +) -> None: + run_model = context.run_model + set_processed_update_map_fields(result.run_update_map) + set_unlock_update_map_fields(result.run_update_map) + + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(result.run_update_map, now=now) + job_update_rows = _build_active_job_update_rows( + job_id_to_update_map=result.job_id_to_update_map, + unlock_job_ids=context.locked_job_ids, + ) + if job_update_rows: + resolve_now_placeholders(job_update_rows, now=now) + + res = await session.execute( + update(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + .values(**result.run_update_map) + .returning(RunModel.id) + ) + updated_run_ids = list(res.scalars().all()) + if len(updated_run_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=context.locked_job_ids, + ) + await session.commit() + return + + if job_update_rows: + await session.execute(update(JobModel), job_update_rows) + + for job_model in result.new_job_models: + session.add(job_model) + events.emit( + session, + f"Job created on retry. Status: {job_model.status.upper()}", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) + + old_status = run_model.status + new_status = result.run_update_map.get("status", old_status) + _emit_active_metrics(run_model, context.run_spec, old_status, new_status) + + _emit_active_job_status_change_events( + session=session, + context=context, + result=result, + ) + # Set termination_reason on the model so emit_run_status_change_event can read it. + if "termination_reason" in result.run_update_map: + run_model.termination_reason = result.run_update_map["termination_reason"] + emit_run_status_change_event( + session=session, + run_model=run_model, + old_status=old_status, + new_status=new_status, + ) + await session.commit() + + +def _emit_active_metrics( + run_model: RunModel, + run_spec, + old_status: RunStatus, + new_status: RunStatus, +) -> None: + if old_status == new_status: + return + project_name = run_model.project.name + run_type = run_spec.configuration.type + if old_status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING: + duration = (get_current_datetime() - run_model.submitted_at).total_seconds() + run_metrics.log_submit_to_provision_duration(duration, project_name, run_type) + if new_status == RunStatus.PENDING: + run_metrics.increment_pending_runs(project_name, run_type) + + +class _ActiveRunJobUpdateRow(active.ActiveRunJobUpdateMap, total=False): + id: uuid.UUID + + +def _build_active_job_update_rows( + job_id_to_update_map: dict[uuid.UUID, active.ActiveRunJobUpdateMap], + unlock_job_ids: set[uuid.UUID], +) -> list[_ActiveRunJobUpdateRow]: + job_update_rows = [] + for job_id in sorted(job_id_to_update_map.keys() | unlock_job_ids): + update_row = _ActiveRunJobUpdateRow(id=job_id) + job_update_map = job_id_to_update_map.get(job_id) + if job_update_map is not None: + for key, value in job_update_map.items(): + update_row[key] = value + if job_id in unlock_job_ids: + set_unlock_update_map_fields(update_row) + set_processed_update_map_fields(update_row) + job_update_rows.append(update_row) + return job_update_rows + + +def _emit_active_job_status_change_events( + session: AsyncSession, + context: active.ActiveContext, + result: active.ActiveResult, +) -> None: + for job_model in context.run_model.jobs: + job_update_map = result.job_id_to_update_map.get(job_model.id) + if job_update_map is None: + continue + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=job_model.status, + new_status=job_update_map.get("status", job_model.status), + termination_reason=job_update_map.get( + "termination_reason", + job_model.termination_reason, + ), + termination_reason_message=job_update_map.get( + "termination_reason_message", + job_model.termination_reason_message, + ), + ) + + +async def _process_terminating_item(item: RunPipelineItem) -> None: + async with get_session_ctx() as session: + context = await _load_terminating_context(session=session, item=item) + if context is None: + return + + result = await terminating.process_terminating_run(context) + await _apply_terminating_result(item=item, context=context, result=result) + + +async def _load_terminating_context( + session: AsyncSession, + item: RunPipelineItem, +) -> Optional[terminating.TerminatingContext]: + locked_job_ids = await _lock_related_jobs( + session=session, + item=item, + ) + if locked_job_ids is None: + return None + run_model = await _refetch_locked_run_for_terminating(session=session, item=item) + if run_model is None: + log_lock_token_mismatch(logger, item) + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=locked_job_ids, + ) + await session.commit() + return None + return terminating.TerminatingContext( + run_model=run_model, + locked_job_ids=locked_job_ids, + ) + + +async def _refetch_locked_run_for_terminating( + session: AsyncSession, + item: RunPipelineItem, +) -> Optional[RunModel]: + latest_sq = _build_latest_submissions_subquery(item.id) + job_alias = aliased(JobModel) + res = await session.execute( + select(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + .outerjoin(latest_sq, latest_sq.c.run_id == RunModel.id) + .outerjoin( + job_alias, + and_( + job_alias.run_id == latest_sq.c.run_id, + job_alias.replica_num == latest_sq.c.replica_num, + job_alias.job_num == latest_sq.c.job_num, + job_alias.submission_num == latest_sq.c.max_submission_num, + ), + ) + .options( + joinedload(RunModel.project).load_only( + ProjectModel.id, + ProjectModel.name, + ), + ) + .options( + contains_eager(RunModel.jobs, alias=job_alias) + .joinedload(JobModel.instance) + .joinedload(InstanceModel.project) + .load_only( + ProjectModel.id, + ProjectModel.ssh_private_key, + ), + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _lock_related_jobs( + session: AsyncSession, + item: RunPipelineItem, +) -> Optional[set[uuid.UUID]]: + now = get_current_datetime() + job_lock, _ = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) + async with job_lock: + res = await session.execute( + select(JobModel) + .where( + JobModel.run_id == item.id, + JobModel.status.not_in(JOB_STATUSES_EXCLUDED_FOR_LOCKING), + or_( + JobModel.lock_expires_at.is_(None), + JobModel.lock_expires_at < now, + ), + or_( + JobModel.lock_owner.is_(None), + JobModel.lock_owner == RunPipeline.__name__, + ), + ) + .order_by(JobModel.id) + .with_for_update(skip_locked=True, key_share=True, of=JobModel) + .options(load_only(JobModel.id)) + ) + locked_job_models = list(res.scalars().all()) + locked_job_ids = {job_model.id for job_model in locked_job_models} + + res = await session.execute( + select(JobModel.id).where( + JobModel.run_id == item.id, + JobModel.status.not_in(JOB_STATUSES_EXCLUDED_FOR_LOCKING), + ) + ) + current_job_ids = set(res.scalars().all()) + if current_job_ids != locked_job_ids: + logger.debug( + "Failed to lock run %s jobs. The run will be processed later.", + item.id, + ) + await _reset_run_lock_for_retry(session=session, item=item) + return None + for job_model in locked_job_models: + job_model.lock_expires_at = item.lock_expires_at + job_model.lock_token = item.lock_token + job_model.lock_owner = RunPipeline.__name__ + await session.commit() + return {jm.id for jm in locked_job_models} + + +async def _reset_run_lock_for_retry( + session: AsyncSession, + item: RunPipelineItem, +) -> None: + res = await session.execute( + update(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + # Keep `lock_owner` so the run remains owned by the run pipeline, + # but unset `lock_expires_at` to retry ASAP and unset `lock_token` + # so heartbeater can no longer update the item. + .values( + lock_expires_at=None, + lock_token=None, + last_processed_at=get_current_datetime(), + ) + .returning(RunModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_on_reset(logger) + + +async def _apply_terminating_result( + item: RunPipelineItem, + context: terminating.TerminatingContext, + result: terminating.TerminatingResult, +) -> None: + run_model = context.run_model + set_processed_update_map_fields(result.run_update_map) + set_unlock_update_map_fields(result.run_update_map) + + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(result.run_update_map, now=now) + job_update_rows = _build_terminating_job_update_rows( + job_id_to_update_map=result.job_id_to_update_map, + unlock_job_ids=context.locked_job_ids, + ) + if job_update_rows: + resolve_now_placeholders(job_update_rows, now=now) + res = await session.execute( + update(RunModel) + .where( + RunModel.id == item.id, + RunModel.lock_token == item.lock_token, + ) + .values(**result.run_update_map) + .returning(RunModel.id) + ) + updated_run_ids = list(res.scalars().all()) + if len(updated_run_ids) == 0: + # The only side-effects are runner stop signal and service deregistration, + # and they are idempotent, so no need for cleanup. + log_lock_token_changed_after_processing(logger, item) + await _unlock_related_jobs( + session=session, + item=item, + locked_job_ids=context.locked_job_ids, + ) + await session.commit() + return + + if job_update_rows: + await session.execute(update(JobModel), job_update_rows) + + if result.service_unregistration is not None: + targets = [events.Target.from_model(run_model)] + if result.service_unregistration.gateway_target is not None: + targets.append(result.service_unregistration.gateway_target) + events.emit( + session, + result.service_unregistration.event_message, + actor=events.SystemActor(), + targets=targets, + ) + + _emit_terminating_job_status_change_events( + session=session, + context=context, + result=result, + ) + emit_run_status_change_event( + session=session, + run_model=context.run_model, + old_status=context.run_model.status, + new_status=result.run_update_map.get("status", context.run_model.status), + ) + await session.commit() + + +class _TerminatingRunJobUpdateRow(terminating.TerminatingRunJobUpdateMap, total=False): + id: uuid.UUID + + +def _build_terminating_job_update_rows( + job_id_to_update_map: dict[uuid.UUID, terminating.TerminatingRunJobUpdateMap], + unlock_job_ids: set[uuid.UUID], +) -> list[_TerminatingRunJobUpdateRow]: + job_update_rows = [] + for job_id in sorted(job_id_to_update_map.keys() | unlock_job_ids): + update_row = _TerminatingRunJobUpdateRow(id=job_id) + job_update_map = job_id_to_update_map.get(job_id) + if job_update_map is not None: + for key, value in job_update_map.items(): + update_row[key] = value + if job_id in unlock_job_ids: + set_unlock_update_map_fields(update_row) + set_processed_update_map_fields(update_row) + job_update_rows.append(update_row) + return job_update_rows + + +def _emit_terminating_job_status_change_events( + session: AsyncSession, + context: terminating.TerminatingContext, + result: terminating.TerminatingResult, +) -> None: + for job_model in context.run_model.jobs: + job_update_map = result.job_id_to_update_map.get(job_model.id) + if job_update_map is None: + continue + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=job_model.status, + new_status=job_update_map.get("status", job_model.status), + termination_reason=job_update_map.get( + "termination_reason", + job_model.termination_reason, + ), + termination_reason_message=job_model.termination_reason_message, + ) + + +async def _unlock_related_jobs( + session: AsyncSession, + item: RunPipelineItem, + locked_job_ids: set[uuid.UUID], +) -> None: + if len(locked_job_ids) == 0: + return + await session.execute( + update(JobModel) + .where( + JobModel.id.in_(locked_job_ids), + JobModel.lock_token == item.lock_token, + JobModel.lock_owner == RunPipeline.__name__, + ) + .values( + lock_expires_at=None, + lock_token=None, + lock_owner=None, + ) + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/active.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/active.py new file mode 100644 index 000000000..f29fd51eb --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/active.py @@ -0,0 +1,732 @@ +import json +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Set, Tuple + +from sqlalchemy import select +from sqlalchemy.orm import load_only + +from dstack._internal.core.errors import ServerError +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.profiles import RetryEvent, StopCriteria +from dstack._internal.core.models.runs import ( + JobStatus, + JobTerminationReason, + RunSpec, + RunStatus, + RunTerminationReason, +) +from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats +from dstack._internal.server.background.pipeline_tasks.base import ItemUpdateMap +from dstack._internal.server.background.pipeline_tasks.runs.common import ( + PerGroupDesiredCounts, + build_scale_up_job_models, + compute_desired_replica_counts, +) +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.services.jobs import ( + get_job_spec, + get_job_specs_from_run_spec, + get_jobs_from_run_spec, + group_jobs_by_replica_latest, +) +from dstack._internal.server.services.runs import create_job_model_for_new_submission +from dstack._internal.server.services.runs.replicas import ( + build_replica_lists, + get_group_rollout_state, + has_out_of_date_replicas, + job_belongs_to_group, +) +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + +ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment + + +class ActiveRunUpdateMap(ItemUpdateMap, total=False): + status: RunStatus + termination_reason: Optional[RunTerminationReason] + fleet_id: Optional[uuid.UUID] + resubmission_attempt: int + desired_replica_count: int + desired_replica_counts: Optional[str] # JSON + + +class ActiveRunJobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + termination_reason_message: Optional[str] + deployment_num: int + + +@dataclass +class ActiveContext: + run_model: RunModel + run_spec: RunSpec + secrets: dict + locked_job_ids: set[uuid.UUID] + gateway_stats: Optional[PerWindowStats] = None + + +@dataclass +class ActiveResult: + run_update_map: ActiveRunUpdateMap + new_job_models: list[JobModel] + job_id_to_update_map: dict[uuid.UUID, ActiveRunJobUpdateMap] + + +@dataclass +class _ReplicaAnalysis: + """Per-replica classification of job states for determining the run's next status.""" + + replica_num: int + job_models: List[JobModel] + contributed_statuses: Set[RunStatus] = field(default_factory=set) + """`RunStatus` values derived from this replica's jobs. Merged into the run-level + analysis unless the replica is being retried as a whole.""" + termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + """Why the replica failed. Only populated when `FAILED` is in `contributed_statuses`.""" + needs_retry: bool = False + """At least one job failed with a retryable reason and the retry duration hasn't been + exceeded. When `True`, the replica does not contribute its statuses to the run-level + analysis and is added to `replicas_to_retry` instead.""" + + +@dataclass +class _RunAnalysis: + """Aggregated replica analysis used to determine the run's next status. + + Each replica contributes `RunStatus` based on its jobs' statuses. + The run's new status is the highest-priority value across all + contributing replicas: FAILED > RUNNING > PROVISIONING > SUBMITTED > DONE. + Replicas that need full retry do not contribute and instead cause a PENDING transition. + """ + + contributed_statuses: Set[RunStatus] = field(default_factory=set) + termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + replicas_to_retry: List[Tuple[int, List[JobModel]]] = field(default_factory=list) + """Replicas with retryable failures that haven't exceeded the retry duration.""" + + +@dataclass +class _ActiveRunTransition: + new_status: RunStatus + termination_reason: Optional[RunTerminationReason] = None + + +async def process_active_run(context: ActiveContext) -> ActiveResult: + run_model = context.run_model + run_spec = context.run_spec + + fleet_id = _detect_fleet_id_from_jobs(run_model) + analysis = await _analyze_active_run(run_model) + transition = _get_active_run_transition(run_spec, run_model, analysis) + + run_update_map = _build_run_update_map(run_model, run_spec, transition, fleet_id) + new_job_models: list[JobModel] = [] + job_id_to_update_map: Dict[uuid.UUID, ActiveRunJobUpdateMap] = {} + + if transition.new_status == RunStatus.PENDING: + job_id_to_update_map = _build_terminate_retrying_jobs_map(analysis.replicas_to_retry) + elif transition.new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}: + if analysis.replicas_to_retry: + new_job_models = await _build_retry_job_models(context, analysis.replicas_to_retry) + # In a multi-node replica, one job may fail while siblings are still running. + # Terminate those siblings so the entire replica retries cleanly. + job_id_to_update_map = _build_terminate_retrying_jobs_map(analysis.replicas_to_retry) + elif run_spec.configuration.type == "service": + per_group_desired = _apply_desired_counts_to_update_map(run_update_map, context) + # Service processing has multiple stages that never conflict: + # - scaling skips groups with out-of-date replicas (rolling in progress), + # so for those groups only rolling manages replica creation and teardown; + # - cleanup only targets removed groups (not in configuration.replica_groups). + new_job_models, job_id_to_update_map = await _build_service_scaling_maps( + context, per_group_desired + ) + + deployment_maps = await _build_deployment_update_map(context) + job_id_to_update_map.update(deployment_maps) + + rolling_new, rolling_maps = await _build_rolling_deployment_maps( + context, per_group_desired, in_place_bumped_job_ids=set(deployment_maps.keys()) + ) + new_job_models.extend(rolling_new) + job_id_to_update_map.update(rolling_maps) + + cleanup_maps = _build_removed_groups_cleanup_maps(context) + job_id_to_update_map.update(cleanup_maps) + else: + job_id_to_update_map = await _build_deployment_update_map(context) + + return ActiveResult( + run_update_map=run_update_map, + new_job_models=new_job_models, + job_id_to_update_map=job_id_to_update_map, + ) + + +def _detect_fleet_id_from_jobs(run_model: RunModel) -> Optional[uuid.UUID]: + """Detect fleet_id from job instances. Returns the current fleet_id if already set.""" + if run_model.fleet_id is not None: + return run_model.fleet_id + for job_model in run_model.jobs: + if job_model.instance is not None and job_model.instance.fleet_id is not None: + return job_model.instance.fleet_id + return None + + +async def _analyze_active_run(run_model: RunModel) -> _RunAnalysis: + run_analysis = _RunAnalysis() + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): + replica_analysis = await _analyze_active_run_replica( + run_model=run_model, + replica_num=replica_num, + job_models=job_models, + ) + _apply_replica_analysis(run_analysis, replica_analysis) + return run_analysis + + +async def _analyze_active_run_replica( + run_model: RunModel, + replica_num: int, + job_models: List[JobModel], +) -> _ReplicaAnalysis: + contributed_statuses: Set[RunStatus] = set() + termination_reasons: Set[RunTerminationReason] = set() + needs_retry = False + + for job_model in job_models: + if _job_is_done_or_finishing_done(job_model): + contributed_statuses.add(RunStatus.DONE) + continue + + if _job_was_scaled_down(job_model): + continue + + replica_status = _get_non_terminal_replica_status(job_model) + if replica_status is not None: + contributed_statuses.add(replica_status) + continue + + if _job_needs_retry_evaluation(job_model): + current_duration = await _should_retry_job(run_model, job_model) + if current_duration is None: + contributed_statuses.add(RunStatus.FAILED) + termination_reasons.add(RunTerminationReason.JOB_FAILED) + elif _is_retry_duration_exceeded(job_model, current_duration): + contributed_statuses.add(RunStatus.FAILED) + termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED) + else: + needs_retry = True + continue + + raise ServerError(f"Unexpected job status {job_model.status}") + + return _ReplicaAnalysis( + replica_num=replica_num, + job_models=job_models, + contributed_statuses=contributed_statuses, + termination_reasons=termination_reasons, + needs_retry=needs_retry, + ) + + +def _apply_replica_analysis( + analysis: _RunAnalysis, + replica_analysis: _ReplicaAnalysis, +) -> None: + if RunStatus.FAILED in replica_analysis.contributed_statuses: + analysis.contributed_statuses.add(RunStatus.FAILED) + analysis.termination_reasons.update(replica_analysis.termination_reasons) + return + + if replica_analysis.needs_retry: + analysis.replicas_to_retry.append( + (replica_analysis.replica_num, replica_analysis.job_models) + ) + + if not replica_analysis.needs_retry: + analysis.contributed_statuses.update(replica_analysis.contributed_statuses) + + +def _job_is_done_or_finishing_done(job_model: JobModel) -> bool: + return job_model.status == JobStatus.DONE or ( + job_model.status == JobStatus.TERMINATING + and job_model.termination_reason == JobTerminationReason.DONE_BY_RUNNER + ) + + +def _job_was_scaled_down(job_model: JobModel) -> bool: + return job_model.termination_reason == JobTerminationReason.SCALED_DOWN + + +def _get_non_terminal_replica_status(job_model: JobModel) -> Optional[RunStatus]: + if job_model.status == JobStatus.RUNNING: + return RunStatus.RUNNING + if job_model.status in {JobStatus.PROVISIONING, JobStatus.PULLING}: + return RunStatus.PROVISIONING + if job_model.status == JobStatus.SUBMITTED: + return RunStatus.SUBMITTED + return None + + +def _job_needs_retry_evaluation(job_model: JobModel) -> bool: + return job_model.status == JobStatus.FAILED or ( + job_model.status in [JobStatus.TERMINATING, JobStatus.TERMINATED, JobStatus.ABORTED] + and job_model.termination_reason + not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN} + ) + + +async def _should_retry_job( + run_model: RunModel, + job_model: JobModel, +) -> Optional[timedelta]: + """ + Checks if the job should be retried. + Returns the current duration of retrying if retry is enabled. + Retrying duration is calculated as the time since `last_processed_at` + of the latest provisioned submission. + """ + job_spec = get_job_spec(job_model) + if job_spec.retry is None: + return None + + last_provisioned = await _load_last_provisioned_job( + run_id=job_model.run_id, + replica_num=job_model.replica_num, + job_num=job_model.job_num, + ) + + if ( + job_model.termination_reason is not None + and job_model.termination_reason.to_retry_event() == RetryEvent.NO_CAPACITY + and last_provisioned is None + and RetryEvent.NO_CAPACITY in job_spec.retry.on_events + ): + return get_current_datetime() - run_model.submitted_at + + if ( + job_model.termination_reason is not None + and job_model.termination_reason.to_retry_event() in job_spec.retry.on_events + and last_provisioned is not None + ): + return get_current_datetime() - last_provisioned.last_processed_at + + return None + + +async def _load_last_provisioned_job( + run_id: uuid.UUID, + replica_num: int, + job_num: int, +) -> Optional[JobModel]: + """Load the last submission with provisioning data for a single (replica_num, job_num).""" + async with get_session_ctx() as session: + res = await session.execute( + select(JobModel) + .where( + JobModel.run_id == run_id, + JobModel.replica_num == replica_num, + JobModel.job_num == job_num, + JobModel.job_provisioning_data.is_not(None), + ) + .order_by(JobModel.submission_num.desc()) + .limit(1) + .options(load_only(JobModel.last_processed_at)) + ) + return res.scalar_one_or_none() + + +def _is_retry_duration_exceeded(job_model: JobModel, current_duration: timedelta) -> bool: + job_spec = get_job_spec(job_model) + if job_spec.retry is None: + return True + return current_duration > timedelta(seconds=job_spec.retry.duration) + + +def _should_stop_on_master_done(run_spec: RunSpec, run_model: RunModel) -> bool: + if run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE: + return False + for job_model in run_model.jobs: + if job_model.job_num == 0 and job_model.status == JobStatus.DONE: + return True + return False + + +def _get_active_run_transition( + run_spec: RunSpec, + run_model: RunModel, + analysis: _RunAnalysis, +) -> _ActiveRunTransition: + # Check `analysis.contributed_statuses` in the priority order. + if RunStatus.FAILED in analysis.contributed_statuses: + if RunTerminationReason.JOB_FAILED in analysis.termination_reasons: + termination_reason = RunTerminationReason.JOB_FAILED + elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in analysis.termination_reasons: + termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED + else: + raise ServerError(f"Unexpected termination reason {analysis.termination_reasons}") + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=termination_reason, + ) + + if _should_stop_on_master_done(run_spec, run_model): + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + ) + + if RunStatus.RUNNING in analysis.contributed_statuses: + return _ActiveRunTransition(new_status=RunStatus.RUNNING) + if RunStatus.PROVISIONING in analysis.contributed_statuses: + return _ActiveRunTransition(new_status=RunStatus.PROVISIONING) + if RunStatus.SUBMITTED in analysis.contributed_statuses: + return _ActiveRunTransition(new_status=RunStatus.SUBMITTED) + if RunStatus.DONE in analysis.contributed_statuses and not analysis.replicas_to_retry: + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + ) + if not analysis.contributed_statuses or analysis.contributed_statuses == {RunStatus.DONE}: + # No active replicas remain — resubmit the entire run. + # `contributed_statuses` is either empty (every replica is retrying) or contains + # only DONE (some replicas finished, others need retry). + return _ActiveRunTransition(new_status=RunStatus.PENDING) + raise ServerError("Failed to determine run transition: unexpected active run state") + + +def _build_run_update_map( + run_model: RunModel, + run_spec: RunSpec, + transition: _ActiveRunTransition, + fleet_id: Optional[uuid.UUID], +) -> ActiveRunUpdateMap: + update_map = ActiveRunUpdateMap() + + if fleet_id != run_model.fleet_id: + update_map["fleet_id"] = fleet_id + + if run_model.status == transition.new_status: + return update_map + + update_map["status"] = transition.new_status + update_map["termination_reason"] = transition.termination_reason + + if transition.new_status == RunStatus.PROVISIONING: + update_map["resubmission_attempt"] = 0 + elif transition.new_status == RunStatus.PENDING: + update_map["resubmission_attempt"] = run_model.resubmission_attempt + 1 + # Unassign run from fleet so that a new fleet can be chosen when retrying + update_map["fleet_id"] = None + + return update_map + + +def _build_terminate_retrying_jobs_map( + replicas_to_retry: List[Tuple[int, List[JobModel]]], +) -> dict[uuid.UUID, ActiveRunJobUpdateMap]: + job_id_to_update_map: dict[uuid.UUID, ActiveRunJobUpdateMap] = {} + for _, replica_jobs in replicas_to_retry: + for job_model in replica_jobs: + if job_model.status.is_finished() or job_model.status == JobStatus.TERMINATING: + continue + job_id_to_update_map[job_model.id] = ActiveRunJobUpdateMap( + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message="Run is to be resubmitted", + ) + return job_id_to_update_map + + +async def _build_retry_job_models( + context: ActiveContext, + replicas_to_retry: List[Tuple[int, List[JobModel]]], +) -> list[JobModel]: + new_job_models: list[JobModel] = [] + for _, replica_jobs in replicas_to_retry: + job_spec = get_job_spec(replica_jobs[0]) + replica_group_name = job_spec.replica_group + new_jobs = await get_jobs_from_run_spec( + run_spec=context.run_spec, + secrets=context.secrets, + replica_num=replica_jobs[0].replica_num, + replica_group_name=replica_group_name, + ) + assert len(new_jobs) == len(replica_jobs), ( + "Changing the number of jobs within a replica is not yet supported" + ) + for old_job_model, new_job in zip(replica_jobs, new_jobs): + # If some jobs in a retry replica are not finished, they must be terminated by the caller. + job_model = create_job_model_for_new_submission( + run_model=context.run_model, + job=new_job, + status=JobStatus.SUBMITTED, + ) + job_model.submission_num = old_job_model.submission_num + 1 + new_job_models.append(job_model) + return new_job_models + + +async def _build_deployment_update_map( + context: ActiveContext, +) -> dict[uuid.UUID, ActiveRunJobUpdateMap]: + """Bump deployment_num for jobs that do not require redeployment.""" + run_model = context.run_model + run_spec = context.run_spec + job_id_to_update_map: dict[uuid.UUID, ActiveRunJobUpdateMap] = {} + + if not has_out_of_date_replicas(run_model): + return job_id_to_update_map + + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): + if all(j.status.is_finished() for j in job_models): + continue + if all(j.deployment_num == run_model.deployment_num for j in job_models): + continue + + replica_group_name = None + if run_spec.configuration.type == "service": + job_spec = get_job_spec(job_models[0]) + replica_group_name = job_spec.replica_group + + new_job_specs = await get_job_specs_from_run_spec( + run_spec=run_spec, + secrets=context.secrets, + replica_num=replica_num, + replica_group_name=replica_group_name, + ) + assert len(new_job_specs) == len(job_models), ( + "Changing the number of jobs within a replica is not yet supported" + ) + can_update_all_jobs = True + for old_job_model, new_job_spec in zip(job_models, new_job_specs): + old_job_spec = get_job_spec(old_job_model) + if new_job_spec != old_job_spec: + can_update_all_jobs = False + break + if can_update_all_jobs: + for job_model in job_models: + job_id_to_update_map[job_model.id] = ActiveRunJobUpdateMap( + deployment_num=run_model.deployment_num, + ) + + return job_id_to_update_map + + +def _compute_last_scaled_at(run_model: RunModel) -> Optional[datetime]: + """Compute the timestamp of the most recent scaling event from replica data.""" + timestamps: list[datetime] = [] + active, inactive = build_replica_lists(run_model) + for _, _, _, jobs in active: + timestamps.append(min(j.submitted_at for j in jobs)) + for _, _, _, jobs in inactive: + timestamps.append(max(j.last_processed_at for j in jobs)) + return max(timestamps) if timestamps else None + + +def _apply_desired_counts_to_update_map( + run_update_map: ActiveRunUpdateMap, + context: ActiveContext, +) -> PerGroupDesiredCounts: + """Compute desired counts and add to run_update_map. Returns per-group desired counts.""" + configuration = context.run_spec.configuration + assert isinstance(configuration, ServiceConfiguration) + last_scaled_at = _compute_last_scaled_at(context.run_model) + total, per_group_desired = compute_desired_replica_counts( + context.run_model, configuration, context.gateway_stats, last_scaled_at + ) + run_update_map["desired_replica_count"] = total + run_update_map["desired_replica_counts"] = json.dumps(per_group_desired) + return per_group_desired + + +def _build_scale_down_job_update_maps( + active_replicas: list[tuple[int, bool, int, list[JobModel]]], + count: int, +) -> dict[uuid.UUID, ActiveRunJobUpdateMap]: + """Build job update maps for scaling down the least-important replicas.""" + job_id_to_update_map: dict[uuid.UUID, ActiveRunJobUpdateMap] = {} + if count <= 0: + return job_id_to_update_map + for _, _, _, replica_jobs in reversed(active_replicas[-count:]): + for job in replica_jobs: + if job.status.is_finished() or job.status == JobStatus.TERMINATING: + continue + job_id_to_update_map[job.id] = ActiveRunJobUpdateMap( + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.SCALED_DOWN, + ) + return job_id_to_update_map + + +async def _build_service_scaling_maps( + context: ActiveContext, + per_group_desired: PerGroupDesiredCounts, +) -> tuple[list[JobModel], dict[uuid.UUID, ActiveRunJobUpdateMap]]: + """Build new jobs for scale-up and update maps for scale-down across all groups.""" + run_model = context.run_model + configuration = context.run_spec.configuration + assert isinstance(configuration, ServiceConfiguration) + new_job_models: list[JobModel] = [] + job_id_to_update_map: dict[uuid.UUID, ActiveRunJobUpdateMap] = {} + + next_replica_num = max((job.replica_num for job in run_model.jobs), default=-1) + 1 + + for group in configuration.replica_groups: + assert group.name is not None + group_desired = per_group_desired.get(group.name, 0) + active_replicas, _ = build_replica_lists(run_model, group_filter=group.name) + diff = group_desired - len(active_replicas) + + if diff == 0: + continue + + # During rolling deployment, skip the group entirely — + # _build_rolling_deployment_maps handles both surge and teardown. + if has_out_of_date_replicas(run_model, group_filter=group.name): + continue + + if diff > 0: + new_jobs = await build_scale_up_job_models( + run_model=run_model, + run_spec=context.run_spec, + secrets=context.secrets, + replicas_diff=diff, + group_name=group.name, + replica_num_start=next_replica_num, + ) + new_job_models.extend(new_jobs) + # Advance next_replica_num past any newly created replicas + if new_jobs: + max_new = max(j.replica_num for j in new_jobs) + next_replica_num = max(next_replica_num, max_new + 1) + else: + scale_down_maps = _build_scale_down_job_update_maps(active_replicas, abs(diff)) + job_id_to_update_map.update(scale_down_maps) + + return ( + new_job_models, + job_id_to_update_map, + ) + + +def _has_out_of_date_replicas( + run_model: RunModel, + group_name: str, + exclude_job_ids: Set[uuid.UUID], +) -> bool: + """Check for out-of-date replicas, treating jobs in `exclude_job_ids` as up-to-date.""" + for job in run_model.jobs: + if job.id in exclude_job_ids: + continue + if not job_belongs_to_group(job, group_name): + continue + if job.deployment_num < run_model.deployment_num and not ( + job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN + ): + return True + return False + + +async def _build_rolling_deployment_maps( + context: ActiveContext, + per_group_desired: PerGroupDesiredCounts, + in_place_bumped_job_ids: Set[uuid.UUID], +) -> tuple[list[JobModel], dict[uuid.UUID, ActiveRunJobUpdateMap]]: + """Build scale-up models and scale-down maps for rolling deployment across all groups. + + Jobs in `in_place_bumped_job_ids` are about to have their deployment_num bumped + in-place. We exclude them from the out-of-date check so rolling deployment only + targets replicas that actually need replacement. + """ + run_model = context.run_model + configuration = context.run_spec.configuration + assert isinstance(configuration, ServiceConfiguration) + new_job_models: list[JobModel] = [] + job_id_to_update_map: dict[uuid.UUID, ActiveRunJobUpdateMap] = {} + + next_replica_num = max((job.replica_num for job in run_model.jobs), default=-1) + 1 + + for group in configuration.replica_groups: + assert group.name is not None + group_desired = per_group_desired.get(group.name, 0) + # Check if there are truly out-of-date replicas (excluding in-place bumped jobs) + if not _has_out_of_date_replicas(run_model, group.name, in_place_bumped_job_ids): + continue + + state = get_group_rollout_state(run_model, group) + group_max = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE + + # Scale up: create new up-to-date replicas if below max + if state.non_terminated_replica_count < group_max: + new_jobs = await build_scale_up_job_models( + run_model=run_model, + run_spec=context.run_spec, + secrets=context.secrets, + replicas_diff=group_max - state.non_terminated_replica_count, + group_name=group.name, + replica_num_start=next_replica_num, + ) + new_job_models.extend(new_jobs) + if new_jobs: + max_new = max(j.replica_num for j in new_jobs) + next_replica_num = max(next_replica_num, max_new + 1) + + # Scale down: terminate unregistered out-of-date + excess registered replicas + replicas_to_stop = state.unregistered_out_of_date_replica_count + replicas_to_stop += max( + 0, + state.registered_non_terminating_replica_count - group_desired, + ) + if replicas_to_stop > 0: + scale_down_maps = _build_scale_down_job_update_maps( + state.active_replicas, replicas_to_stop + ) + job_id_to_update_map.update(scale_down_maps) + + return new_job_models, job_id_to_update_map + + +def _build_removed_groups_cleanup_maps( + context: ActiveContext, +) -> dict[uuid.UUID, ActiveRunJobUpdateMap]: + """Terminate replicas from groups no longer in the configuration.""" + run_model = context.run_model + configuration = context.run_spec.configuration + assert isinstance(configuration, ServiceConfiguration) + job_id_to_update_map: dict[uuid.UUID, ActiveRunJobUpdateMap] = {} + + existing_group_names: set[str] = set() + for job in run_model.jobs: + if job.status.is_finished(): + continue + job_spec = get_job_spec(job) + existing_group_names.add(job_spec.replica_group) + + new_group_names = {group.name for group in configuration.replica_groups} + removed_group_names = existing_group_names - new_group_names + + for removed_group_name in removed_group_names: + active_replicas, inactive_replicas = build_replica_lists( + run_model=run_model, + group_filter=removed_group_name, + ) + if active_replicas: + scale_down_maps = _build_scale_down_job_update_maps( + active_replicas, len(active_replicas) + ) + job_id_to_update_map.update(scale_down_maps) + if inactive_replicas: + scale_down_maps = _build_scale_down_job_update_maps( + inactive_replicas, len(inactive_replicas) + ) + job_id_to_update_map.update(scale_down_maps) + + return job_id_to_update_map diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/common.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/common.py new file mode 100644 index 000000000..0c6c9730c --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/common.py @@ -0,0 +1,117 @@ +import json +from datetime import datetime +from typing import Optional + +from dstack._internal.core.models.configurations import ( + DEFAULT_REPLICA_GROUP_NAME, + ServiceConfiguration, +) +from dstack._internal.core.models.runs import JobStatus, RunSpec +from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats +from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.services.jobs import get_job_spec, get_jobs_from_run_spec +from dstack._internal.server.services.runs import create_job_model_for_new_submission +from dstack._internal.server.services.runs.replicas import build_replica_lists +from dstack._internal.server.services.services.autoscalers import get_service_scaler + +PerGroupDesiredCounts = dict[str, int] +"""Maps group_name → desired replica count""" + + +def compute_desired_replica_counts( + run_model: RunModel, + configuration: ServiceConfiguration, + gateway_stats: Optional[PerWindowStats], + last_scaled_at: Optional[datetime], +) -> tuple[int, PerGroupDesiredCounts]: + """Returns (total_desired, per_group_desired_counts).""" + replica_groups = configuration.replica_groups + prev_counts: PerGroupDesiredCounts = ( + json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} + ) + if ( + prev_counts == {} + and len(replica_groups) == 1 + and replica_groups[0].name == DEFAULT_REPLICA_GROUP_NAME + ): + # Special case to avoid dropping the replica count to group.count.min + # when a 0.20.7+ server first processes a service created by a pre-0.20.7 server. + # TODO: remove once most users upgrade to 0.20.7+. + prev_counts = {DEFAULT_REPLICA_GROUP_NAME: run_model.desired_replica_count} + desired_counts: PerGroupDesiredCounts = {} + total = 0 + for group in replica_groups: + scaler = get_service_scaler(group.count, group.scaling) + assert group.name is not None, "Group name is always set" + group_desired = scaler.get_desired_count( + current_desired_count=prev_counts.get(group.name, group.count.min or 0), + stats=gateway_stats, + last_scaled_at=last_scaled_at, + ) + desired_counts[group.name] = group_desired + total += group_desired + return total, desired_counts + + +async def build_scale_up_job_models( + run_model: RunModel, + run_spec: RunSpec, + secrets: dict, + replicas_diff: int, + group_name: Optional[str] = None, + replica_num_start: Optional[int] = None, +) -> list[JobModel]: + """Build new JobModel instances for scaling up.""" + if replicas_diff <= 0: + return [] + + _, inactive_replicas = build_replica_lists(run_model, group_filter=group_name) + new_job_models: list[JobModel] = [] + scheduled_replicas = 0 + + # Retry inactive replicas first. + for _, _, replica_num, replica_jobs in inactive_replicas: + if scheduled_replicas == replicas_diff: + break + job_spec = get_job_spec(replica_jobs[0]) + replica_group_name = job_spec.replica_group + new_jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, + replica_group_name=replica_group_name, + ) + for old_job_model, new_job in zip(replica_jobs, new_jobs): + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=new_job, + status=JobStatus.SUBMITTED, + ) + job_model.submission_num = old_job_model.submission_num + 1 + new_job_models.append(job_model) + scheduled_replicas += 1 + + # Create new replicas for the remainder + if scheduled_replicas < replicas_diff: + if replica_num_start is not None: + first_replica_num = replica_num_start + else: + first_replica_num = max((job.replica_num for job in run_model.jobs), default=-1) + 1 + new_replicas_needed = replicas_diff - scheduled_replicas + for i in range(new_replicas_needed): + new_replica_num = first_replica_num + i + new_jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=new_replica_num, + replica_group_name=group_name, + ) + for new_job in new_jobs: + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=new_job, + status=JobStatus.SUBMITTED, + ) + new_job_models.append(job_model) + + return new_job_models diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/pending.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/pending.py new file mode 100644 index 000000000..863b36f9a --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/pending.py @@ -0,0 +1,142 @@ +import json +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Optional + +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.runs import RunSpec, RunStatus +from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats +from dstack._internal.server.background.pipeline_tasks.base import ItemUpdateMap +from dstack._internal.server.background.pipeline_tasks.runs.common import ( + build_scale_up_job_models, + compute_desired_replica_counts, +) +from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class PendingRunUpdateMap(ItemUpdateMap, total=False): + status: RunStatus + desired_replica_count: int + desired_replica_counts: Optional[str] + + +@dataclass +class PendingContext: + run_model: RunModel + run_spec: RunSpec + secrets: dict + locked_job_ids: set[uuid.UUID] + gateway_stats: Optional[PerWindowStats] = None + + +@dataclass +class PendingResult: + run_update_map: PendingRunUpdateMap + new_job_models: list[JobModel] + + +async def process_pending_run(context: PendingContext) -> Optional[PendingResult]: + """ + Returns None if the run is not ready for processing (retry delay not met, + zero-scaled service, etc.). Otherwise returns a result describing the + desired state change and pre-built job models. + """ + run_model = context.run_model + run_spec = context.run_spec + + if run_model.resubmission_attempt > 0 and not _is_ready_for_resubmission(run_model): + return None + + if run_spec.configuration.type == "service": + return await _process_pending_service(context) + + desired_replica_count = 1 + new_job_models = await build_scale_up_job_models( + run_model=run_model, + run_spec=run_spec, + secrets=context.secrets, + replicas_diff=desired_replica_count, + ) + return PendingResult( + run_update_map=PendingRunUpdateMap( + status=RunStatus.SUBMITTED, + desired_replica_count=desired_replica_count, + ), + new_job_models=new_job_models, + ) + + +async def _process_pending_service(context: PendingContext) -> Optional[PendingResult]: + run_model = context.run_model + run_spec = context.run_spec + assert isinstance(run_spec.configuration, ServiceConfiguration) + configuration = run_spec.configuration + + total, per_group_desired = compute_desired_replica_counts( + run_model=run_model, + configuration=configuration, + gateway_stats=context.gateway_stats, + last_scaled_at=None, + ) + if total == 0: + return None + + all_new_job_models: list[JobModel] = [] + next_replica_num = max((j.replica_num for j in run_model.jobs), default=-1) + 1 + for group in configuration.replica_groups: + assert group.name is not None + group_desired = per_group_desired.get(group.name, 0) + if group_desired <= 0: + continue + new_job_models = await build_scale_up_job_models( + run_model=run_model, + run_spec=run_spec, + secrets=context.secrets, + replicas_diff=group_desired, + group_name=group.name, + replica_num_start=next_replica_num, + ) + next_replica_num += group_desired + all_new_job_models.extend(new_job_models) + + return PendingResult( + run_update_map=PendingRunUpdateMap( + status=RunStatus.SUBMITTED, + desired_replica_count=total, + desired_replica_counts=json.dumps(per_group_desired), + ), + new_job_models=all_new_job_models, + ) + + +def _is_ready_for_resubmission(run_model: RunModel) -> bool: + if not run_model.jobs: + # No jobs yet — should not be possible for resubmission, but allow processing. + return True + last_processed_at = max(job.last_processed_at for job in run_model.jobs) + duration_since_processing = get_current_datetime() - last_processed_at + return duration_since_processing >= _get_retry_delay(run_model.resubmission_attempt) + + +# We use exponentially increasing retry delays for pending runs. +# This prevents creation of too many job submissions for runs stuck in pending, +# e.g. when users set retry for a long period without capacity. +_PENDING_RETRY_DELAYS = [ + timedelta(seconds=15), + timedelta(seconds=30), + timedelta(minutes=1), + timedelta(minutes=2), + timedelta(minutes=5), + timedelta(minutes=10), +] + + +def _get_retry_delay(resubmission_attempt: int) -> timedelta: + if resubmission_attempt - 1 < len(_PENDING_RETRY_DELAYS): + return _PENDING_RETRY_DELAYS[resubmission_attempt - 1] + return _PENDING_RETRY_DELAYS[-1] diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py new file mode 100644 index 000000000..a84d11cbf --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py @@ -0,0 +1,172 @@ +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Optional + +import httpx + +from dstack._internal.core.errors import GatewayError, SSHError +from dstack._internal.core.models.runs import ( + JobStatus, + JobTerminationReason, + RunStatus, + RunTerminationReason, +) +from dstack._internal.server import models +from dstack._internal.server.background.pipeline_tasks.base import ItemUpdateMap +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.services import events +from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.jobs import stop_runner +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runs import _get_next_triggered_at, get_run_spec +from dstack._internal.utils.common import get_current_datetime, get_or_error +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class TerminatingRunUpdateMap(ItemUpdateMap, total=False): + status: RunStatus + next_triggered_at: Optional[datetime] + fleet_id: Optional[uuid.UUID] + + +class TerminatingRunJobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + remove_at: Optional[datetime] + + +@dataclass +class ServiceUnregistration: + event_message: str + gateway_target: Optional[events.Target] + + +@dataclass +class TerminatingContext: + run_model: models.RunModel + locked_job_ids: set[uuid.UUID] + + +@dataclass +class TerminatingResult: + run_update_map: TerminatingRunUpdateMap = field(default_factory=TerminatingRunUpdateMap) + job_id_to_update_map: dict[uuid.UUID, TerminatingRunJobUpdateMap] = field(default_factory=dict) + service_unregistration: Optional[ServiceUnregistration] = None + + +async def process_terminating_run(context: TerminatingContext) -> TerminatingResult: + """ + Stops the jobs gracefully and marks them as TERMINATING. + Jobs then should be terminated by `JobTerminatingPipeline`. + When all jobs are already terminated, assigns a finished status to the run. + Caller must preload the run, acquire related job locks, and apply the result. + """ + run_model = context.run_model + assert run_model.termination_reason is not None + + job_termination_reason = run_model.termination_reason.to_job_termination_reason() + if len(context.locked_job_ids) > 0: + locked_jobs = [j for j in run_model.jobs if j.id in context.locked_job_ids] + delayed_job_ids = [] + regular_job_ids = [] + for job_model in locked_jobs: + if job_model.status == JobStatus.RUNNING and job_termination_reason not in { + JobTerminationReason.ABORTED_BY_USER, + JobTerminationReason.DONE_BY_RUNNER, + }: + # Send a signal to stop the job gracefully. + await stop_runner( + job_model=job_model, instance_model=get_or_error(job_model.instance) + ) + delayed_job_ids.append(job_model.id) + continue + regular_job_ids.append(job_model.id) + return TerminatingResult( + job_id_to_update_map=_get_job_id_to_update_map( + delayed_job_ids=delayed_job_ids, + regular_job_ids=regular_job_ids, + job_termination_reason=job_termination_reason, + ) + ) + + if any(not job_model.status.is_finished() for job_model in run_model.jobs): + return TerminatingResult() + + service_unregistration = None + if run_model.service_spec is not None: + try: + service_unregistration = await _unregister_service(run_model) + except Exception as e: + logger.warning("%s: failed to unregister service: %s", fmt(run_model), repr(e)) + + return TerminatingResult( + run_update_map=_get_run_update_map(run_model), + service_unregistration=service_unregistration, + ) + + +def _get_job_id_to_update_map( + delayed_job_ids: list[uuid.UUID], + regular_job_ids: list[uuid.UUID], + job_termination_reason: JobTerminationReason, +) -> dict[uuid.UUID, TerminatingRunJobUpdateMap]: + job_id_to_update_map = {} + for job_id in regular_job_ids: + job_id_to_update_map[job_id] = TerminatingRunJobUpdateMap( + status=JobStatus.TERMINATING, + termination_reason=job_termination_reason, + ) + for job_id in delayed_job_ids: + job_id_to_update_map[job_id] = TerminatingRunJobUpdateMap( + status=JobStatus.TERMINATING, + termination_reason=job_termination_reason, + remove_at=get_current_datetime() + timedelta(seconds=15), + ) + return job_id_to_update_map + + +def _get_run_update_map(run_model: models.RunModel) -> TerminatingRunUpdateMap: + termination_reason = get_or_error(run_model.termination_reason) + run_spec = get_run_spec(run_model) + if run_spec.merged_profile.schedule is not None and termination_reason not in { + RunTerminationReason.ABORTED_BY_USER, + RunTerminationReason.STOPPED_BY_USER, + }: + return TerminatingRunUpdateMap( + status=RunStatus.PENDING, + next_triggered_at=_get_next_triggered_at(run_spec), + fleet_id=None, + ) + return TerminatingRunUpdateMap(status=termination_reason.to_status()) + + +async def _unregister_service(run_model: models.RunModel) -> Optional[ServiceUnregistration]: + if run_model.gateway_id is None: # in-server proxy + return None + + async with get_session_ctx() as session: + gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway_target = events.Target.from_model(gateway) + + try: + logger.debug("%s: unregistering service", fmt(run_model)) + async with conn.client() as client: + await client.unregister_service( + project=run_model.project.name, + run_name=run_model.run_name, + ) + event_message = "Service unregistered from gateway" + except GatewayError as e: + # Ignore if the service is not registered. + logger.warning("%s: unregistering service: %s", fmt(run_model), e) + event_message = f"Gateway error when unregistering service: {e}" + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + return ServiceUnregistration( + event_message=event_message, + gateway_target=gateway_target, + ) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index d1b1eda5a..6bc343a55 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -122,13 +122,13 @@ def start_scheduled_tasks() -> AsyncIOScheduler: # Add multiple copies of tasks if requested. # max_instances=1 for additional copies to avoid running too many tasks. # Move other tasks here when they need per-replica scaling. - _scheduler.add_job( - process_runs, - IntervalTrigger(seconds=2, jitter=1), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_runs, + IntervalTrigger(seconds=2, jitter=1), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.add_job( process_instances, IntervalTrigger(seconds=4, jitter=2), diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_19_0924_7099b48e72a9_add_runmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/03_19_0924_7099b48e72a9_add_runmodel_pipeline_columns.py new file mode 100644 index 000000000..353dbadee --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_19_0924_7099b48e72a9_add_runmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add RunModel pipeline columns + +Revision ID: 7099b48e72a9 +Revises: 8b6d5d8c1b9a +Create Date: 2026-03-19 09:24:29.042905+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "7099b48e72a9" +down_revision = "8b6d5d8c1b9a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_24_0528_c1c2ecaee45c_add_ix_runs_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_24_0528_c1c2ecaee45c_add_ix_runs_pipeline_fetch_q_index.py new file mode 100644 index 000000000..eb47db4e4 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_24_0528_c1c2ecaee45c_add_ix_runs_pipeline_fetch_q_index.py @@ -0,0 +1,49 @@ +"""Add ix_runs_pipeline_fetch_q index + +Revision ID: c1c2ecaee45c +Revises: 7099b48e72a9 +Create Date: 2026-03-24 05:28:50.925623+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c1c2ecaee45c" +down_revision = "7099b48e72a9" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_runs_pipeline_fetch_q", + table_name="runs", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_runs_pipeline_fetch_q", + "runs", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("(status NOT IN ('TERMINATED', 'FAILED', 'DONE'))"), + postgresql_where=sa.text("(status NOT IN ('TERMINATED', 'FAILED', 'DONE'))"), + postgresql_concurrently=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_runs_pipeline_fetch_q", + table_name="runs", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 958a8bacf..b599c4314 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -391,7 +391,7 @@ class FileArchiveModel(BaseModel): """`blob` is stored on S3 when it is `None`.""" -class RunModel(BaseModel): +class RunModel(PipelineModelMixin, BaseModel): __tablename__ = "runs" id: Mapped[uuid.UUID] = mapped_column( @@ -443,7 +443,15 @@ class RunModel(BaseModel): ) gateway: Mapped[Optional["GatewayModel"]] = relationship() - __table_args__ = (Index("ix_submitted_at_id", submitted_at.desc(), id),) + __table_args__ = ( + Index("ix_submitted_at_id", submitted_at.desc(), id), + Index( + "ix_runs_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=status.not_in(RunStatus.finished_statuses()), + sqlite_where=status.not_in(RunStatus.finished_statuses()), + ), + ) class JobModel(PipelineModelMixin, BaseModel): @@ -697,6 +705,7 @@ class InstanceModel(PipelineModelMixin, BaseModel): back_populates="instances", foreign_keys=[fleet_id], ) + """`fleet` can be `None` only for legacy instances created before fleets.""" compute_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("compute_groups.id")) compute_group: Mapped[Optional["ComputeGroupModel"]] = relationship(back_populates="instances") diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 569c3d6d1..bc3cbbdd3 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -614,6 +614,7 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> return name +# TODO: Connect to gateway outside session async def get_or_add_gateway_connection( session: AsyncSession, gateway_id: uuid.UUID ) -> tuple[GatewayModel, GatewayConnection]: diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 62254d765..94054ca58 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -7,7 +7,7 @@ import requests from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.orm import load_only from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT from dstack._internal.core.errors import ( @@ -302,15 +302,12 @@ def _get_job_configurator( _configuration_type_to_configurator_class_map = {c.TYPE: c for c in _job_configurator_classes} -async def stop_runner(session: AsyncSession, job_model: JobModel): - res = await session.execute( - select(InstanceModel) - .where(InstanceModel.id == job_model.instance_id) - .options(joinedload(InstanceModel.project)) - ) - instance: Optional[InstanceModel] = res.scalar() - - ssh_private_keys = get_instance_ssh_private_keys(common.get_or_error(instance)) +async def stop_runner(job_model: JobModel, instance_model: InstanceModel): + """ + Stops the runner using a preloaded instance model. + `instance_model.project` must be loaded because SSH key resolution uses the project keys. + """ + ssh_private_keys = get_instance_ssh_private_keys(instance_model) try: jpd = get_job_provisioning_data(job_model) if jpd is not None: diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 3fb9468d3..829691aee 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -40,6 +40,7 @@ from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( FleetModel, + InstanceModel, JobModel, ProbeModel, ProjectModel, @@ -104,13 +105,47 @@ def switch_run_status( return run_model.status = new_status + emit_run_status_change_event( + session=session, + run_model=run_model, + old_status=old_status, + new_status=new_status, + actor=actor, + ) + + +def emit_run_status_change_event( + session: AsyncSession, + run_model: RunModel, + old_status: RunStatus, + new_status: RunStatus, + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + events.emit( + session, + get_run_status_change_message( + old_status=old_status, + new_status=new_status, + termination_reason=run_model.termination_reason, + ), + actor=actor, + targets=[events.Target.from_model(run_model)], + ) + +def get_run_status_change_message( + old_status: RunStatus, + new_status: RunStatus, + termination_reason: Optional[RunTerminationReason], +) -> str: msg = f"Run status changed {old_status.upper()} -> {new_status.upper()}" if new_status == RunStatus.TERMINATING: - if run_model.termination_reason is None: + if termination_reason is None: raise ValueError("termination_reason must be set when switching to TERMINATING status") - msg += f". Termination reason: {run_model.termination_reason.upper()}" - events.emit(session, msg, actor=actor, targets=[events.Target.from_model(run_model)]) + msg += f". Termination reason: {termination_reason.upper()}" + return msg def get_run_spec(run_model: RunModel) -> RunSpec: @@ -613,6 +648,7 @@ async def submit_run( await session.commit() if pipeline_hinter is not None: pipeline_hinter.hint_fetch(JobModel.__name__) + pipeline_hinter.hint_fetch(RunModel.__name__) await session.refresh(run_model) run = await get_run_by_id(session, project, run_model.id) @@ -1020,7 +1056,10 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel): JobTerminationReason.DONE_BY_RUNNER, }: # Send a signal to stop the job gracefully - await stop_runner(session, job_model) + instance_model = await _load_job_instance_for_stop( + session=session, job_model=job_model + ) + await stop_runner(job_model, common_utils.get_or_error(instance_model)) delay_job_instance_termination(job_model) job_model.termination_reason = job_termination_reason switch_job_status(session, job_model, JobStatus.TERMINATING) @@ -1063,3 +1102,17 @@ def _get_next_triggered_at(run_spec: RunSpec) -> Optional[datetime]: ) ) return min(fire_times) + + +async def _load_job_instance_for_stop( + session: AsyncSession, + job_model: JobModel, +) -> Optional[InstanceModel]: + if job_model.instance_id is None: + return None + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.id == job_model.instance_id) + .options(joinedload(InstanceModel.project)) + ) + return res.scalar_one_or_none() diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 87db5ec8e..e38907b01 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -322,6 +322,7 @@ async def create_run( deployment_num: int = 0, resubmission_attempt: int = 0, next_triggered_at: Optional[datetime] = None, + last_processed_at: Optional[datetime] = None, ) -> RunModel: if run_name is None: run_name = "test-run" @@ -332,6 +333,8 @@ async def create_run( ) if run_id is None: run_id = uuid.uuid4() + if last_processed_at is None: + last_processed_at = submitted_at run = RunModel( id=run_id, deleted=deleted, @@ -344,7 +347,7 @@ async def create_run( status=status, termination_reason=termination_reason, run_spec=run_spec.json(), - last_processed_at=submitted_at, + last_processed_at=last_processed_at, jobs=[], priority=priority, deployment_num=deployment_num, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index d2b53226e..ea9788b9d 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -743,6 +743,77 @@ async def test_min_zero_failed_master_terminates_unprovisioned_siblings( assert sibling2.termination_reason == InstanceTerminationReason.MASTER_FAILED assert fleet.current_master_instance_id is None + async def test_master_failure_path_resets_when_sibling_instance_is_locked( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=0, target=3, max=3), + ) + ), + ) + failed_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + failed_master.termination_reason = InstanceTerminationReason.NO_OFFERS + locked_sibling = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + free_sibling = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=2, + ) + original_last_processed_at = fleet.last_processed_at + fleet.current_master_instance_id = failed_master.id + fleet.consolidation_attempt = 1 + fleet.last_consolidated_at = datetime.now(timezone.utc) + await _lock_fleet_for_processing(session, fleet) + fleet.lock_owner = FleetPipeline.__name__ + locked_sibling.lock_token = uuid.uuid4() + locked_sibling.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + locked_sibling.lock_owner = "OtherPipeline" + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(failed_master) + await session.refresh(locked_sibling) + await session.refresh(free_sibling) + assert fleet.current_master_instance_id == failed_master.id + assert fleet.lock_owner == FleetPipeline.__name__ + assert fleet.lock_token is None + assert fleet.lock_expires_at is None + assert fleet.last_processed_at > original_last_processed_at + assert not failed_master.deleted + assert locked_sibling.status == InstanceStatus.PENDING + assert locked_sibling.termination_reason is None + assert locked_sibling.lock_owner == "OtherPipeline" + assert free_sibling.status == InstanceStatus.PENDING + assert free_sibling.termination_reason is None + async def test_min_zero_failed_master_preserves_provisioned_survivor( self, test_db, session: AsyncSession, worker: FleetWorker ): diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_runs/__init__.py b/src/tests/_internal/server/background/pipeline_tasks/test_runs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_runs/conftest.py b/src/tests/_internal/server/background/pipeline_tasks/test_runs/conftest.py new file mode 100644 index 000000000..6612232cc --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_runs/conftest.py @@ -0,0 +1,23 @@ +import asyncio +from datetime import timedelta +from unittest.mock import Mock + +import pytest + +from dstack._internal.server.background.pipeline_tasks.runs import RunFetcher, RunWorker + + +@pytest.fixture +def fetcher() -> RunFetcher: + return RunFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=5), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + +@pytest.fixture +def worker() -> RunWorker: + return RunWorker(queue=Mock(), heartbeater=Mock()) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_runs/helpers.py b/src/tests/_internal/server/background/pipeline_tasks/test_runs/helpers.py new file mode 100644 index 000000000..1edf90774 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_runs/helpers.py @@ -0,0 +1,34 @@ +import datetime as dt +import uuid + +from dstack._internal.server.background.pipeline_tasks.runs import ( + RunPipeline, + RunPipelineItem, +) +from dstack._internal.server.models import RunModel + +LOCK_EXPIRES_AT = dt.datetime(2025, 1, 2, 3, 4, tzinfo=dt.timezone.utc) + + +def run_to_pipeline_item(run_model: RunModel) -> RunPipelineItem: + assert run_model.lock_token is not None + assert run_model.lock_expires_at is not None + return RunPipelineItem( + __tablename__=run_model.__tablename__, + id=run_model.id, + lock_token=run_model.lock_token, + lock_expires_at=run_model.lock_expires_at, + prev_lock_expired=False, + status=run_model.status, + ) + + +def lock_run( + run_model: RunModel, + *, + lock_owner: str = RunPipeline.__name__, + lock_expires_at: dt.datetime = LOCK_EXPIRES_AT, +) -> None: + run_model.lock_token = uuid.uuid4() + run_model.lock_expires_at = lock_expires_at + run_model.lock_owner = lock_owner diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_active.py b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_active.py new file mode 100644 index 000000000..e9a55d6f1 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_active.py @@ -0,0 +1,948 @@ +import json +import uuid +from datetime import timedelta +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.configurations import ( + ScalingSpec, + ServiceConfiguration, + TaskConfiguration, +) +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.profiles import ( + Profile, + ProfileRetry, + RetryEvent, + StopCriteria, +) +from dstack._internal.core.models.resources import Range +from dstack._internal.core.models.runs import ( + JobStatus, + JobTerminationReason, + RunStatus, + RunTerminationReason, +) +from dstack._internal.server.background.pipeline_tasks.runs import RunWorker +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.jobs import get_job_spec +from dstack._internal.server.testing.common import ( + create_fleet, + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_job_provisioning_data, + get_run_spec, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_runs.helpers import ( + lock_run, + run_to_pipeline_item, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock") +class TestRunActiveWorker: + async def test_transitions_submitted_to_provisioning( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.SUBMITTED, + ) + await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + instance=instance, + instance_assigned=True, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.PROVISIONING + assert run.lock_token is None + + async def test_transitions_provisioning_to_running( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.PROVISIONING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert run.lock_token is None + + async def test_terminates_run_when_all_jobs_done( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.DONE, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.ALL_JOBS_DONE + assert run.lock_token is None + + async def test_terminates_run_on_job_failure( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.CONTAINER_EXITED_WITH_ERROR, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.JOB_FAILED + assert run.lock_token is None + + async def test_retries_failed_replica_within_retry_duration( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """When a replica fails within retry duration, the run goes to PENDING with + resubmission_attempt incremented. The pending worker then creates the new submission.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile( + name="default", + retry=ProfileRetry(duration=3600, on_events=[RetryEvent.ERROR]), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + resubmission_attempt=0, + ) + old_time = get_current_datetime() - timedelta(minutes=5) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.CONTAINER_EXITED_WITH_ERROR, + job_provisioning_data=get_job_provisioning_data(), + last_processed_at=old_time, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + # Retryable failure → PENDING with resubmission_attempt incremented + assert run.status == RunStatus.PENDING + assert run.resubmission_attempt == 1 + assert run.lock_token is None + + async def test_retries_no_capacity_replica_and_keeps_service_running( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile( + name="default", + retry=ProfileRetry(duration=3600, on_events=[RetryEvent.INTERRUPTION]), + ), + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + replicas=Range[int](min=2, max=2), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + interrupted_job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY, + submitted_at=run.submitted_at, + last_processed_at=run.last_processed_at, + replica_num=0, + job_provisioning_data=get_job_provisioning_data(), + ) + healthy_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + submitted_at=run.submitted_at, + last_processed_at=run.last_processed_at, + replica_num=1, + job_provisioning_data=get_job_provisioning_data(), + ) + lock_run(run) + await session.commit() + + now = run.submitted_at + timedelta(minutes=3) + with patch( + "dstack._internal.server.background.pipeline_tasks.runs.active.get_current_datetime", + return_value=now, + ): + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + await session.refresh(interrupted_job) + await session.refresh(healthy_job) + + jobs = list( + ( + await session.execute( + select(JobModel) + .where(JobModel.run_id == run.id) + .order_by(JobModel.replica_num, JobModel.submission_num) + ) + ).scalars() + ) + retried_job = next(job for job in jobs if job.replica_num == 0 and job.submission_num == 1) + + assert run.status == RunStatus.RUNNING + assert interrupted_job.status == JobStatus.TERMINATING + assert ( + interrupted_job.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + ) + assert healthy_job.status == JobStatus.RUNNING + assert retried_job.status == JobStatus.SUBMITTED + assert len(jobs) == 3 + + async def test_retrying_multinode_replica_terminates_active_sibling_jobs( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile( + name="default", + retry=ProfileRetry(duration=3600, on_events=[RetryEvent.ERROR]), + ), + configuration=TaskConfiguration( + commands=["echo hello"], + nodes=2, + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + failed_job = await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.CONTAINER_EXITED_WITH_ERROR, + replica_num=0, + job_num=0, + job_provisioning_data=get_job_provisioning_data(), + last_processed_at=run.submitted_at, + ) + running_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + job_num=1, + job_provisioning_data=get_job_provisioning_data(), + last_processed_at=run.submitted_at, + ) + lock_run(run) + await session.commit() + + now = run.submitted_at + timedelta(minutes=1) + with patch( + "dstack._internal.server.background.pipeline_tasks.runs.active.get_current_datetime", + return_value=now, + ): + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + await session.refresh(failed_job) + await session.refresh(running_job) + + assert run.status == RunStatus.PENDING + assert failed_job.status == JobStatus.FAILED + assert running_job.status == JobStatus.TERMINATING + assert running_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert running_job.termination_reason_message == "Run is to be resubmitted" + + async def test_transitions_to_pending_when_retry_duration_exceeded( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile( + name="default", + retry=ProfileRetry(duration=60, on_events=[RetryEvent.ERROR]), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + resubmission_attempt=0, + ) + # Last provisioned long ago so retry duration is exceeded + very_old_time = get_current_datetime() - timedelta(hours=2) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.CONTAINER_EXITED_WITH_ERROR, + job_provisioning_data=get_job_provisioning_data(), + last_processed_at=very_old_time, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED + assert run.lock_token is None + + async def test_stops_on_master_done( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile(name="default", stop_criteria=StopCriteria.MASTER_DONE), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + # Master job (job_num=0) is done + await create_job( + session=session, + run=run, + status=JobStatus.DONE, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + job_num=0, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.ALL_JOBS_DONE + + async def test_sets_fleet_id_from_job_instance( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.SUBMITTED, + ) + assert run.fleet_id is None + await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + instance=instance, + instance_assigned=True, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.fleet_id == fleet.id + + async def test_service_noop_when_at_desired_count( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Service with 1 RUNNING replica and desired=1 stays RUNNING, no new jobs.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert run.desired_replica_count == 1 + assert run.desired_replica_counts is not None + counts = json.loads(run.desired_replica_counts) + assert counts == {"0": 1} + assert run.lock_token is None + + async def test_service_scale_up( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Service with min=2 and 1 RUNNING replica creates 1 new SUBMITTED job.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + replicas=Range[int](min=2, max=2), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.SUBMITTED, + ) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert run.desired_replica_count == 2 + + res = await session.execute( + select(JobModel).where(JobModel.run_id == run.id).order_by(JobModel.replica_num) + ) + jobs = list(res.scalars().all()) + assert len(jobs) == 2 + assert jobs[0].status == JobStatus.RUNNING + assert jobs[0].replica_num == 0 + assert jobs[1].status == JobStatus.SUBMITTED + assert jobs[1].replica_num == 1 + + async def test_service_scale_down( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Service with min=1 and 2 RUNNING replicas terminates 1 with SCALED_DOWN.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + replicas=Range[int](min=1, max=1), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + run.desired_replica_count = 2 + run.desired_replica_counts = json.dumps({"0": 2}) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + ) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=1, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert run.desired_replica_count == 1 + + res = await session.execute( + select(JobModel).where(JobModel.run_id == run.id).order_by(JobModel.replica_num) + ) + jobs = list(res.scalars().all()) + assert len(jobs) == 2 + # One should remain RUNNING, the other should be TERMINATING with SCALED_DOWN + running = [j for j in jobs if j.status == JobStatus.RUNNING] + terminating = [j for j in jobs if j.status == JobStatus.TERMINATING] + assert len(running) == 1 + assert len(terminating) == 1 + assert terminating[0].termination_reason == JobTerminationReason.SCALED_DOWN + + async def test_service_zero_scale_noop( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Active service with 0 desired and no active replicas stays in current status.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + replicas=Range[int](min=0, max=2), + scaling=ScalingSpec(metric="rps", target=10), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + run.desired_replica_count = 0 + run.desired_replica_counts = json.dumps({"0": 0}) + # Create a terminated/scaled-down job to have some job history + await create_job( + session=session, + run=run, + status=JobStatus.TERMINATED, + termination_reason=JobTerminationReason.SCALED_DOWN, + replica_num=0, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + # All replicas scaled down → transitions to PENDING + assert run.status == RunStatus.PENDING + assert run.lock_token is None + + async def test_noops_when_run_lock_changes_after_processing( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.DONE, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + ) + lock_run(run) + await session.commit() + item = run_to_pipeline_item(run) + new_lock_token = uuid.uuid4() + + from dstack._internal.server.background.pipeline_tasks.runs.active import ( + ActiveResult, + ActiveRunUpdateMap, + ) + + async def intercept_process(context): + # Change the lock token to simulate concurrent modification + run.lock_token = new_lock_token + run.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + await session.commit() + return ActiveResult( + run_update_map=ActiveRunUpdateMap( + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + ), + new_job_models=[], + job_id_to_update_map={}, + ) + + with patch( + "dstack._internal.server.background.pipeline_tasks.runs.active.process_active_run", + new=AsyncMock(side_effect=intercept_process), + ): + await worker.process(item) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + assert run.lock_token == new_lock_token + + async def test_service_in_place_deployment_bump( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Service with 1 RUNNING replica at deployment_num=0, run at deployment_num=1, + same job spec → job gets deployment_num bumped to 1.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.RUNNING, + deployment_num=1, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + deployment_num=0, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + + await session.refresh(job) + assert job.deployment_num == 1 + + async def test_service_rolling_deployment_scale_up( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Service with 1 out-of-date RUNNING replica whose spec differs from the new + deployment, desired=1 → creates 1 new replica (surge), old registered replica + untouched.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo new!"], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.RUNNING, + deployment_num=1, + ) + old_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + deployment_num=0, + registered=True, + replica_num=0, + ) + # Make the old job's spec differ from the current run_spec so in-place bump + # cannot be applied and rolling deployment is triggered instead. + old_spec = get_job_spec(old_job) + old_spec.commands = ["echo old!"] + old_job.job_spec_data = old_spec.json() + await session.commit() + + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + + res = await session.execute( + select(JobModel).where(JobModel.run_id == run.id).order_by(JobModel.replica_num) + ) + jobs = list(res.scalars().all()) + assert len(jobs) == 2 + # Old replica still RUNNING (registered, not terminated during rolling) + assert jobs[0].status == JobStatus.RUNNING + assert jobs[0].deployment_num == 0 + # New surge replica created + assert jobs[1].status == JobStatus.SUBMITTED + assert jobs[1].deployment_num == 1 + + async def test_service_rolling_deployment_scale_down_old_unregistered( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Service with 1 up-to-date RUNNING+registered and 1 out-of-date RUNNING+unregistered + replica (with a different spec) → old unregistered replica terminated.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo new!"], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.RUNNING, + deployment_num=1, + ) + # Up-to-date registered replica + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + deployment_num=1, + registered=True, + replica_num=0, + ) + # Out-of-date unregistered replica with different spec + old_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + deployment_num=0, + registered=False, + replica_num=1, + ) + old_spec = get_job_spec(old_job) + old_spec.commands = ["echo old!"] + old_job.job_spec_data = old_spec.json() + await session.commit() + + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + + await session.refresh(old_job) + assert old_job.status == JobStatus.TERMINATING + assert old_job.termination_reason == JobTerminationReason.SCALED_DOWN + + async def test_service_removed_group_cleanup( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + """Service run with jobs belonging to group "old" not in current config → + those jobs get TERMINATING with SCALED_DOWN.""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + # Current config only has group "0" (default) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + # Active replica in current group "0" + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + ) + # Replica belonging to a removed group "old" — manually set job_spec_data + old_group_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=1, + ) + # Patch the job spec to have replica_group="old" + old_spec = get_job_spec(old_group_job) + old_spec.replica_group = "old" + old_group_job.job_spec_data = old_spec.json() + await session.commit() + + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.RUNNING + + await session.refresh(old_group_job) + assert old_group_job.status == JobStatus.TERMINATING + assert old_group_job.termination_reason == JobTerminationReason.SCALED_DOWN diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_pending.py b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_pending.py new file mode 100644 index 000000000..6cfe5fe00 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_pending.py @@ -0,0 +1,291 @@ +import json +import uuid +from datetime import timedelta +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.configurations import ScalingSpec, ServiceConfiguration +from dstack._internal.core.models.resources import Range +from dstack._internal.core.models.runs import ( + JobStatus, + RunStatus, +) +from dstack._internal.server.background.pipeline_tasks.runs import RunWorker +from dstack._internal.server.models import JobModel +from dstack._internal.server.testing.common import ( + create_job, + create_project, + create_repo, + create_run, + create_user, + get_run_spec, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_runs.helpers import ( + lock_run, + run_to_pipeline_item, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock") +class TestRunPendingWorker: + async def test_submits_non_service_run_and_creates_job( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.PENDING, + resubmission_attempt=0, + next_triggered_at=None, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.SUBMITTED + assert run.desired_replica_count == 1 + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + res = await session.execute(select(JobModel).where(JobModel.run_id == run.id)) + jobs = list(res.scalars().all()) + assert len(jobs) == 1 + assert jobs[0].status == JobStatus.SUBMITTED + assert jobs[0].replica_num == 0 + assert jobs[0].submission_num == 0 + + async def test_skips_retrying_run_when_delay_not_met( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.PENDING, + resubmission_attempt=1, + ) + # Create a job with recent last_processed_at so retry delay is not met + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + last_processed_at=get_current_datetime(), + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.PENDING + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + async def test_resubmits_retrying_run_after_delay( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.PENDING, + resubmission_attempt=1, + ) + # Create a job with old last_processed_at so retry delay is met (>15s for attempt 1) + old_time = get_current_datetime() - timedelta(minutes=1) + old_job = await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + last_processed_at=old_time, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.SUBMITTED + assert run.desired_replica_count == 1 + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + # Should have created a new job (retry of the failed one) + res = await session.execute( + select(JobModel) + .where(JobModel.run_id == run.id) + .order_by(JobModel.submitted_at.desc()) + ) + jobs = list(res.scalars().all()) + assert len(jobs) == 2 + new_job = next(j for j in jobs if j.id != old_job.id) + assert new_job.status == JobStatus.SUBMITTED + assert new_job.replica_num == old_job.replica_num + assert new_job.submission_num == old_job.submission_num + 1 + + async def test_noops_when_run_lock_changes_after_processing( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.PENDING, + resubmission_attempt=0, + next_triggered_at=None, + ) + lock_run(run) + await session.commit() + item = run_to_pipeline_item(run) + new_lock_token = uuid.uuid4() + + from dstack._internal.server.background.pipeline_tasks.runs.pending import ( + PendingResult, + PendingRunUpdateMap, + ) + + async def intercept_process(context): + # Change the lock token to simulate concurrent modification + run.lock_token = new_lock_token + run.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + await session.commit() + # Return a result that would normally cause a state change + return PendingResult( + run_update_map=PendingRunUpdateMap( + status=RunStatus.SUBMITTED, + desired_replica_count=1, + ), + new_job_models=[], + ) + + with patch( + "dstack._internal.server.background.pipeline_tasks.runs.pending.process_pending_run", + new=AsyncMock(side_effect=intercept_process), + ): + await worker.process(item) + + await session.refresh(run) + assert run.status == RunStatus.PENDING + assert run.lock_token == new_lock_token + + async def test_submits_service_run_and_creates_jobs( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + replicas=Range[int](min=2, max=2), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.PENDING, + resubmission_attempt=0, + next_triggered_at=None, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.SUBMITTED + assert run.desired_replica_count == 2 + assert run.desired_replica_counts is not None + counts = json.loads(run.desired_replica_counts) + assert counts == {"0": 2} + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + res = await session.execute(select(JobModel).where(JobModel.run_id == run.id)) + jobs = list(res.scalars().all()) + assert len(jobs) == 2 + replica_nums = sorted(j.replica_num for j in jobs) + assert replica_nums == [0, 1] + assert all(j.status == JobStatus.SUBMITTED for j in jobs) + assert all(j.submission_num == 0 for j in jobs) + + async def test_noops_for_zero_scaled_service( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="service-run", + configuration=ServiceConfiguration( + port=8080, + commands=["echo Hi!"], + replicas=Range[int](min=0, max=2), + scaling=ScalingSpec(metric="rps", target=10), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="service-run", + run_spec=run_spec, + status=RunStatus.PENDING, + resubmission_attempt=0, + next_triggered_at=None, + ) + # Set desired_replica_count=0 and desired_replica_counts to match zero-scaled state. + run.desired_replica_count = 0 + run.desired_replica_counts = json.dumps({"0": 0}) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.PENDING + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + res = await session.execute(select(JobModel).where(JobModel.run_id == run.id)) + jobs = list(res.scalars().all()) + assert len(jobs) == 0 diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_pipeline.py b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_pipeline.py new file mode 100644 index 000000000..51a27fcc1 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_pipeline.py @@ -0,0 +1,250 @@ +import datetime as dt +import uuid + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.runs import RunStatus +from dstack._internal.server.background.pipeline_tasks.runs import ( + RunFetcher, + RunPipeline, +) +from dstack._internal.server.testing.common import ( + create_project, + create_repo, + create_run, + create_user, +) +from dstack._internal.utils.common import get_current_datetime + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestRunFetcher: + async def test_fetch_selects_eligible_runs_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: RunFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + now = get_current_datetime() + stale = now - dt.timedelta(minutes=1) + + submitted = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="submitted", + status=RunStatus.SUBMITTED, + submitted_at=stale - dt.timedelta(seconds=5), + ) + running = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="running", + status=RunStatus.RUNNING, + submitted_at=stale - dt.timedelta(seconds=4), + ) + pending_retry = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="pending-retry", + status=RunStatus.PENDING, + submitted_at=stale - dt.timedelta(seconds=3), + resubmission_attempt=1, + ) + pending_scheduled_ready = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="pending-scheduled-ready", + status=RunStatus.PENDING, + submitted_at=stale - dt.timedelta(seconds=2), + next_triggered_at=stale, + ) + pending_zero_scaled = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="pending-zero-scaled", + status=RunStatus.PENDING, + submitted_at=stale - dt.timedelta(seconds=1), + ) + future_scheduled = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="future-scheduled", + status=RunStatus.PENDING, + submitted_at=stale, + next_triggered_at=now + dt.timedelta(minutes=1), + ) + finished = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="finished", + status=RunStatus.DONE, + submitted_at=stale + dt.timedelta(seconds=1), + ) + recent = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="recent", + status=RunStatus.RUNNING, + submitted_at=now, + last_processed_at=now + dt.timedelta(seconds=10), + ) + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + submitted.id, + running.id, + pending_retry.id, + pending_scheduled_ready.id, + pending_zero_scaled.id, + } + assert {item.id: item.status for item in items} == { + submitted.id: RunStatus.SUBMITTED, + running.id: RunStatus.RUNNING, + pending_retry.id: RunStatus.PENDING, + pending_scheduled_ready.id: RunStatus.PENDING, + pending_zero_scaled.id: RunStatus.PENDING, + } + + for run in [ + submitted, + running, + pending_retry, + pending_scheduled_ready, + pending_zero_scaled, + future_scheduled, + finished, + recent, + ]: + await session.refresh(run) + + fetched_runs = [ + submitted, + running, + pending_retry, + pending_scheduled_ready, + pending_zero_scaled, + ] + assert all(run.lock_owner == RunPipeline.__name__ for run in fetched_runs) + assert all(run.lock_expires_at is not None for run in fetched_runs) + assert all(run.lock_token is not None for run in fetched_runs) + assert len({run.lock_token for run in fetched_runs}) == 1 + + assert future_scheduled.lock_owner is None + assert finished.lock_owner is None + assert recent.lock_owner is None + + async def test_fetch_respects_order_and_limit( + self, test_db, session: AsyncSession, fetcher: RunFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + now = get_current_datetime() + + oldest = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="oldest", + status=RunStatus.SUBMITTED, + submitted_at=now - dt.timedelta(minutes=3), + ) + middle = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="middle", + status=RunStatus.RUNNING, + submitted_at=now - dt.timedelta(minutes=2), + ) + newest = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="newest", + status=RunStatus.SUBMITTED, + submitted_at=now - dt.timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == RunPipeline.__name__ + assert middle.lock_owner == RunPipeline.__name__ + assert newest.lock_owner is None + + async def test_fetch_retries_expired_same_owner_lock_and_skips_foreign_live_lock( + self, test_db, session: AsyncSession, fetcher: RunFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + now = get_current_datetime() + stale = now - dt.timedelta(minutes=1) + + expired_same_owner = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="expired-same-owner", + status=RunStatus.RUNNING, + submitted_at=stale, + ) + expired_same_owner.lock_expires_at = stale + expired_same_owner.lock_token = uuid.uuid4() + expired_same_owner.lock_owner = RunPipeline.__name__ + + foreign_locked = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="foreign-locked", + status=RunStatus.SUBMITTED, + submitted_at=stale + dt.timedelta(seconds=1), + ) + foreign_locked.lock_expires_at = now + dt.timedelta(minutes=1) + foreign_locked.lock_token = uuid.uuid4() + foreign_locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [expired_same_owner.id] + assert items[0].prev_lock_expired is True + + await session.refresh(expired_same_owner) + await session.refresh(foreign_locked) + + assert expired_same_owner.lock_owner == RunPipeline.__name__ + assert expired_same_owner.lock_expires_at is not None + assert expired_same_owner.lock_token is not None + assert foreign_locked.lock_owner == "OtherPipeline" diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_termination.py b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_termination.py new file mode 100644 index 000000000..d056df8a8 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_runs/test_termination.py @@ -0,0 +1,415 @@ +import uuid +from datetime import datetime, timedelta, timezone +from typing import Optional +from unittest.mock import AsyncMock, patch + +import pytest +from freezegun import freeze_time +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.profiles import Schedule +from dstack._internal.core.models.runs import ( + JobStatus, + JobTerminationReason, + RunStatus, + RunTerminationReason, +) +from dstack._internal.server.background.pipeline_tasks.jobs_terminating import ( + JobTerminatingPipeline, +) +from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline, RunWorker +from dstack._internal.server.testing.common import ( + create_fleet, + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_job_provisioning_data, + get_run_spec, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_runs.helpers import ( + lock_run, + run_to_pipeline_item, +) + + +def _lock_job( + job_model, + *, + lock_owner: str = RunPipeline.__name__, + lock_expires_at: Optional[datetime] = None, +) -> None: + if lock_expires_at is None: + lock_expires_at = get_current_datetime() + timedelta(seconds=30) + job_model.lock_token = uuid.uuid4() + job_model.lock_expires_at = lock_expires_at + job_model.lock_owner = lock_owner + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock") +class TestRunTerminatingWorker: + async def test_transitions_running_jobs_to_terminating( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.JOB_FAILED, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, + ) + lock_run(run) + await session.commit() + item = run_to_pipeline_item(run) + observed_job_lock = {} + + async def record_stop_call(**kwargs) -> None: + observed_job_lock["lock_token"] = kwargs["job_model"].lock_token + observed_job_lock["lock_owner"] = kwargs["job_model"].lock_owner + + with patch( + "dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner", + new=AsyncMock(side_effect=record_stop_call), + ) as stop_runner: + await worker.process(item) + + assert stop_runner.await_count == 1 + stop_call = stop_runner.await_args + assert stop_call is not None + assert stop_call.kwargs["job_model"].id == job.id + assert observed_job_lock["lock_token"] == item.lock_token + assert observed_job_lock["lock_owner"] == RunPipeline.__name__ + assert stop_call.kwargs["instance_model"].id == instance.id + + await session.refresh(job) + await session.refresh(run) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert job.remove_at is not None + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner is None + assert run.status == RunStatus.TERMINATING + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + async def test_updates_delayed_and_regular_jobs_separately( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.JOB_FAILED, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + delayed_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, + ) + regular_job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + job_num=1, + ) + lock_run(run) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner", + new=AsyncMock(), + ): + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(delayed_job) + await session.refresh(regular_job) + assert delayed_job.status == JobStatus.TERMINATING + assert delayed_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert delayed_job.remove_at is not None + assert regular_job.status == JobStatus.TERMINATING + assert regular_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert regular_job.remove_at is None + + async def test_finishes_non_scheduled_run_when_all_jobs_are_finished( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.JOB_FAILED, + ) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.EXECUTOR_ERROR, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.FAILED + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + @freeze_time(datetime(2023, 1, 2, 3, 10, tzinfo=timezone.utc)) + async def test_reschedules_scheduled_run_and_clears_fleet( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="scheduled-run", + configuration=TaskConfiguration( + nodes=1, + schedule=Schedule(cron="15 * * * *"), + commands=["echo Hi!"], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_name="scheduled-run", + run_spec=run_spec, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + resubmission_attempt=1, + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + assert run.status == RunStatus.PENDING + assert run.next_triggered_at == datetime(2023, 1, 2, 3, 15, tzinfo=timezone.utc) + assert run.fleet_id is None + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + async def test_noops_when_run_lock_changes_after_processing( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.JOB_FAILED, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, + ) + lock_run(run) + await session.commit() + item = run_to_pipeline_item(run) + new_lock_token = uuid.uuid4() + + async def change_run_lock(**kwargs) -> None: + run.lock_token = new_lock_token + run.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner", + new=AsyncMock(side_effect=change_run_lock), + ): + await worker.process(item) + + await session.refresh(run) + await session.refresh(job) + assert run.status == RunStatus.TERMINATING + assert run.lock_token == new_lock_token + assert job.status == JobStatus.RUNNING + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner is None + + async def test_resets_run_lock_when_related_job_is_locked_by_another_pipeline( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.JOB_FAILED, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + ) + _lock_job(job, lock_owner=JobTerminatingPipeline.__name__) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + await session.refresh(job) + assert run.status == RunStatus.TERMINATING + assert run.lock_owner == RunPipeline.__name__ + assert run.lock_token is None + assert run.lock_expires_at is None + assert job.status == JobStatus.SUBMITTED + assert job.lock_owner == JobTerminatingPipeline.__name__ + + async def test_reclaims_expired_same_owner_related_job_lock( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.JOB_FAILED, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + ) + _lock_job( + job, + lock_owner=RunPipeline.__name__, + lock_expires_at=get_current_datetime() - timedelta(minutes=1), + ) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner is None + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None + + async def test_ignores_already_terminating_jobs_when_locking_related_jobs( + self, test_db, session: AsyncSession, worker: RunWorker + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.JOB_FAILED, + ) + terminating_job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + ) + submitted_job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + job_num=1, + ) + _lock_job(terminating_job, lock_owner=JobTerminatingPipeline.__name__) + lock_run(run) + await session.commit() + + await worker.process(run_to_pipeline_item(run)) + + await session.refresh(run) + await session.refresh(terminating_job) + await session.refresh(submitted_job) + assert terminating_job.status == JobStatus.TERMINATING + assert terminating_job.lock_owner == JobTerminatingPipeline.__name__ + assert submitted_job.status == JobStatus.TERMINATING + assert submitted_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert submitted_job.lock_token is None + assert submitted_job.lock_expires_at is None + assert submitted_job.lock_owner is None + assert run.lock_token is None + assert run.lock_expires_at is None + assert run.lock_owner is None