-
Notifications
You must be signed in to change notification settings - Fork 743
[RL] pause: use abort pipeline with scheduling loop alive for gracefu… #7753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,11 +44,9 @@ | |
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.engine.register_manager import RegisterManager | ||
| from fastdeploy.engine.request import ( | ||
| CompletionOutput, | ||
| ControlRequest, | ||
| ControlResponse, | ||
| Request, | ||
| RequestMetrics, | ||
| RequestOutput, | ||
| RequestStatus, | ||
| RequestType, | ||
|
|
@@ -142,6 +140,7 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): | |
|
|
||
| self.is_paused = False # pause request generation | ||
| self._pause_cond = threading.Condition() | ||
| self._rejecting_new_requests = False # blocks new requests during abort drain | ||
|
|
||
| self._ctrl_output_queues = {} | ||
| self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict) | ||
|
|
@@ -1305,10 +1304,26 @@ def _insert_zmq_task_to_scheduler(self): | |
| self.request_worker_map[req_id_for_map] = worker_pid | ||
| status_value = data.get("status", None) | ||
| if status_value is not None and status_value == RequestStatus.ABORT.value: | ||
| req_id = data["request_id"] | ||
| self.llm_logger.info(f"Receive abort request, req_id: {req_id}") | ||
| if envs.ENABLE_V1_KVCACHE_SCHEDULER: | ||
| self.resource_manager.add_abort_req_ids(req_id) | ||
| if not envs.ENABLE_V1_KVCACHE_SCHEDULER: | ||
| self.llm_logger.info("abort requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") | ||
| else: | ||
| abort_all = data.get("abort_all", False) | ||
| req_ids = data.get("req_ids", []) | ||
| if abort_all or req_ids: | ||
| target_req_ids = self._resolve_abort_targets(abort_all, req_ids) | ||
| self.llm_logger.info( | ||
| f"Receive abort_reqs, abort_all={abort_all}, " | ||
| f"input={len(req_ids)}, resolved={len(target_req_ids)}" | ||
| ) | ||
| self.resource_manager.add_abort_req_ids(target_req_ids) | ||
| else: | ||
| req_id = data.get("request_id", None) | ||
| if not req_id: | ||
| self.llm_logger.warning( | ||
| "Receive abort request without request_id, skip invalid abort message" | ||
| ) | ||
| self.llm_logger.info(f"Receive abort request, req_id: {req_id}") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug
建议修复: req_id = data.get("request_id", None)
if not req_id:
self.llm_logger.warning(
"Receive abort request without request_id, skip invalid abort message"
)
continue # 实际跳过,与警告语义一致
self.llm_logger.info(f"Receive abort request, req_id: {req_id}")
self.resource_manager.add_abort_req_ids(req_id) |
||
| self.resource_manager.add_abort_req_ids(req_id) | ||
| continue | ||
| err_msg = None | ||
| try: | ||
|
|
@@ -1325,7 +1340,7 @@ def _insert_zmq_task_to_scheduler(self): | |
| trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) | ||
| self.llm_logger.debug(f"Receive request from api server: {request}") | ||
|
|
||
| if self.is_paused: | ||
| if self.is_paused or self._rejecting_new_requests: | ||
| self.llm_logger.warning(f"Engine is paused, drop request: {request}") | ||
| self._send_error_response( | ||
| request.request_id, | ||
|
|
@@ -1445,39 +1460,19 @@ def _control_pause(self, control_request: ControlRequest): | |
| if self.is_paused: | ||
| self.llm_logger.info("Engine is already paused, no need to pause again.") | ||
| return | ||
| self.is_paused = True | ||
|
|
||
| self.llm_logger.info("Abort running requests.") | ||
|
|
||
| self.resource_manager.log_status() | ||
| # preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue | ||
| timeout, count = 60, 0 | ||
| while self.engine_worker_queue.exist_tasks(): | ||
| time.sleep(0.001) | ||
| count += 1 | ||
| if count >= timeout * 1000: | ||
| break | ||
| if count >= timeout * 1000: | ||
| error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!" | ||
| self.llm_logger.error(error_msg) | ||
| raise Exception(error_msg) | ||
| running_reqs = self.resource_manager.preempted_all() | ||
| if len(running_reqs) > 0: | ||
| self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.") | ||
| self.resource_manager.get_real_bsz() | ||
| self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz)) | ||
| self.resource_manager.wait_worker_inflight_requests_finish(timeout=60) | ||
| # self.engine_worker_queue.clear_data() | ||
| self.token_processor.clear_data() | ||
| self._rejecting_new_requests = True | ||
| self.resource_manager.log_status() | ||
|
|
||
| # abort inflight requests to user | ||
| inflight_requests = self.scheduler.get_inflight_requests() | ||
| self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).") | ||
| for req in inflight_requests: | ||
| self._send_error_response(req.request_id, "Request is aborted since engine is paused.") | ||
| self.scheduler.reset() | ||
| all_req_ids = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) | ||
| self.llm_logger.info(f"Pause: aborting {len(all_req_ids)} total requests.") | ||
| if all_req_ids: | ||
| self.resource_manager.add_abort_req_ids(all_req_ids) | ||
| self._wait_inflight_drained() | ||
|
|
||
| with self._pause_cond: | ||
| self.is_paused = True | ||
|
|
||
| self.resource_manager.log_status() | ||
| if envs.ENABLE_V1_KVCACHE_MANAGER: | ||
| self.resource_manager.cache_manager.reset_cache() | ||
| else: | ||
|
|
@@ -1500,6 +1495,16 @@ def _control_pause(self, control_request: ControlRequest): | |
| self.llm_logger.info("Successfully paused request generation.") | ||
| return None | ||
|
|
||
| def _wait_inflight_drained(self): | ||
| """ | ||
| Wait until resource_manager.requests is completely empty. | ||
| No timeout — abort pipeline will complete. Aligned with SGLang's poll-until-drained. | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| """ | ||
| start_time = time.time() | ||
| while self.resource_manager.requests or self.scheduler.requests: | ||
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 当前实现依赖 abort pipeline 必然完成的假设,但在以下场景中可能永久 block:
建议参照原 timeout = 60
start_time = time.time()
while self.resource_manager.requests or self.scheduler.requests:
time.sleep(0.005)
if time.time() - start_time > timeout:
self.llm_logger.error(
f"_wait_inflight_drained timeout after {timeout}s, forcing pause"
)
break |
||
| time.sleep(0.005) | ||
| self.llm_logger.info(f"All inflight requests drained, takes {time.time() - start_time:.1f} seconds.") | ||
|
|
||
| def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: | ||
| """Control function for resuming request generation. | ||
|
|
||
|
|
@@ -1514,6 +1519,7 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: | |
| if not self.is_paused: | ||
| self.llm_logger.info("Engine is not paused, no need to resume.") | ||
| return None | ||
| self._rejecting_new_requests = False | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| self.is_paused = False | ||
| self._pause_cond.notify_all() | ||
|
|
||
|
|
@@ -1597,150 +1603,6 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d | |
|
|
||
| return responses | ||
|
|
||
| def _control_abort_requests(self, control_req: ControlRequest): | ||
| if not envs.ENABLE_V1_KVCACHE_SCHEDULER: | ||
| raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") | ||
| args = control_req.get_args() | ||
| abort_all = args.get("abort_all", False) | ||
| req_ids = args.get("req_ids", []) | ||
| matched_input_ids = set() | ||
| now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) | ||
|
|
||
| # Step 1: Determine target request list | ||
| if abort_all: | ||
| # all requests in running + waiting | ||
| target_req_ids = now_reqs | ||
| else: | ||
| # filter out requests that actually exist | ||
| target_req_ids = [] | ||
| for rid in req_ids: | ||
| if rid in now_reqs: | ||
| target_req_ids.append(rid) | ||
| matched_input_ids.add(rid) | ||
| elif f"{rid}_0" in now_reqs: | ||
| target_req_ids.append(f"{rid}_0") | ||
| matched_input_ids.add(rid) | ||
|
|
||
| if not target_req_ids: | ||
| return {"aborted": [], "not_found": req_ids if not abort_all else []} | ||
|
|
||
| # Step 2: Collect partial results | ||
| aborted_info = [] | ||
| results = [] | ||
| for req_id in target_req_ids: | ||
| request = self.resource_manager.requests.get(req_id) | ||
| if request is None: | ||
| scheduled_req = self.scheduler.requests.get(req_id) | ||
| if scheduled_req is None: | ||
| continue | ||
| request = scheduled_req.raw | ||
|
|
||
| partial_token_ids = list(request.output_token_ids) | ||
|
|
||
| # Construct finished response with partial results | ||
| now = time.time() | ||
| abort_metrics = RequestMetrics( | ||
| arrival_time=request.metrics.arrival_time if request.metrics else now, | ||
| inference_start_time=request.metrics.inference_start_time if request.metrics else now, | ||
| engine_recv_latest_token_time=now, | ||
| engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, | ||
| request_start_time=request.metrics.arrival_time if request.metrics else now, | ||
| ) | ||
| eos_token_ids = getattr(request, "eos_token_ids", [0]) | ||
| result = RequestOutput( | ||
| request_id=req_id, | ||
| finished=True, | ||
| outputs=CompletionOutput( | ||
| index=0, | ||
| send_idx=len(partial_token_ids), | ||
| token_ids=[eos_token_ids[0]], | ||
| ), | ||
| metrics=abort_metrics, | ||
| error_code=200, | ||
| error_msg="Aborted", | ||
| ) | ||
| results.append(result) | ||
| aborted_info.append( | ||
| { | ||
| "request_id": req_id, | ||
| "output_token_count": len(partial_token_ids), | ||
| } | ||
| ) | ||
|
|
||
| # Step 3: Execute abort — add all requests to waiting_abort_req_id_set | ||
| if envs.ENABLE_V1_KVCACHE_SCHEDULER: | ||
| for req_id in target_req_ids: | ||
| self.resource_manager.add_abort_req_ids(req_id) | ||
| time.sleep(0.0001) | ||
| if self.cfg.scheduler_config.splitwise_role != "prefill": | ||
| self._wait_abort_complete(target_req_ids) | ||
|
|
||
| # Add results to scheduler, engine will have a thread calling get_results, | ||
| # then cleanup and call send_response to send to client. | ||
| # When client disconnects, send_response will automatically ignore | ||
| if self.cfg.scheduler_config.splitwise_role != "prefill": | ||
| try: | ||
| # self.send_response_server.send_response(req_id, [result]) | ||
| self.scheduler.put_results(results) | ||
| except Exception: | ||
| pass # client may have disconnected | ||
|
|
||
| not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] | ||
|
|
||
| return {"aborted": aborted_info, "not_found": not_found} | ||
|
|
||
| def _wait_abort_complete(self, target_req_ids, stall_timeout=1): | ||
| """ | ||
| Wait for all abort requests to complete. | ||
| - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet | ||
| - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, | ||
| reset progress state if any, then continue monitoring | ||
| """ | ||
| target_set = set(target_req_ids) | ||
| target_set = target_set & (set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) | ||
| prev_remaining_count = len(target_set) | ||
| last_progress_time = time.time() | ||
| remaining = target_set & self.resource_manager.get_reqs_in_aborting() | ||
| while remaining: | ||
| alive_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) | ||
| finished_reqs = target_set - alive_reqs | ||
| if finished_reqs: | ||
| self.llm_logger.info(f"abort targets already finished, skip: {finished_reqs}") | ||
| for req_id in finished_reqs: | ||
| self.resource_manager.waiting_abort_req_id_set.discard(req_id) | ||
| self.resource_manager.to_be_aborted_req_id_set.discard(req_id) | ||
| target_set -= finished_reqs | ||
| remaining = target_set & self.resource_manager.get_reqs_in_aborting() | ||
| if not remaining: | ||
| self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") | ||
| return | ||
| self.llm_logger.debug(f"remaining:{remaining}") | ||
|
|
||
| current_count = len(remaining) | ||
| if current_count < prev_remaining_count: | ||
| # progress made: recycle_abort_task was called | ||
| self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") | ||
| last_progress_time = time.time() | ||
| prev_remaining_count = current_count | ||
|
|
||
| if time.time() - last_progress_time > stall_timeout: | ||
| # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) | ||
| stuck = remaining & self.resource_manager.to_be_aborted_req_id_set | ||
| if stuck: | ||
| self.llm_logger.warning( | ||
| f"no abort progress for {stall_timeout}s, " | ||
| f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" | ||
| ) | ||
| for req_id in list(stuck): | ||
| self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") | ||
| self.resource_manager.recycle_abort_task(req_id) | ||
| # reset progress state | ||
| last_progress_time = time.time() | ||
| prev_remaining_count = current_count - len(stuck) | ||
| # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow | ||
|
|
||
| time.sleep(0.005) | ||
|
|
||
| def _parse_tags(self, control_request: ControlRequest): | ||
| """ | ||
| Parse tags from control request. | ||
|
|
@@ -2766,3 +2628,21 @@ def detect_thread(): | |
| except Exception: | ||
| pass | ||
| return True | ||
|
|
||
| def _resolve_abort_targets(self, abort_all, req_ids): | ||
| """ | ||
| Resolve abort target request IDs. | ||
| """ | ||
| now_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) | ||
| self.llm_logger.debug(f"now_reqs: {now_reqs}") | ||
|
|
||
| if abort_all: | ||
| return list(now_reqs) | ||
|
|
||
| target_req_ids = [] | ||
| for rid in req_ids: | ||
| if rid in now_reqs: | ||
| target_req_ids.append(rid) | ||
| elif f"{rid}_0" in now_reqs: | ||
| target_req_ids.append(f"{rid}_0") | ||
| return target_req_ids | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -496,13 +496,8 @@ async def abort_requests(request: Request): | |
| if not abort_all and not req_ids: | ||
| return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"}) | ||
|
|
||
| control_request = ControlRequest( | ||
| request_id=f"control-{uuid.uuid4()}", | ||
| method="abort_requests", | ||
| args={"abort_all": abort_all, "req_ids": req_ids or []}, | ||
| ) | ||
| control_response = await app.state.engine_client.run_control_method(control_request) | ||
| return control_response.to_api_json_response() | ||
| await app.state.engine_client.abort_reqs(req_ids=req_ids or [], abort_all=abort_all) | ||
| return Response(status_code=200) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 请确认:
|
||
|
|
||
|
|
||
| def wrap_streaming_generator(original_generator: AsyncGenerator): | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.