diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 75586d09a3e..d0964f48937 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -142,6 +142,7 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): self.is_paused = False # pause request generation self._pause_cond = threading.Condition() + self._rejecting_new_requests = False # blocks new requests during abort drain self._ctrl_output_queues = {} self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict) @@ -1325,7 +1326,7 @@ def _insert_zmq_task_to_scheduler(self): trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) self.llm_logger.debug(f"Receive request from api server: {request}") - if self.is_paused: + if self.is_paused or self._rejecting_new_requests: self.llm_logger.warning(f"Engine is paused, drop request: {request}") self._send_error_response( request.request_id, @@ -1445,39 +1446,20 @@ def _control_pause(self, control_request: ControlRequest): if self.is_paused: self.llm_logger.info("Engine is already paused, no need to pause again.") return - self.is_paused = True + self._rejecting_new_requests = True + self.resource_manager.log_status() - self.llm_logger.info("Abort running requests.") + # Scheduling loop picks them up via _trigger_abort when they enter resource_manager + all_req_ids = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) + self.llm_logger.info(f"Pause: aborting {len(all_req_ids)} total requests.") + if all_req_ids: + self.resource_manager.add_abort_req_ids(all_req_ids) + self._wait_inflight_drained() - self.resource_manager.log_status() - # preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue - timeout, count = 60, 0 - while self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - count += 1 - if count >= timeout * 1000: - break - if count >= timeout * 1000: - error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!" - self.llm_logger.error(error_msg) - raise Exception(error_msg) - running_reqs = self.resource_manager.preempted_all() - if len(running_reqs) > 0: - self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.") - self.resource_manager.get_real_bsz() - self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz)) - self.resource_manager.wait_worker_inflight_requests_finish(timeout=60) - # self.engine_worker_queue.clear_data() - self.token_processor.clear_data() + with self._pause_cond: + self.is_paused = True self.resource_manager.log_status() - # abort inflight requests to user - inflight_requests = self.scheduler.get_inflight_requests() - self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).") - for req in inflight_requests: - self._send_error_response(req.request_id, "Request is aborted since engine is paused.") - self.scheduler.reset() - if envs.ENABLE_V1_KVCACHE_MANAGER: self.resource_manager.cache_manager.reset_cache() else: @@ -1500,6 +1482,21 @@ def _control_pause(self, control_request: ControlRequest): self.llm_logger.info("Successfully paused request generation.") return None + def _wait_inflight_drained(self): + """ + Wait until resource_manager.requests is completely empty. + No timeout — abort pipeline will complete. Aligned with SGLang's poll-until-drained. + """ + start_time = time.time() + while ( + self.resource_manager.requests + or self.scheduler.requests + or self.resource_manager.waiting_abort_req_id_set + or self.resource_manager.to_be_aborted_req_id_set + ): + time.sleep(0.005) + self.llm_logger.info(f"All inflight requests drained, take time: {time.time() - start_time:.3f} seconds") + def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: """Control function for resuming request generation. @@ -1515,6 +1512,7 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: self.llm_logger.info("Engine is not paused, no need to resume.") return None self.is_paused = False + self._rejecting_new_requests = False self._pause_cond.notify_all() # resume cache transfer diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index ac30f26d9ab..1490ed6cfc9 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -1137,22 +1137,29 @@ def test_control_pause_and_resume_paths(self): eng = self._make_mixed_engine() eng.is_paused = False eng._pause_cond = threading.Condition() - eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False)) eng.resource_manager = Mock( - preempted_all=Mock(return_value=[Request(request_id="r1", prompt_token_ids=[1], prompt_token_ids_len=1)]), - get_real_bsz=Mock(), - wait_worker_inflight_requests_finish=Mock(), + requests={"r1": Mock(output_token_ids=[1, 2, 3])}, + waiting_abort_req_id_set=set(), + to_be_aborted_req_id_set=set(), + add_abort_req_ids=Mock(), log_status=Mock(), cache_manager=Mock(reset=Mock()), - real_bsz=1, ) eng.token_processor = Mock(clear_data=Mock()) - eng.scheduler = Mock(get_inflight_requests=Mock(return_value=[]), reset=Mock()) + mock_scheduler = Mock(reset=Mock()) + mock_scheduler.requests = {} + mock_scheduler.mutex = threading.Lock() + mock_scheduler.responses = {} + mock_scheduler.batch_responses_per_step = [] + eng.scheduler = mock_scheduler eng._send_error_response = Mock() + eng._wait_inflight_drained = Mock() with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", True): eng._control_pause(ControlRequest(request_id="ctrl1", method="pause")) self.assertTrue(eng.is_paused) + eng.resource_manager.add_abort_req_ids.assert_called_once() eng._control_resume(ControlRequest(request_id="ctrl2", method="resume")) self.assertFalse(eng.is_paused)