diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 08c8dea003a..b934c3e74c7 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -613,16 +613,12 @@ def __init__( ) self.gpu_id = gpu_id - self.cache_info = dict() # {'request_id': cache_info_dict} + self.cache_info = dict() self.rank_id = self.rank + local_data_parallel_id * self.nranks self.engine_cache_task_thread_lock = threading.Lock() - self.engine_cache_tasks = [ - dict() for _ in range(512) - ] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}} - self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict} - self.cache_prefilled_engine_ids_queue = ( - queue.Queue() - ) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)] + self.engine_cache_tasks = [dict() for _ in range(512)] + self.idx_cache_task_dict = {} + self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step if splitwise_role == "prefill": consume_signals_thread = threading.Thread(target=self.consume_signals) consume_signals_thread.daemon = True @@ -642,6 +638,7 @@ def _add_cache_task_thread(self): while True: try: cache_info = self.engine_worker_queue.get_cache_info() + finished_add_cache_task_req_ids = [] if cache_info: logger.debug(f"Get cache info from engine worker queue, {cache_info}") self.engine_worker_queue.cache_info_barrier.wait() @@ -650,6 +647,7 @@ def _add_cache_task_thread(self): self.cache_info[info["request_id"]].update(info) current_info = self.cache_info[info["request_id"]] assert "dest_block_ids" in current_info and "src_block_ids" in current_info + finished_add_cache_task_req_ids.append(info["request_id"]) decode_cached_block_num = len(current_info["src_block_ids"]) - len( current_info["dest_block_ids"] ) @@ -661,13 +659,17 @@ def _add_cache_task_thread(self): current_info["sended_layer_id"] = -1 current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size current_info["status"] = "init" - logger.info(f"Get cache info and finish add cache task: {current_info}") + logger.info(f"Get cache info from D: finish add cache task: {current_info}") self.cache_info[info["request_id"]] = current_info self.idx_cache_task_dict[current_info["current_id"]] = current_info else: - logger.info(f"Get cache info: {info}") + logger.info(f"Get cache info from P: {info}") self.cache_info[info["request_id"]] = info + if finished_add_cache_task_req_ids: + logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}") + self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids) + self.engine_worker_queue.finish_add_cache_task_barrier.wait() else: time.sleep(0.001) except Exception as e: @@ -685,12 +687,10 @@ def prefill_layerwise_send_cache_thread(self): block_start_end_list = [] current_prefilled_token_num_list = [] for engine_index, current_step_prefilled_token_num in batch_engine_signals: - self._maybe_wait_for_cache_task(engine_index) assert ( engine_index in self.idx_cache_task_dict ), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}" block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"] - prefilled_token_num = current_step_prefilled_token_num if ( prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"] @@ -917,20 +917,6 @@ def _handle_connect_task(self): except Exception as e: logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}") - def _maybe_wait_for_cache_task(self, engine_index): - # If cache messager does not get cache task from engine, just hang here for now - wait_step = 1 - sleep_seconds = 0.005 - - while engine_index not in self.idx_cache_task_dict: - time.sleep(sleep_seconds) - wait_step += 1 - - if wait_step % 400 == 0: - logger.warning( - f"waiting cache task for engine_index: {engine_index}, cost_time: {wait_step * 0.005:.2f} s" - ) - def main(): device = args.device_id diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index c31c9039b40..2f96db3d6d4 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -30,6 +30,7 @@ import time import traceback import weakref +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -41,7 +42,6 @@ import fastdeploy.metrics.trace as tracing from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.config import FDConfig -from fastdeploy.engine.common_engine_prepare_mixin import EngineServicePrepareMixin from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( ControlRequest, @@ -113,7 +113,7 @@ def _format_worker_launch_failure_message(log_dir: str) -> str: return message -class EngineService(EngineServicePrepareMixin): +class EngineService: """ Base class containing common engine functionality """ @@ -250,13 +250,12 @@ def start(self, async_llm_pid=None): self.start_worker_service(async_llm_pid) if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.prepare_request_thread = threading.Thread(target=self._prepare_request_v1, daemon=True) - self.prepare_request_thread.start() - self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker_v1, daemon=True) - self.schedule_request_thread.start() + self.insert_task_to_worker_thread = threading.Thread( + target=self._schedule_request_to_worker_v1, daemon=True + ) else: - self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) - self.schedule_request_thread.start() + self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) + self.insert_task_to_worker_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() if self.cfg.scheduler_config.splitwise_role == "decode": @@ -879,19 +878,215 @@ def _schedule_request_to_worker_v1(self): Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). """ tracing.trace_set_thread_info("Scheduler Task to Work") + get_request_pool = ThreadPoolExecutor(max_workers=1) + is_fetching = False + + def _fetch_request(): + try: + with self._pause_cond: + self._pause_cond.wait_for(lambda: not self.is_paused) + nonlocal is_fetching + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + + if self.cfg.scheduler_config.splitwise_role != "mixed": + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + else: + max_num_batched_tokens = self.cfg.model_config.max_model_len + + available_blocks = self.cfg.cache_config.max_block_num_per_seq + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + for task in tasks: + task.metrics.engine_get_req_time = time.time() + trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) + + if self.cfg.scheduler_config.splitwise_role == "decode": + # TODO: refine scheduler to remove this limitation + # Decode will process and schedule the request sent by prefill to engine, + # so the same request sent by the decode api server will be ignored + is_fetching = False + return + + if tasks: + self.llm_logger.debug( + f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" + ) + + if self.cfg.scheduler_config.splitwise_role == "prefill": + for task in tasks: + # start async preprocess + self.resource_manager.apply_async_preprocess(task) + need_delete_tasks = [] + if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for request: {task.request_id}" + ) + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") + ) + task.metrics.ask_decode_resource_start_time = time.time() + while True: + self.split_connector.send_splitwise_tasks([task], task.idx) + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.warning( + f"D failed to allocate resource for request {task.request_id}, try again." + ) + time.sleep(0.05) + else: + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_END, + task.request_id, + getattr(task, "user", ""), + ) + break + self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") + else: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for req_id: {task.request_id}" + ) + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") + ) + task.metrics.ask_decode_resource_start_time = time.time() + self.split_connector.send_splitwise_tasks([task], task.idx) + + for task in tasks: + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "") + ) + if not status: + error_msg = ( + f"PD Error: prefill failed to apply for resource from decode, " + f"req: {task.request_id}, msg:{msg}." + ) + self.llm_logger.error(error_msg) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=error_msg, + ) + ] + ) + main_process_metrics.reschedule_req_num.inc() + need_delete_tasks.append(task) + continue + for tmp_task in need_delete_tasks: + tasks.remove(tmp_task) + # release resource in P + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # to send cache info to cache messager + if tasks: + need_check_req_ids = [task.request_id for task in tasks] + self.split_connector.send_cache_info_to_messager(tasks, 0) + # ensure cache tasks has sent to cache_messager + need_check_req_ids = [task.request_id for task in tasks] + finished_ids, delete_tasks_list = [], [] + while need_check_req_ids: + finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) + self.llm_logger.debug( + f"P has successfully sent cache infos to cache messager for requests: {finished_ids}" + ) + if finished_ids: + for task in tasks: + result = self.resource_manager.waiting_async_process(task) + if result is None: + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=task.error_code, + error_msg=task.error_message, + ) + ] + ) + need_check_req_ids.remove(task.request_id) + delete_tasks_list.append(task) + elif result is False: + if task.request_id in finished_ids: + need_check_req_ids.remove(task.request_id) + finished_ids.remove(task.request_id) + else: + time.sleep(0.001) + + for tmp_task in delete_tasks_list: + tasks.remove(tmp_task) + # release resource in P + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # Fetch requests and add them to the scheduling queue + if tasks: + for task in tasks: + task.metrics.add_req_to_resource_manager_time = time.time() + trace_print( + LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "") + ) + if self.cfg.scheduler_config.splitwise_role == "prefill": + self.resource_manager.add_request_in_p(tasks) + self.llm_logger.info( + f"P add requests into running queue: {[task.request_id for task in tasks]}" + ) + else: + for task in tasks: + self.resource_manager.add_request(task) + is_fetching = False + except Exception as e: + self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}") + is_fetching = False while self.running: with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) - try: if self.engine_worker_queue.exist_tasks(): time.sleep(0.001) continue + if self.cfg.scheduler_config.splitwise_role != "mixed": + if not is_fetching: + is_fetching = True + get_request_pool.submit(_fetch_request) + + else: + if len(self.resource_manager.waiting) == 0 and (not is_fetching): + # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. + try: + is_fetching = True + get_request_pool.submit(_fetch_request) + except RuntimeError as e: + if "shutdown" in str(e): + self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") + break + else: + raise if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() - # 2. Schedule requests tasks, error_tasks = self.resource_manager.schedule() diff --git a/fastdeploy/engine/common_engine_prepare_mixin.py b/fastdeploy/engine/common_engine_prepare_mixin.py deleted file mode 100644 index 71327025458..00000000000 --- a/fastdeploy/engine/common_engine_prepare_mixin.py +++ /dev/null @@ -1,282 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -from __future__ import annotations - -import threading -import time -import traceback - -import fastdeploy.metrics.trace as tracing -from fastdeploy.engine.request import RequestOutput -from fastdeploy.metrics.metrics import main_process_metrics -from fastdeploy.trace.constants import LoggingEventName -from fastdeploy.trace.trace_logger import print as trace_print -from fastdeploy.utils import envs - - -class EngineServicePrepareMixin: - def _fetch_request_mixed(self) -> bool: - """Fetch and prepare requests for a mixed instance. Returns True if tasks were fetched.""" - # FIXME: to validate if it's necessary for avoiding error when enable mtp - if len(self.resource_manager.waiting) > 0: - return False - - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - max_num_batched_tokens = self.cfg.model_config.max_model_len - available_blocks = self.cfg.cache_config.max_block_num_per_seq - - tasks = self.scheduler.get_requests( - available_blocks=available_blocks, - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=0, - max_num_batched_tokens=max_num_batched_tokens, - batch=num_prefill_batch, - ) - if not tasks: - return False - - for task in tasks: - task.metrics.engine_get_req_time = time.time() - trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) - - self.llm_logger.debug( - f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" - ) - - for task in tasks: - task.metrics.add_req_to_resource_manager_time = time.time() - trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) - self.resource_manager.add_request(task) - - return True - - def _fetch_request_decode(self) -> bool: - """Consume scheduler queue for decode instance to prevent memory accumulation. - Returns True if tasks were consumed.""" - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens - available_blocks = self.cfg.cache_config.max_block_num_per_seq - - tasks = self.scheduler.get_requests( - available_blocks=available_blocks, - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=0, - max_num_batched_tokens=max_num_batched_tokens, - batch=num_prefill_batch, - ) - # Tasks are intentionally discarded - decode receives requests via _decode_process_splitwise_requests - return len(tasks) > 0 - - def _fetch_request_prefill(self) -> bool: - """Fetch and prepare requests for a prefill instance. Returns True if tasks were fetched.""" - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens - available_blocks = self.cfg.cache_config.max_block_num_per_seq - - tasks = self.scheduler.get_requests( - available_blocks=available_blocks, - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=0, - max_num_batched_tokens=max_num_batched_tokens, - batch=num_prefill_batch, - ) - if not tasks: - return False - - for task in tasks: - task.metrics.engine_get_req_time = time.time() - trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) - - self.llm_logger.debug( - f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" - ) - - # Start async preprocess for all tasks in this batch - for task in tasks: - self.resource_manager.apply_async_preprocess(task) - - # P-side resource preallocation + D-side coordination - failed_tasks = [] - if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for request: {task.request_id}" - ) - trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) - task.metrics.ask_decode_resource_start_time = time.time() - while True: - self.split_connector.send_splitwise_tasks([task], task.idx) - status, msg = self.split_connector.check_decode_allocated(task) - if status: - task.metrics.ask_decode_resource_finish_time = time.time() - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_END, - task.request_id, - getattr(task, "user", ""), - ) - break - else: - self.llm_logger.warning( - f"D failed to allocate resource for request {task.request_id}, try again." - ) - time.sleep(0.05) - - self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") - else: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for req_id: {task.request_id}" - ) - trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) - task.metrics.ask_decode_resource_start_time = time.time() - self.split_connector.send_splitwise_tasks([task], task.idx) - - for task in tasks: - # assure fetch block ids from D - status, msg = self.split_connector.check_decode_allocated(task) - task.metrics.ask_decode_resource_finish_time = time.time() - trace_print(LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "")) - if not status: - error_msg = ( - f"PD Error: prefill failed to apply for resource from decode, " - f"req: {task.request_id}, msg:{msg}." - ) - self.llm_logger.error(error_msg) - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=error_msg, - ) - ] - ) - main_process_metrics.reschedule_req_num.inc() - failed_tasks.append(task) - - for tmp_task in failed_tasks: - tasks.remove(tmp_task) - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # Check and wait async preprocess - if tasks: - need_check_req_ids = [task.request_id for task in tasks] - failed_tasks = [] - - while need_check_req_ids: - still_in_progress = False - for task in tasks: - if task.request_id not in need_check_req_ids: - continue - - result = self.resource_manager.waiting_async_process(task) - if result is False: # async preprocess success - need_check_req_ids.remove(task.request_id) - elif result is True: - still_in_progress = True - elif result is None: # async preprocess failed - failed_tasks.append(task) - need_check_req_ids.remove(task.request_id) - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=task.error_code, - error_msg=task.error_message, - ) - ] - ) - - if still_in_progress: - time.sleep(0.005) - - for tmp_task in failed_tasks: - tasks.remove(tmp_task) - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # Send cache info to messager - if tasks: - self.split_connector.send_cache_info_to_messager(tasks, 0) - - # Fetch requests and add them to the scheduling queue - if tasks: - for task in tasks: - task.metrics.add_req_to_resource_manager_time = time.time() - trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) - self.resource_manager.add_request_in_p(tasks) - self.llm_logger.info(f"P add requests into running queue: {[task.request_id for task in tasks]}") - - return True - - def _fetch_loop(self, fetch_fn, thread_idx: int): - """Fetch loop run by each worker thread.""" - tracing.trace_set_thread_info(f"Prepare Request for Scheduling - thread {thread_idx}") - while self.running: - try: - with self._pause_cond: - self._pause_cond.wait_for(lambda: not self.is_paused) - fetch_fn() - time.sleep(0.002) - except Exception as e: - self.llm_logger.error(f"fetching request error in worker-{thread_idx}: {e} {traceback.format_exc()}") - time.sleep(0.002) - - def _prepare_request_v1(self): - """Prepare request and send to the queue for scheduling""" - tracing.trace_set_thread_info("Prepare Request for Scheduling") - role = self.cfg.scheduler_config.splitwise_role - num_workers = envs.FD_PREFILL_PREPARE_REQ_THREAD_NUM if role == "prefill" else 1 - self.llm_logger.info(f"prepare request for scheduling, role: {role}, num_workers: {num_workers}") - - fetch_fn = { - "mixed": self._fetch_request_mixed, - "prefill": self._fetch_request_prefill, - "decode": self._fetch_request_decode, - }[role] - - self._fetch_threads = [] - for i in range(num_workers): - t = threading.Thread( - target=self._fetch_loop, - args=(fetch_fn, i), - daemon=True, - name=f"fetch-{i}", - ) - t.start() - self._fetch_threads.append(t) - - # Keep this thread alive for graceful shutdown - while self.running: - time.sleep(1.0) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 509f9a768d9..cdcfd5b4092 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -191,8 +191,6 @@ def _validate_split_kv_size(value: int) -> int: "FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")), # "Enable FP8 calibration on HPU" "FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"), - # Number of worker threads for prepare requests in prefill instance - "FD_PREFILL_PREPARE_REQ_THREAD_NUM": lambda: int(os.getenv("FD_PREFILL_PREPARE_REQ_THREAD_NUM", "5")), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), "FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int( os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1") diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index 2cb1246aad3..b0fc9bb3385 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -92,6 +92,7 @@ class QueueManager(BaseManager): Value("i", 0) for _ in range(self.local_data_parallel_size) ] self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)] + self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] @@ -109,6 +110,9 @@ class QueueManager(BaseManager): self.connect_task_response_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # connect rdma task response + self.finish_add_cache_task_lock_init: List[threading.Lock] = [ + threading.Lock() for _ in range(self.local_data_parallel_size) + ] # finish add cache task self.finish_send_cache_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # finish send cache @@ -120,12 +124,18 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] + self.client_get_finished_add_cache_task_flag_init: List[List[int]] = [ + [0] * self.num_client for _ in range(self.local_data_parallel_size) + ] self.client_get_finish_send_cache_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] self.can_put_next_connect_task_response_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] + self.can_put_next_add_task_finished_flag_init: List[Value] = [ + Value("i", 1) for _ in range(self.local_data_parallel_size) + ] self.can_put_next_send_cache_finished_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] @@ -137,6 +147,9 @@ class QueueManager(BaseManager): self.get_connect_task_response_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] + self.finish_add_cache_task_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] self.begin_send_cache_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] @@ -175,6 +188,11 @@ class QueueManager(BaseManager): callable=lambda idx: self.client_get_connect_task_response_flag_init[idx], proxytype=ListProxy, ) + QueueManager.register( + "get_client_get_finished_add_cache_task_flag_init", + callable=lambda idx: self.client_get_finished_add_cache_task_flag_init[idx], + proxytype=ListProxy, + ) QueueManager.register( "get_client_get_finish_send_cache_flag_init", callable=lambda idx: self.client_get_finish_send_cache_flag_init[idx], @@ -200,6 +218,11 @@ class QueueManager(BaseManager): callable=lambda idx: self.can_put_next_connect_task_response_flag_init[idx], proxytype=ValueProxy, ) + QueueManager.register( + "get_can_put_next_add_task_finished_flag", + callable=lambda idx: self.can_put_next_add_task_finished_flag_init[idx], + proxytype=ValueProxy, + ) QueueManager.register( "get_can_put_next_send_cache_finished_flag", callable=lambda idx: self.can_put_next_send_cache_finished_flag_init[idx], @@ -216,6 +239,11 @@ class QueueManager(BaseManager): callable=lambda idx: self.connect_task_response_lock_init[idx], proxytype=AcquirerProxy, ) + QueueManager.register( + "get_finish_add_cache_task_lock", + callable=lambda idx: self.finish_add_cache_task_lock_init[idx], + proxytype=AcquirerProxy, + ) QueueManager.register( "get_finish_send_cache_lock", callable=lambda idx: self.finish_send_cache_lock_init[idx], @@ -240,6 +268,12 @@ class QueueManager(BaseManager): "get_finish_request_queue", callable=lambda idx: self.finished_req_list[idx], proxytype=ListProxy ) + QueueManager.register( + "get_finish_add_cache_task_queue", + callable=lambda idx: self.finished_add_cache_task_list[idx], + proxytype=ListProxy, + ) + QueueManager.register( "get_cache_infos", callable=lambda idx: self.cache_infos_init[idx], @@ -287,6 +321,12 @@ class QueueManager(BaseManager): "get_cache_info_barrier", callable=lambda idx: self.get_cache_info_barrier[idx], ) + + QueueManager.register( + "get_finish_add_cache_task_barrier", + callable=lambda idx: self.finish_add_cache_task_barrier[idx], + ) + QueueManager.register( "get_worker_process_tp_barrier", callable=lambda idx: self.worker_process_tp_barrier[idx], @@ -311,11 +351,13 @@ class QueueManager(BaseManager): QueueManager.register("get_exist_tasks_inter_signal") QueueManager.register("get_connected_client_counter") QueueManager.register("get_finish_request_queue") + QueueManager.register("get_finish_add_cache_task_queue") QueueManager.register("get_cache_infos") QueueManager.register("get_client_read_info_flag") QueueManager.register("get_lock_info") QueueManager.register("get_disaggregate_requests") QueueManager.register("get_finish_request_barrier") + QueueManager.register("get_finish_add_cache_task_barrier") QueueManager.register("get_connect_task_barrier") QueueManager.register("get_connect_task_response_barrier") QueueManager.register("get_finish_send_cache_barrier") @@ -324,13 +366,16 @@ class QueueManager(BaseManager): QueueManager.register("get_connect_rdma_tasks") QueueManager.register("get_client_get_connect_task_flag") QueueManager.register("get_client_get_connect_task_response_flag") + QueueManager.register("get_client_get_finished_add_cache_task_flag_init") QueueManager.register("get_client_get_finish_send_cache_flag_init") QueueManager.register("get_connect_rdma_tasks_responses") QueueManager.register("get_connect_task_lock") QueueManager.register("get_connect_task_response_lock") + QueueManager.register("get_finish_add_cache_task_lock") QueueManager.register("get_finish_send_cache_lock") QueueManager.register("get_worker_process_tp_barrier") QueueManager.register("get_can_put_next_connect_task_response_flag") + QueueManager.register("get_can_put_next_add_task_finished_flag") QueueManager.register("get_can_put_next_send_cache_finished_flag") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -353,6 +398,9 @@ class QueueManager(BaseManager): # p/d 分离获取 self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id) self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) + self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier( + self.local_data_parallel_id + ) self.connect_task_barrier = self.manager.get_connect_task_barrier(self.local_data_parallel_id) self.connect_task_response_barrier = self.manager.get_connect_task_response_barrier( self.local_data_parallel_id @@ -362,6 +410,9 @@ class QueueManager(BaseManager): self.begin_send_cache_barrier = self.manager.get_begin_send_cache_barrier(self.local_data_parallel_id) self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id) self.finished_send_cache_list = self.manager.get_finish_request_queue(self.local_data_parallel_id) + self.finished_add_cache_task_list = self.manager.get_finish_add_cache_task_queue( + self.local_data_parallel_id + ) # p/d互联 self.connect_rdma_tasks = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) self.client_get_connect_task_flag = self.manager.get_client_get_connect_task_flag( @@ -370,6 +421,9 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag = self.manager.get_client_get_connect_task_response_flag( self.local_data_parallel_id ) + self.client_get_finished_add_cache_task_flag = ( + self.manager.get_client_get_finished_add_cache_task_flag_init(self.local_data_parallel_id) + ) self.client_get_finish_send_cache_flag = self.manager.get_client_get_finish_send_cache_flag_init( self.local_data_parallel_id ) @@ -379,8 +433,12 @@ class QueueManager(BaseManager): ) self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) self.connect_task_response_lock = self.manager.get_connect_task_response_lock(self.local_data_parallel_id) + self.finish_add_cache_task_lock = self.manager.get_finish_add_cache_task_lock(self.local_data_parallel_id) self.finish_send_cache_lock = self.manager.get_finish_send_cache_lock(self.local_data_parallel_id) + self.can_put_next_add_task_finished_flag = self.manager.get_can_put_next_add_task_finished_flag( + self.local_data_parallel_id + ) self.can_put_next_connect_task_response_flag = self.manager.get_can_put_next_connect_task_response_flag( self.local_data_parallel_id ) @@ -698,6 +756,54 @@ def get_finished_req(self) -> str: self.finish_send_cache_lock.release() return response + def put_finished_add_cache_task_req(self, req_ids) -> None: + """ + Put finished request ID into the queue. + + Args: + req_ids: Request ID to be added to the queue + """ + self.finish_add_cache_task_lock.acquire() + while not self.can_put_next_add_task_finished_flag.get(): + self.finish_add_cache_task_lock.release() + time.sleep(0.001) + self.finish_add_cache_task_lock.acquire() + self.finished_add_cache_task_list.append(req_ids) + self.client_get_finished_add_cache_task_flag[self.client_id] = 1 + all_client_put: bool = np.sum(self.client_get_finished_add_cache_task_flag) == self.num_client + if all_client_put: + self.can_put_next_add_task_finished_flag.set(0) + self.finish_add_cache_task_lock.release() + return all_client_put + + def get_finished_add_cache_task_req(self) -> str: + """ + Get finished request ID from the queue. + + Returns: + str: Finished request ID + """ + response = [] + self.finish_add_cache_task_lock.acquire() + if len(self.finished_add_cache_task_list) == 0: + self.finish_add_cache_task_lock.release() + return response + while sum(self.client_get_finished_add_cache_task_flag) < self.num_client: + self.finish_add_cache_task_lock.release() + time.sleep(0.001) + self.finish_add_cache_task_lock.acquire() + if len(self.finished_add_cache_task_list) > 0: + response = self.finished_add_cache_task_list[0] + for tmp_response in self.finished_add_cache_task_list: + assert ( + tmp_response == response + ), f"Inconsistent responses across workers: expected {response}, got {tmp_response}" + self.finished_add_cache_task_list[:] = list() + self.client_get_finished_add_cache_task_flag[:] = [0] * self.num_client + self.can_put_next_add_task_finished_flag.set(1) + self.finish_add_cache_task_lock.release() + return response + def disaggregate_queue_empty(self): """ Check if the disaggregated task queue is empty. diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py index 3e415ebe9c8..c69d27a24fa 100644 --- a/tests/cache_manager/test_cache_messager.py +++ b/tests/cache_manager/test_cache_messager.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib.util -import os import sys import types @@ -24,21 +22,7 @@ if not hasattr(paddle, "enable_compat"): paddle.enable_compat = lambda *args, **kwargs: None -# Import the legacy cache_messager module directly from the .py file, -# because the cache_messager/ package shadows it and the legacy -# fallback (cache_messager_legacy) does not exist locally. -_cm_py_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "fastdeploy", - "cache_manager", - "cache_messager.py", -) -_spec = importlib.util.spec_from_file_location( - "fastdeploy.cache_manager.cache_messager_py", - _cm_py_path, -) -cache_messager = importlib.util.module_from_spec(_spec) -_spec.loader.exec_module(cache_messager) +from fastdeploy.cache_manager import cache_messager class _DummyBarrier: @@ -56,10 +40,12 @@ def __init__(self, cache_info_sequence=None, connect_task_sequence=None, **kwarg self.cache_info_calls = 0 self.connect_task_calls = 0 self.cache_info_barrier = _DummyBarrier() + self.finish_add_cache_task_barrier = _DummyBarrier() self.finish_send_cache_barrier = _DummyBarrier() self.connect_task_barrier = _DummyBarrier() self.connect_task_response_barrier = _DummyBarrier() self.begin_send_cache_barrier = _DummyBarrier() + self.finished_add_cache_task_req_ids = [] self.finished_req_payloads = [] self.connect_task_responses = [] @@ -70,6 +56,9 @@ def get_cache_info(self): self.cache_info_calls += 1 return info + def put_finished_add_cache_task_req(self, req_ids): + self.finished_add_cache_task_req_ids.append(req_ids) + def put_finished_req(self, payload): self.finished_req_payloads.append(payload) @@ -387,6 +376,7 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch): } with pytest.raises(SystemExit): messager._add_cache_task_thread() + assert dummy_queue.finished_add_cache_task_req_ids == [["req-2"]] assert messager.cache_info["req-2"]["status"] == "init" diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 7e7c660964b..63c8c5165ae 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -579,7 +579,6 @@ def test_start_prefill_branch_cache_manager_and_worker_dead(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None - eng._prepare_request_v1 = lambda: None started_cache = {} @@ -625,7 +624,6 @@ def test_start_mixed_branch_cache_after_load_and_zmq(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None - eng._prepare_request_v1 = lambda: None started_cache = {} @@ -1388,18 +1386,21 @@ def test_schedule_request_to_worker_v1_mixed_single_iteration(self): task = Request(request_id="v1_r0", prompt_token_ids=[1], prompt_token_ids_len=1) task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(put_results=Mock()) + eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) - eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []), with_add_request=True) + eng.resource_manager = self._make_v1_decode_rm(eng, ([], []), with_add_request=True) try: - with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + with ( + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + ): eng._schedule_request_to_worker_v1() finally: eng.running = False - eng.engine_worker_queue.put_tasks.assert_called_once() + eng.resource_manager.add_request.assert_called_once_with(task) self._detach_finalizer(eng) def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): @@ -1419,6 +1420,7 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), + get_finished_add_cache_task_req=Mock(return_value=[]), ) eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) @@ -1430,13 +1432,11 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): try: with ( - patch( - "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", - False, - ), - patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), + patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", False), + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), ): - eng._fetch_request_prefill() + eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1457,14 +1457,17 @@ def test_schedule_request_to_worker_v1_decode_preempted_and_errors(self): task.task_type = RequestType.PREEMPTED task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(put_results=Mock()) + eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")])) try: - with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + with ( + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + ): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1488,13 +1491,16 @@ def test_schedule_request_to_worker_v1_decode_prefill_task_path(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(put_results=Mock()) + eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [])) try: - with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + with ( + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + ): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1516,20 +1522,23 @@ def test_schedule_request_to_worker_v1_error_task_none_skips_send(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(put_results=Mock()) + eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)])) - with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + with ( + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + ): eng._schedule_request_to_worker_v1() eng.engine_worker_queue.put_tasks.assert_called_once() eng._send_error_response.assert_not_called() self._detach_finalizer(eng) - def test_schedule_request_to_worker_v1_no_tasks_sleeps(self): + def test_schedule_request_to_worker_v1_threadpool_shutdown_breaks(self): eng = self._make_mixed_engine() self._setup_v1_engine(eng) @@ -1537,7 +1546,17 @@ def test_schedule_request_to_worker_v1_no_tasks_sleeps(self): eng.resource_manager = self._make_v1_decode_rm(eng, ([], [])) - with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + class DummyExecutor: + def __init__(self, max_workers=None): + pass + + def submit(self, fn): + raise RuntimeError("cannot schedule new futures after shutdown") + + with ( + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", DummyExecutor), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + ): eng._schedule_request_to_worker_v1() self._detach_finalizer(eng) @@ -1560,8 +1579,17 @@ def test_schedule_request_to_worker_v1_prefill_continuous_cache_success(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) + calls = {"n": 0} + + def get_finished_add_cache_task_req(): + if calls["n"] == 0: + calls["n"] += 1 + return ["pc_ok"] + return [] + eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), + get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1571,12 +1599,11 @@ def test_schedule_request_to_worker_v1_prefill_continuous_cache_success(self): ) with ( - patch( - "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True - ), - patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), + patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), ): - eng._fetch_request_prefill() + eng._schedule_request_to_worker_v1() eng.split_connector.send_splitwise_tasks.assert_called() eng.split_connector.send_cache_info_to_messager.assert_called_once() @@ -1604,8 +1631,17 @@ def test_schedule_request_to_worker_v1_prefill_continuous_wait_async_none(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=None) + calls = {"n": 0} + + def get_finished_add_cache_task_req(): + if calls["n"] == 0: + calls["n"] += 1 + return ["pc_fail"] + return [] + eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), + get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1615,12 +1651,11 @@ def test_schedule_request_to_worker_v1_prefill_continuous_wait_async_none(self): ) with ( - patch( - "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True - ), - patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), + patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), + patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), + patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), ): - eng._fetch_request_prefill() + eng._schedule_request_to_worker_v1() eng.scheduler.put_results.assert_called_once() eng.resource_manager.pre_recycle_resource.assert_called_once_with("pc_fail") diff --git a/tests/inter_communicator/test_e2w_queue.py b/tests/inter_communicator/test_e2w_queue.py index 333249cc66d..d3cd657f01a 100644 --- a/tests/inter_communicator/test_e2w_queue.py +++ b/tests/inter_communicator/test_e2w_queue.py @@ -301,15 +301,15 @@ def test_wait_loops_and_tensor_conversion(self): client.get_finished_req() thread.join() - client.can_put_next_send_cache_finished_flag.set(0) - thread = self._set_value_after_delay(client.can_put_next_send_cache_finished_flag, 1) - client.put_finished_req([["req-wait", {"status": "ok"}]]) + client.can_put_next_add_task_finished_flag.set(0) + thread = self._set_value_after_delay(client.can_put_next_add_task_finished_flag, 1) + client.put_finished_add_cache_task_req(["req-wait"]) thread.join() - client.finished_send_cache_list.append(["req-wait", {"error": "bad"}]) - client.client_get_finish_send_cache_flag[:] = [0] - thread = self._set_list_after_delay(client.client_get_finish_send_cache_flag, [1]) - client.get_finished_req() + client.finished_add_cache_task_list.append(["req-wait"]) + client.client_get_finished_add_cache_task_flag[:] = [0] + thread = self._set_list_after_delay(client.client_get_finished_add_cache_task_flag, [1]) + client.get_finished_add_cache_task_req() thread.join() finally: paddle.set_device(previous_device) @@ -361,6 +361,18 @@ def test_finished_req_flow(self): finally: self._cleanup_queue_pair(server) + def test_finished_add_cache_task_req(self): + server, client = self._build_queue_pair() + try: + req_ids = ["req-2"] + self.assertTrue(client.put_finished_add_cache_task_req(req_ids)) + client.finished_add_cache_task_list.append(req_ids) + self.assertEqual(client.get_finished_add_cache_task_req(), req_ids) + self.assertEqual(client.get_finished_add_cache_task_req(), []) + self.assertEqual(client.can_put_next_add_task_finished_flag.get(), 1) + finally: + self._cleanup_queue_pair(server) + def test_disaggregated_queue(self): server, client = self._build_queue_pair() try: