-
Notifications
You must be signed in to change notification settings - Fork 742
[PD] prepare request in prefill instance by multi threads #7724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Jiang-Jia-Jun
merged 1 commit into
PaddlePaddle:release/2.6
from
juncaipeng:refine-pd-fetch-req
May 13, 2026
+374
−414
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,6 @@ | |
| import time | ||
| import traceback | ||
| import weakref | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from pathlib import Path | ||
| from typing import Dict, List, Optional, Tuple | ||
|
|
||
|
|
@@ -42,6 +41,7 @@ | |
| import fastdeploy.metrics.trace as tracing | ||
| from fastdeploy.cache_manager.cache_data import CacheStatus | ||
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.engine.common_engine_prepare_mixin import EngineServicePrepareMixin | ||
| from fastdeploy.engine.register_manager import RegisterManager | ||
| from fastdeploy.engine.request import ( | ||
| CompletionOutput, | ||
|
|
@@ -115,7 +115,7 @@ def _format_worker_launch_failure_message(log_dir: str) -> str: | |
| return message | ||
|
|
||
|
|
||
| class EngineService: | ||
| class EngineService(EngineServicePrepareMixin): | ||
| """ | ||
| Base class containing common engine functionality | ||
| """ | ||
|
|
@@ -251,12 +251,13 @@ def start(self, async_llm_pid=None): | |
| self.start_worker_service(async_llm_pid) | ||
|
|
||
| if envs.ENABLE_V1_KVCACHE_SCHEDULER: | ||
| self.insert_task_to_worker_thread = threading.Thread( | ||
| target=self._schedule_request_to_worker_v1, daemon=True | ||
| ) | ||
| self.prepare_request_thread = threading.Thread(target=self._prepare_request_v1, daemon=True) | ||
| self.prepare_request_thread.start() | ||
| self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker_v1, daemon=True) | ||
| self.schedule_request_thread.start() | ||
| else: | ||
| self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) | ||
| self.insert_task_to_worker_thread.start() | ||
| self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) | ||
| self.schedule_request_thread.start() | ||
| self.token_processor.tasks_queue = self.engine_worker_queue | ||
| self.token_processor.run() | ||
| if self.cfg.scheduler_config.splitwise_role == "decode": | ||
|
|
@@ -879,215 +880,19 @@ def _schedule_request_to_worker_v1(self): | |
| Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). | ||
| """ | ||
| tracing.trace_set_thread_info("Scheduler Task to Work") | ||
| get_request_pool = ThreadPoolExecutor(max_workers=1) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 将准备请求相关函数移到单独一个文件中 |
||
| is_fetching = False | ||
|
|
||
| def _fetch_request(): | ||
| try: | ||
| with self._pause_cond: | ||
| self._pause_cond.wait_for(lambda: not self.is_paused) | ||
| nonlocal is_fetching | ||
| num_prefill_batch = min( | ||
| int(self.resource_manager.available_batch()), | ||
| self.cfg.max_prefill_batch, | ||
| ) | ||
|
|
||
| if self.cfg.scheduler_config.splitwise_role != "mixed": | ||
| max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens | ||
| else: | ||
| max_num_batched_tokens = self.cfg.model_config.max_model_len | ||
|
|
||
| available_blocks = self.cfg.cache_config.max_block_num_per_seq | ||
| tasks = self.scheduler.get_requests( | ||
| available_blocks=available_blocks, | ||
| block_size=self.cfg.cache_config.block_size, | ||
| reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num | ||
| max_num_batched_tokens=max_num_batched_tokens, | ||
| batch=num_prefill_batch, | ||
| ) | ||
| for task in tasks: | ||
| task.metrics.engine_get_req_time = time.time() | ||
| trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) | ||
|
|
||
| if self.cfg.scheduler_config.splitwise_role == "decode": | ||
| # TODO: refine scheduler to remove this limitation | ||
| # Decode will process and schedule the request sent by prefill to engine, | ||
| # so the same request sent by the decode api server will be ignored | ||
| is_fetching = False | ||
| return | ||
|
|
||
| if tasks: | ||
| self.llm_logger.debug( | ||
| f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" | ||
| ) | ||
|
|
||
| if self.cfg.scheduler_config.splitwise_role == "prefill": | ||
| for task in tasks: | ||
| # start async preprocess | ||
| self.resource_manager.apply_async_preprocess(task) | ||
| need_delete_tasks = [] | ||
| if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: | ||
| for task in tasks: | ||
| # assure can allocate block ids in P | ||
| while not self.resource_manager.preallocate_resource_in_p(task): | ||
| time.sleep(0.005) | ||
| self.llm_logger.debug( | ||
| f"P has allocated resources and then ask D resource for request: {task.request_id}" | ||
| ) | ||
| trace_print( | ||
| LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") | ||
| ) | ||
| task.metrics.ask_decode_resource_start_time = time.time() | ||
| while True: | ||
| self.split_connector.send_splitwise_tasks([task], task.idx) | ||
| status, msg = self.split_connector.check_decode_allocated(task) | ||
| if not status: | ||
| self.llm_logger.warning( | ||
| f"D failed to allocate resource for request {task.request_id}, try again." | ||
| ) | ||
| time.sleep(0.05) | ||
| else: | ||
| task.metrics.ask_decode_resource_finish_time = time.time() | ||
| trace_print( | ||
| LoggingEventName.ASK_DECODE_RESOURCE_END, | ||
| task.request_id, | ||
| getattr(task, "user", ""), | ||
| ) | ||
| break | ||
| self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") | ||
| else: | ||
| for task in tasks: | ||
| # assure can allocate block ids in P | ||
| while not self.resource_manager.preallocate_resource_in_p(task): | ||
| time.sleep(0.005) | ||
|
|
||
| self.llm_logger.debug( | ||
| f"P has allocated resources and then ask D resource for req_id: {task.request_id}" | ||
| ) | ||
| trace_print( | ||
| LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") | ||
| ) | ||
| task.metrics.ask_decode_resource_start_time = time.time() | ||
| self.split_connector.send_splitwise_tasks([task], task.idx) | ||
|
|
||
| for task in tasks: | ||
| # assure fetch block ids from D | ||
| status, msg = self.split_connector.check_decode_allocated(task) | ||
| task.metrics.ask_decode_resource_finish_time = time.time() | ||
| trace_print( | ||
| LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "") | ||
| ) | ||
| if not status: | ||
| error_msg = ( | ||
| f"PD Error: prefill failed to apply for resource from decode, " | ||
| f"req: {task.request_id}, msg:{msg}." | ||
| ) | ||
| self.llm_logger.error(error_msg) | ||
| self.scheduler.put_results( | ||
| [ | ||
| RequestOutput( | ||
| request_id=task.request_id, | ||
| finished=True, | ||
| error_code=500, | ||
| error_msg=error_msg, | ||
| ) | ||
| ] | ||
| ) | ||
| main_process_metrics.reschedule_req_num.inc() | ||
| need_delete_tasks.append(task) | ||
| continue | ||
| for tmp_task in need_delete_tasks: | ||
| tasks.remove(tmp_task) | ||
| # release resource in P | ||
| self.resource_manager.pre_recycle_resource(tmp_task.request_id) | ||
|
|
||
| # to send cache info to cache messager | ||
| if tasks: | ||
| need_check_req_ids = [task.request_id for task in tasks] | ||
| self.split_connector.send_cache_info_to_messager(tasks, 0) | ||
| # ensure cache tasks has sent to cache_messager | ||
| need_check_req_ids = [task.request_id for task in tasks] | ||
| finished_ids, delete_tasks_list = [], [] | ||
| while need_check_req_ids: | ||
| finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) | ||
| self.llm_logger.debug( | ||
| f"P has successfully sent cache infos to cache messager for requests: {finished_ids}" | ||
| ) | ||
| if finished_ids: | ||
| for task in tasks: | ||
| result = self.resource_manager.waiting_async_process(task) | ||
| if result is None: | ||
| self.scheduler.put_results( | ||
| [ | ||
| RequestOutput( | ||
| request_id=task.request_id, | ||
| finished=True, | ||
| error_code=task.error_code, | ||
| error_msg=task.error_message, | ||
| ) | ||
| ] | ||
| ) | ||
| need_check_req_ids.remove(task.request_id) | ||
| delete_tasks_list.append(task) | ||
| elif result is False: | ||
| if task.request_id in finished_ids: | ||
| need_check_req_ids.remove(task.request_id) | ||
| finished_ids.remove(task.request_id) | ||
| else: | ||
| time.sleep(0.001) | ||
|
|
||
| for tmp_task in delete_tasks_list: | ||
| tasks.remove(tmp_task) | ||
| # release resource in P | ||
| self.resource_manager.pre_recycle_resource(tmp_task.request_id) | ||
|
|
||
| # Fetch requests and add them to the scheduling queue | ||
| if tasks: | ||
| for task in tasks: | ||
| task.metrics.add_req_to_resource_manager_time = time.time() | ||
| trace_print( | ||
| LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "") | ||
| ) | ||
| if self.cfg.scheduler_config.splitwise_role == "prefill": | ||
| self.resource_manager.add_request_in_p(tasks) | ||
| self.llm_logger.info( | ||
| f"P add requests into running queue: {[task.request_id for task in tasks]}" | ||
| ) | ||
| else: | ||
| for task in tasks: | ||
| self.resource_manager.add_request(task) | ||
| is_fetching = False | ||
| except Exception as e: | ||
| self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}") | ||
| is_fetching = False | ||
|
|
||
| while self.running: | ||
| with self._pause_cond: | ||
| self._pause_cond.wait_for(lambda: not self.is_paused) | ||
|
|
||
| try: | ||
| if self.engine_worker_queue.exist_tasks(): | ||
| time.sleep(0.001) | ||
| continue | ||
| if self.cfg.scheduler_config.splitwise_role != "mixed": | ||
| if not is_fetching: | ||
| is_fetching = True | ||
| get_request_pool.submit(_fetch_request) | ||
|
|
||
| else: | ||
| if len(self.resource_manager.waiting) == 0 and (not is_fetching): | ||
| # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. | ||
| try: | ||
| is_fetching = True | ||
| get_request_pool.submit(_fetch_request) | ||
| except RuntimeError as e: | ||
| if "shutdown" in str(e): | ||
| self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") | ||
| break | ||
| else: | ||
| raise | ||
|
|
||
| if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): | ||
| self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() | ||
|
|
||
| # 2. Schedule requests | ||
| tasks, error_tasks = self.resource_manager.schedule() | ||
|
|
||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果引擎收到请求并推理了、但是cache messager还没收到请求,这里就等待收到请求,避免错误。如果万一收不到请求就hang住,避免出现传输cache错误。