From 53b9761b25cb7b4f874b472212140124aff2f50e Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Tue, 12 May 2026 03:33:02 +0000 Subject: [PATCH] prepare request in prefill instance by multi threads --- fastdeploy/cache_manager/cache_messager.py | 38 ++- fastdeploy/engine/common_engine.py | 215 +------------ .../engine/common_engine_prepare_mixin.py | 282 ++++++++++++++++++ fastdeploy/envs.py | 2 + .../inter_communicator/engine_worker_queue.py | 106 ------- tests/cache_manager/test_cache_messager.py | 24 +- tests/engine/test_common_engine.py | 95 ++---- tests/inter_communicator/test_e2w_queue.py | 26 +- 8 files changed, 374 insertions(+), 414 deletions(-) create mode 100644 fastdeploy/engine/common_engine_prepare_mixin.py diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index b934c3e74c7..08c8dea003a 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -613,12 +613,16 @@ def __init__( ) self.gpu_id = gpu_id - self.cache_info = dict() + self.cache_info = dict() # {'request_id': cache_info_dict} self.rank_id = self.rank + local_data_parallel_id * self.nranks self.engine_cache_task_thread_lock = threading.Lock() - self.engine_cache_tasks = [dict() for _ in range(512)] - self.idx_cache_task_dict = {} - self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step + self.engine_cache_tasks = [ + dict() for _ in range(512) + ] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}} + self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict} + self.cache_prefilled_engine_ids_queue = ( + queue.Queue() + ) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)] if splitwise_role == "prefill": consume_signals_thread = threading.Thread(target=self.consume_signals) consume_signals_thread.daemon = True @@ -638,7 +642,6 @@ def _add_cache_task_thread(self): while True: try: cache_info = self.engine_worker_queue.get_cache_info() - finished_add_cache_task_req_ids = [] if cache_info: logger.debug(f"Get cache info from engine worker queue, {cache_info}") self.engine_worker_queue.cache_info_barrier.wait() @@ -647,7 +650,6 @@ def _add_cache_task_thread(self): self.cache_info[info["request_id"]].update(info) current_info = self.cache_info[info["request_id"]] assert "dest_block_ids" in current_info and "src_block_ids" in current_info - finished_add_cache_task_req_ids.append(info["request_id"]) decode_cached_block_num = len(current_info["src_block_ids"]) - len( current_info["dest_block_ids"] ) @@ -659,17 +661,13 @@ def _add_cache_task_thread(self): current_info["sended_layer_id"] = -1 current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size current_info["status"] = "init" - logger.info(f"Get cache info from D: finish add cache task: {current_info}") + logger.info(f"Get cache info and finish add cache task: {current_info}") self.cache_info[info["request_id"]] = current_info self.idx_cache_task_dict[current_info["current_id"]] = current_info else: - logger.info(f"Get cache info from P: {info}") + logger.info(f"Get cache info: {info}") self.cache_info[info["request_id"]] = info - if finished_add_cache_task_req_ids: - logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}") - self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids) - self.engine_worker_queue.finish_add_cache_task_barrier.wait() else: time.sleep(0.001) except Exception as e: @@ -687,10 +685,12 @@ def prefill_layerwise_send_cache_thread(self): block_start_end_list = [] current_prefilled_token_num_list = [] for engine_index, current_step_prefilled_token_num in batch_engine_signals: + self._maybe_wait_for_cache_task(engine_index) assert ( engine_index in self.idx_cache_task_dict ), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}" block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"] + prefilled_token_num = current_step_prefilled_token_num if ( prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"] @@ -917,6 +917,20 @@ def _handle_connect_task(self): except Exception as e: logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}") + def _maybe_wait_for_cache_task(self, engine_index): + # If cache messager does not get cache task from engine, just hang here for now + wait_step = 1 + sleep_seconds = 0.005 + + while engine_index not in self.idx_cache_task_dict: + time.sleep(sleep_seconds) + wait_step += 1 + + if wait_step % 400 == 0: + logger.warning( + f"waiting cache task for engine_index: {engine_index}, cost_time: {wait_step * 0.005:.2f} s" + ) + def main(): device = args.device_id diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 50017baf5de..9390ac8260a 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -30,7 +30,6 @@ import time import traceback import weakref -from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -42,6 +41,7 @@ import fastdeploy.metrics.trace as tracing from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.config import FDConfig +from fastdeploy.engine.common_engine_prepare_mixin import EngineServicePrepareMixin from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( CompletionOutput, @@ -115,7 +115,7 @@ def _format_worker_launch_failure_message(log_dir: str) -> str: return message -class EngineService: +class EngineService(EngineServicePrepareMixin): """ Base class containing common engine functionality """ @@ -251,12 +251,13 @@ def start(self, async_llm_pid=None): self.start_worker_service(async_llm_pid) if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.insert_task_to_worker_thread = threading.Thread( - target=self._schedule_request_to_worker_v1, daemon=True - ) + self.prepare_request_thread = threading.Thread(target=self._prepare_request_v1, daemon=True) + self.prepare_request_thread.start() + self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker_v1, daemon=True) + self.schedule_request_thread.start() else: - self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) - self.insert_task_to_worker_thread.start() + self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) + self.schedule_request_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() if self.cfg.scheduler_config.splitwise_role == "decode": @@ -879,215 +880,19 @@ def _schedule_request_to_worker_v1(self): Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). """ tracing.trace_set_thread_info("Scheduler Task to Work") - get_request_pool = ThreadPoolExecutor(max_workers=1) - is_fetching = False - - def _fetch_request(): - try: - with self._pause_cond: - self._pause_cond.wait_for(lambda: not self.is_paused) - nonlocal is_fetching - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - - if self.cfg.scheduler_config.splitwise_role != "mixed": - max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens - else: - max_num_batched_tokens = self.cfg.model_config.max_model_len - - available_blocks = self.cfg.cache_config.max_block_num_per_seq - tasks = self.scheduler.get_requests( - available_blocks=available_blocks, - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num - max_num_batched_tokens=max_num_batched_tokens, - batch=num_prefill_batch, - ) - for task in tasks: - task.metrics.engine_get_req_time = time.time() - trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) - - if self.cfg.scheduler_config.splitwise_role == "decode": - # TODO: refine scheduler to remove this limitation - # Decode will process and schedule the request sent by prefill to engine, - # so the same request sent by the decode api server will be ignored - is_fetching = False - return - - if tasks: - self.llm_logger.debug( - f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" - ) - - if self.cfg.scheduler_config.splitwise_role == "prefill": - for task in tasks: - # start async preprocess - self.resource_manager.apply_async_preprocess(task) - need_delete_tasks = [] - if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for request: {task.request_id}" - ) - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") - ) - task.metrics.ask_decode_resource_start_time = time.time() - while True: - self.split_connector.send_splitwise_tasks([task], task.idx) - status, msg = self.split_connector.check_decode_allocated(task) - if not status: - self.llm_logger.warning( - f"D failed to allocate resource for request {task.request_id}, try again." - ) - time.sleep(0.05) - else: - task.metrics.ask_decode_resource_finish_time = time.time() - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_END, - task.request_id, - getattr(task, "user", ""), - ) - break - self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") - else: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for req_id: {task.request_id}" - ) - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") - ) - task.metrics.ask_decode_resource_start_time = time.time() - self.split_connector.send_splitwise_tasks([task], task.idx) - - for task in tasks: - # assure fetch block ids from D - status, msg = self.split_connector.check_decode_allocated(task) - task.metrics.ask_decode_resource_finish_time = time.time() - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "") - ) - if not status: - error_msg = ( - f"PD Error: prefill failed to apply for resource from decode, " - f"req: {task.request_id}, msg:{msg}." - ) - self.llm_logger.error(error_msg) - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=error_msg, - ) - ] - ) - main_process_metrics.reschedule_req_num.inc() - need_delete_tasks.append(task) - continue - for tmp_task in need_delete_tasks: - tasks.remove(tmp_task) - # release resource in P - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # to send cache info to cache messager - if tasks: - need_check_req_ids = [task.request_id for task in tasks] - self.split_connector.send_cache_info_to_messager(tasks, 0) - # ensure cache tasks has sent to cache_messager - need_check_req_ids = [task.request_id for task in tasks] - finished_ids, delete_tasks_list = [], [] - while need_check_req_ids: - finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) - self.llm_logger.debug( - f"P has successfully sent cache infos to cache messager for requests: {finished_ids}" - ) - if finished_ids: - for task in tasks: - result = self.resource_manager.waiting_async_process(task) - if result is None: - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=task.error_code, - error_msg=task.error_message, - ) - ] - ) - need_check_req_ids.remove(task.request_id) - delete_tasks_list.append(task) - elif result is False: - if task.request_id in finished_ids: - need_check_req_ids.remove(task.request_id) - finished_ids.remove(task.request_id) - else: - time.sleep(0.001) - - for tmp_task in delete_tasks_list: - tasks.remove(tmp_task) - # release resource in P - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # Fetch requests and add them to the scheduling queue - if tasks: - for task in tasks: - task.metrics.add_req_to_resource_manager_time = time.time() - trace_print( - LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "") - ) - if self.cfg.scheduler_config.splitwise_role == "prefill": - self.resource_manager.add_request_in_p(tasks) - self.llm_logger.info( - f"P add requests into running queue: {[task.request_id for task in tasks]}" - ) - else: - for task in tasks: - self.resource_manager.add_request(task) - is_fetching = False - except Exception as e: - self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}") - is_fetching = False while self.running: with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) + try: if self.engine_worker_queue.exist_tasks(): time.sleep(0.001) continue - if self.cfg.scheduler_config.splitwise_role != "mixed": - if not is_fetching: - is_fetching = True - get_request_pool.submit(_fetch_request) - - else: - if len(self.resource_manager.waiting) == 0 and (not is_fetching): - # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. - try: - is_fetching = True - get_request_pool.submit(_fetch_request) - except RuntimeError as e: - if "shutdown" in str(e): - self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") - break - else: - raise if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() + # 2. Schedule requests tasks, error_tasks = self.resource_manager.schedule() diff --git a/fastdeploy/engine/common_engine_prepare_mixin.py b/fastdeploy/engine/common_engine_prepare_mixin.py new file mode 100644 index 00000000000..71327025458 --- /dev/null +++ b/fastdeploy/engine/common_engine_prepare_mixin.py @@ -0,0 +1,282 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import threading +import time +import traceback + +import fastdeploy.metrics.trace as tracing +from fastdeploy.engine.request import RequestOutput +from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.trace.constants import LoggingEventName +from fastdeploy.trace.trace_logger import print as trace_print +from fastdeploy.utils import envs + + +class EngineServicePrepareMixin: + def _fetch_request_mixed(self) -> bool: + """Fetch and prepare requests for a mixed instance. Returns True if tasks were fetched.""" + # FIXME: to validate if it's necessary for avoiding error when enable mtp + if len(self.resource_manager.waiting) > 0: + return False + + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.model_config.max_model_len + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + if not tasks: + return False + + for task in tasks: + task.metrics.engine_get_req_time = time.time() + trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) + + self.llm_logger.debug( + f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" + ) + + for task in tasks: + task.metrics.add_req_to_resource_manager_time = time.time() + trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) + self.resource_manager.add_request(task) + + return True + + def _fetch_request_decode(self) -> bool: + """Consume scheduler queue for decode instance to prevent memory accumulation. + Returns True if tasks were consumed.""" + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + # Tasks are intentionally discarded - decode receives requests via _decode_process_splitwise_requests + return len(tasks) > 0 + + def _fetch_request_prefill(self) -> bool: + """Fetch and prepare requests for a prefill instance. Returns True if tasks were fetched.""" + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + if not tasks: + return False + + for task in tasks: + task.metrics.engine_get_req_time = time.time() + trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) + + self.llm_logger.debug( + f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" + ) + + # Start async preprocess for all tasks in this batch + for task in tasks: + self.resource_manager.apply_async_preprocess(task) + + # P-side resource preallocation + D-side coordination + failed_tasks = [] + if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for request: {task.request_id}" + ) + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) + task.metrics.ask_decode_resource_start_time = time.time() + while True: + self.split_connector.send_splitwise_tasks([task], task.idx) + status, msg = self.split_connector.check_decode_allocated(task) + if status: + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_END, + task.request_id, + getattr(task, "user", ""), + ) + break + else: + self.llm_logger.warning( + f"D failed to allocate resource for request {task.request_id}, try again." + ) + time.sleep(0.05) + + self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") + else: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for req_id: {task.request_id}" + ) + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) + task.metrics.ask_decode_resource_start_time = time.time() + self.split_connector.send_splitwise_tasks([task], task.idx) + + for task in tasks: + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "")) + if not status: + error_msg = ( + f"PD Error: prefill failed to apply for resource from decode, " + f"req: {task.request_id}, msg:{msg}." + ) + self.llm_logger.error(error_msg) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=error_msg, + ) + ] + ) + main_process_metrics.reschedule_req_num.inc() + failed_tasks.append(task) + + for tmp_task in failed_tasks: + tasks.remove(tmp_task) + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # Check and wait async preprocess + if tasks: + need_check_req_ids = [task.request_id for task in tasks] + failed_tasks = [] + + while need_check_req_ids: + still_in_progress = False + for task in tasks: + if task.request_id not in need_check_req_ids: + continue + + result = self.resource_manager.waiting_async_process(task) + if result is False: # async preprocess success + need_check_req_ids.remove(task.request_id) + elif result is True: + still_in_progress = True + elif result is None: # async preprocess failed + failed_tasks.append(task) + need_check_req_ids.remove(task.request_id) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=task.error_code, + error_msg=task.error_message, + ) + ] + ) + + if still_in_progress: + time.sleep(0.005) + + for tmp_task in failed_tasks: + tasks.remove(tmp_task) + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # Send cache info to messager + if tasks: + self.split_connector.send_cache_info_to_messager(tasks, 0) + + # Fetch requests and add them to the scheduling queue + if tasks: + for task in tasks: + task.metrics.add_req_to_resource_manager_time = time.time() + trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) + self.resource_manager.add_request_in_p(tasks) + self.llm_logger.info(f"P add requests into running queue: {[task.request_id for task in tasks]}") + + return True + + def _fetch_loop(self, fetch_fn, thread_idx: int): + """Fetch loop run by each worker thread.""" + tracing.trace_set_thread_info(f"Prepare Request for Scheduling - thread {thread_idx}") + while self.running: + try: + with self._pause_cond: + self._pause_cond.wait_for(lambda: not self.is_paused) + fetch_fn() + time.sleep(0.002) + except Exception as e: + self.llm_logger.error(f"fetching request error in worker-{thread_idx}: {e} {traceback.format_exc()}") + time.sleep(0.002) + + def _prepare_request_v1(self): + """Prepare request and send to the queue for scheduling""" + tracing.trace_set_thread_info("Prepare Request for Scheduling") + role = self.cfg.scheduler_config.splitwise_role + num_workers = envs.FD_PREFILL_PREPARE_REQ_THREAD_NUM if role == "prefill" else 1 + self.llm_logger.info(f"prepare request for scheduling, role: {role}, num_workers: {num_workers}") + + fetch_fn = { + "mixed": self._fetch_request_mixed, + "prefill": self._fetch_request_prefill, + "decode": self._fetch_request_decode, + }[role] + + self._fetch_threads = [] + for i in range(num_workers): + t = threading.Thread( + target=self._fetch_loop, + args=(fetch_fn, i), + daemon=True, + name=f"fetch-{i}", + ) + t.start() + self._fetch_threads.append(t) + + # Keep this thread alive for graceful shutdown + while self.running: + time.sleep(1.0) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 7e0f809d5d3..48f84faa5e0 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -191,6 +191,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")), # "Enable FP8 calibration on HPU" "FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"), + # Number of worker threads for prepare requests in prefill instance + "FD_PREFILL_PREPARE_REQ_THREAD_NUM": lambda: int(os.getenv("FD_PREFILL_PREPARE_REQ_THREAD_NUM", "5")), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), "FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int( os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1") diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index a7876669f8f..c142fef9d64 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -92,7 +92,6 @@ class QueueManager(BaseManager): Value("i", 0) for _ in range(self.local_data_parallel_size) ] self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)] - self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] @@ -110,9 +109,6 @@ class QueueManager(BaseManager): self.connect_task_response_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # connect rdma task response - self.finish_add_cache_task_lock_init: List[threading.Lock] = [ - threading.Lock() for _ in range(self.local_data_parallel_size) - ] # finish add cache task self.finish_send_cache_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # finish send cache @@ -124,18 +120,12 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] - self.client_get_finished_add_cache_task_flag_init: List[List[int]] = [ - [0] * self.num_client for _ in range(self.local_data_parallel_size) - ] self.client_get_finish_send_cache_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] self.can_put_next_connect_task_response_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] - self.can_put_next_add_task_finished_flag_init: List[Value] = [ - Value("i", 1) for _ in range(self.local_data_parallel_size) - ] self.can_put_next_send_cache_finished_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] @@ -147,9 +137,6 @@ class QueueManager(BaseManager): self.get_connect_task_response_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] - self.finish_add_cache_task_barrier = [ - threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) - ] self.begin_send_cache_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] @@ -188,11 +175,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.client_get_connect_task_response_flag_init[idx], proxytype=ListProxy, ) - QueueManager.register( - "get_client_get_finished_add_cache_task_flag_init", - callable=lambda idx: self.client_get_finished_add_cache_task_flag_init[idx], - proxytype=ListProxy, - ) QueueManager.register( "get_client_get_finish_send_cache_flag_init", callable=lambda idx: self.client_get_finish_send_cache_flag_init[idx], @@ -218,11 +200,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.can_put_next_connect_task_response_flag_init[idx], proxytype=ValueProxy, ) - QueueManager.register( - "get_can_put_next_add_task_finished_flag", - callable=lambda idx: self.can_put_next_add_task_finished_flag_init[idx], - proxytype=ValueProxy, - ) QueueManager.register( "get_can_put_next_send_cache_finished_flag", callable=lambda idx: self.can_put_next_send_cache_finished_flag_init[idx], @@ -239,11 +216,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.connect_task_response_lock_init[idx], proxytype=AcquirerProxy, ) - QueueManager.register( - "get_finish_add_cache_task_lock", - callable=lambda idx: self.finish_add_cache_task_lock_init[idx], - proxytype=AcquirerProxy, - ) QueueManager.register( "get_finish_send_cache_lock", callable=lambda idx: self.finish_send_cache_lock_init[idx], @@ -268,12 +240,6 @@ class QueueManager(BaseManager): "get_finish_request_queue", callable=lambda idx: self.finished_req_list[idx], proxytype=ListProxy ) - QueueManager.register( - "get_finish_add_cache_task_queue", - callable=lambda idx: self.finished_add_cache_task_list[idx], - proxytype=ListProxy, - ) - QueueManager.register( "get_cache_infos", callable=lambda idx: self.cache_infos_init[idx], @@ -321,12 +287,6 @@ class QueueManager(BaseManager): "get_cache_info_barrier", callable=lambda idx: self.get_cache_info_barrier[idx], ) - - QueueManager.register( - "get_finish_add_cache_task_barrier", - callable=lambda idx: self.finish_add_cache_task_barrier[idx], - ) - QueueManager.register( "get_worker_process_tp_barrier", callable=lambda idx: self.worker_process_tp_barrier[idx], @@ -351,13 +311,11 @@ class QueueManager(BaseManager): QueueManager.register("get_exist_tasks_inter_signal") QueueManager.register("get_connected_client_counter") QueueManager.register("get_finish_request_queue") - QueueManager.register("get_finish_add_cache_task_queue") QueueManager.register("get_cache_infos") QueueManager.register("get_client_read_info_flag") QueueManager.register("get_lock_info") QueueManager.register("get_disaggregate_requests") QueueManager.register("get_finish_request_barrier") - QueueManager.register("get_finish_add_cache_task_barrier") QueueManager.register("get_connect_task_barrier") QueueManager.register("get_connect_task_response_barrier") QueueManager.register("get_finish_send_cache_barrier") @@ -366,16 +324,13 @@ class QueueManager(BaseManager): QueueManager.register("get_connect_rdma_tasks") QueueManager.register("get_client_get_connect_task_flag") QueueManager.register("get_client_get_connect_task_response_flag") - QueueManager.register("get_client_get_finished_add_cache_task_flag_init") QueueManager.register("get_client_get_finish_send_cache_flag_init") QueueManager.register("get_connect_rdma_tasks_responses") QueueManager.register("get_connect_task_lock") QueueManager.register("get_connect_task_response_lock") - QueueManager.register("get_finish_add_cache_task_lock") QueueManager.register("get_finish_send_cache_lock") QueueManager.register("get_worker_process_tp_barrier") QueueManager.register("get_can_put_next_connect_task_response_flag") - QueueManager.register("get_can_put_next_add_task_finished_flag") QueueManager.register("get_can_put_next_send_cache_finished_flag") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -398,9 +353,6 @@ class QueueManager(BaseManager): # p/d 分离获取 self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id) self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) - self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier( - self.local_data_parallel_id - ) self.connect_task_barrier = self.manager.get_connect_task_barrier(self.local_data_parallel_id) self.connect_task_response_barrier = self.manager.get_connect_task_response_barrier( self.local_data_parallel_id @@ -410,9 +362,6 @@ class QueueManager(BaseManager): self.begin_send_cache_barrier = self.manager.get_begin_send_cache_barrier(self.local_data_parallel_id) self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id) self.finished_send_cache_list = self.manager.get_finish_request_queue(self.local_data_parallel_id) - self.finished_add_cache_task_list = self.manager.get_finish_add_cache_task_queue( - self.local_data_parallel_id - ) # p/d互联 self.connect_rdma_tasks = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) self.client_get_connect_task_flag = self.manager.get_client_get_connect_task_flag( @@ -421,9 +370,6 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag = self.manager.get_client_get_connect_task_response_flag( self.local_data_parallel_id ) - self.client_get_finished_add_cache_task_flag = ( - self.manager.get_client_get_finished_add_cache_task_flag_init(self.local_data_parallel_id) - ) self.client_get_finish_send_cache_flag = self.manager.get_client_get_finish_send_cache_flag_init( self.local_data_parallel_id ) @@ -433,12 +379,8 @@ class QueueManager(BaseManager): ) self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) self.connect_task_response_lock = self.manager.get_connect_task_response_lock(self.local_data_parallel_id) - self.finish_add_cache_task_lock = self.manager.get_finish_add_cache_task_lock(self.local_data_parallel_id) self.finish_send_cache_lock = self.manager.get_finish_send_cache_lock(self.local_data_parallel_id) - self.can_put_next_add_task_finished_flag = self.manager.get_can_put_next_add_task_finished_flag( - self.local_data_parallel_id - ) self.can_put_next_connect_task_response_flag = self.manager.get_can_put_next_connect_task_response_flag( self.local_data_parallel_id ) @@ -756,54 +698,6 @@ def get_finished_req(self) -> str: self.finish_send_cache_lock.release() return response - def put_finished_add_cache_task_req(self, req_ids) -> None: - """ - Put finished request ID into the queue. - - Args: - req_ids: Request ID to be added to the queue - """ - self.finish_add_cache_task_lock.acquire() - while not self.can_put_next_add_task_finished_flag.get(): - self.finish_add_cache_task_lock.release() - time.sleep(0.001) - self.finish_add_cache_task_lock.acquire() - self.finished_add_cache_task_list.append(req_ids) - self.client_get_finished_add_cache_task_flag[self.client_id] = 1 - all_client_put: bool = np.sum(self.client_get_finished_add_cache_task_flag) == self.num_client - if all_client_put: - self.can_put_next_add_task_finished_flag.set(0) - self.finish_add_cache_task_lock.release() - return all_client_put - - def get_finished_add_cache_task_req(self) -> str: - """ - Get finished request ID from the queue. - - Returns: - str: Finished request ID - """ - response = [] - self.finish_add_cache_task_lock.acquire() - if len(self.finished_add_cache_task_list) == 0: - self.finish_add_cache_task_lock.release() - return response - while sum(self.client_get_finished_add_cache_task_flag) < self.num_client: - self.finish_add_cache_task_lock.release() - time.sleep(0.001) - self.finish_add_cache_task_lock.acquire() - if len(self.finished_add_cache_task_list) > 0: - response = self.finished_add_cache_task_list[0] - for tmp_response in self.finished_add_cache_task_list: - assert ( - tmp_response == response - ), f"Inconsistent responses across workers: expected {response}, got {tmp_response}" - self.finished_add_cache_task_list[:] = list() - self.client_get_finished_add_cache_task_flag[:] = [0] * self.num_client - self.can_put_next_add_task_finished_flag.set(1) - self.finish_add_cache_task_lock.release() - return response - def disaggregate_queue_empty(self): """ Check if the disaggregated task queue is empty. diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py index c69d27a24fa..3e415ebe9c8 100644 --- a/tests/cache_manager/test_cache_messager.py +++ b/tests/cache_manager/test_cache_messager.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util +import os import sys import types @@ -22,7 +24,21 @@ if not hasattr(paddle, "enable_compat"): paddle.enable_compat = lambda *args, **kwargs: None -from fastdeploy.cache_manager import cache_messager +# Import the legacy cache_messager module directly from the .py file, +# because the cache_messager/ package shadows it and the legacy +# fallback (cache_messager_legacy) does not exist locally. +_cm_py_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "fastdeploy", + "cache_manager", + "cache_messager.py", +) +_spec = importlib.util.spec_from_file_location( + "fastdeploy.cache_manager.cache_messager_py", + _cm_py_path, +) +cache_messager = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(cache_messager) class _DummyBarrier: @@ -40,12 +56,10 @@ def __init__(self, cache_info_sequence=None, connect_task_sequence=None, **kwarg self.cache_info_calls = 0 self.connect_task_calls = 0 self.cache_info_barrier = _DummyBarrier() - self.finish_add_cache_task_barrier = _DummyBarrier() self.finish_send_cache_barrier = _DummyBarrier() self.connect_task_barrier = _DummyBarrier() self.connect_task_response_barrier = _DummyBarrier() self.begin_send_cache_barrier = _DummyBarrier() - self.finished_add_cache_task_req_ids = [] self.finished_req_payloads = [] self.connect_task_responses = [] @@ -56,9 +70,6 @@ def get_cache_info(self): self.cache_info_calls += 1 return info - def put_finished_add_cache_task_req(self, req_ids): - self.finished_add_cache_task_req_ids.append(req_ids) - def put_finished_req(self, payload): self.finished_req_payloads.append(payload) @@ -376,7 +387,6 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch): } with pytest.raises(SystemExit): messager._add_cache_task_thread() - assert dummy_queue.finished_add_cache_task_req_ids == [["req-2"]] assert messager.cache_info["req-2"]["status"] == "init" diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 84dbe2ce3c5..833bd5008da 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -579,6 +579,7 @@ def test_start_prefill_branch_cache_manager_and_worker_dead(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None + eng._prepare_request_v1 = lambda: None started_cache = {} @@ -624,6 +625,7 @@ def test_start_mixed_branch_cache_after_load_and_zmq(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None + eng._prepare_request_v1 = lambda: None started_cache = {} @@ -1379,21 +1381,18 @@ def test_schedule_request_to_worker_v1_mixed_single_iteration(self): task = Request(request_id="v1_r0", prompt_token_ids=[1], prompt_token_ids_len=1) task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) - eng.resource_manager = self._make_v1_decode_rm(eng, ([], []), with_add_request=True) + eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []), with_add_request=True) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False - eng.resource_manager.add_request.assert_called_once_with(task) + eng.engine_worker_queue.put_tasks.assert_called_once() self._detach_finalizer(eng) def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): @@ -1413,7 +1412,6 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(return_value=[]), ) eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) @@ -1425,11 +1423,13 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): try: with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", False), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", + False, + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() finally: eng.running = False @@ -1450,17 +1450,14 @@ def test_schedule_request_to_worker_v1_decode_preempted_and_errors(self): task.task_type = RequestType.PREEMPTED task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")])) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1484,16 +1481,13 @@ def test_schedule_request_to_worker_v1_decode_prefill_task_path(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [])) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1515,23 +1509,20 @@ def test_schedule_request_to_worker_v1_error_task_none_skips_send(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)])) - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() eng.engine_worker_queue.put_tasks.assert_called_once() eng._send_error_response.assert_not_called() self._detach_finalizer(eng) - def test_schedule_request_to_worker_v1_threadpool_shutdown_breaks(self): + def test_schedule_request_to_worker_v1_no_tasks_sleeps(self): eng = self._make_mixed_engine() self._setup_v1_engine(eng) @@ -1539,17 +1530,7 @@ def test_schedule_request_to_worker_v1_threadpool_shutdown_breaks(self): eng.resource_manager = self._make_v1_decode_rm(eng, ([], [])) - class DummyExecutor: - def __init__(self, max_workers=None): - pass - - def submit(self, fn): - raise RuntimeError("cannot schedule new futures after shutdown") - - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", DummyExecutor), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() self._detach_finalizer(eng) @@ -1572,17 +1553,8 @@ def test_schedule_request_to_worker_v1_prefill_continuous_cache_success(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) - calls = {"n": 0} - - def get_finished_add_cache_task_req(): - if calls["n"] == 0: - calls["n"] += 1 - return ["pc_ok"] - return [] - eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1592,11 +1564,12 @@ def get_finished_add_cache_task_req(): ) with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() eng.split_connector.send_splitwise_tasks.assert_called() eng.split_connector.send_cache_info_to_messager.assert_called_once() @@ -1624,17 +1597,8 @@ def test_schedule_request_to_worker_v1_prefill_continuous_wait_async_none(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=None) - calls = {"n": 0} - - def get_finished_add_cache_task_req(): - if calls["n"] == 0: - calls["n"] += 1 - return ["pc_fail"] - return [] - eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1644,11 +1608,12 @@ def get_finished_add_cache_task_req(): ) with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() eng.scheduler.put_results.assert_called_once() eng.resource_manager.pre_recycle_resource.assert_called_once_with("pc_fail") diff --git a/tests/inter_communicator/test_e2w_queue.py b/tests/inter_communicator/test_e2w_queue.py index d3cd657f01a..333249cc66d 100644 --- a/tests/inter_communicator/test_e2w_queue.py +++ b/tests/inter_communicator/test_e2w_queue.py @@ -301,15 +301,15 @@ def test_wait_loops_and_tensor_conversion(self): client.get_finished_req() thread.join() - client.can_put_next_add_task_finished_flag.set(0) - thread = self._set_value_after_delay(client.can_put_next_add_task_finished_flag, 1) - client.put_finished_add_cache_task_req(["req-wait"]) + client.can_put_next_send_cache_finished_flag.set(0) + thread = self._set_value_after_delay(client.can_put_next_send_cache_finished_flag, 1) + client.put_finished_req([["req-wait", {"status": "ok"}]]) thread.join() - client.finished_add_cache_task_list.append(["req-wait"]) - client.client_get_finished_add_cache_task_flag[:] = [0] - thread = self._set_list_after_delay(client.client_get_finished_add_cache_task_flag, [1]) - client.get_finished_add_cache_task_req() + client.finished_send_cache_list.append(["req-wait", {"error": "bad"}]) + client.client_get_finish_send_cache_flag[:] = [0] + thread = self._set_list_after_delay(client.client_get_finish_send_cache_flag, [1]) + client.get_finished_req() thread.join() finally: paddle.set_device(previous_device) @@ -361,18 +361,6 @@ def test_finished_req_flow(self): finally: self._cleanup_queue_pair(server) - def test_finished_add_cache_task_req(self): - server, client = self._build_queue_pair() - try: - req_ids = ["req-2"] - self.assertTrue(client.put_finished_add_cache_task_req(req_ids)) - client.finished_add_cache_task_list.append(req_ids) - self.assertEqual(client.get_finished_add_cache_task_req(), req_ids) - self.assertEqual(client.get_finished_add_cache_task_req(), []) - self.assertEqual(client.can_put_next_add_task_finished_flag.get(), 1) - finally: - self._cleanup_queue_pair(server) - def test_disaggregated_queue(self): server, client = self._build_queue_pair() try: