diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 50017baf5de..e0d316d356d 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1166,6 +1166,10 @@ def _fetch_request(): except Exception as e: err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) self.llm_logger.error(err_msg) + # Failed to connect to engine worker queue, retry after 5 seconds + if self.engine_worker_queue.is_broken(): + self.llm_logger.error("Failed to connect to engine worker queue, retry after 5 seconds") + time.sleep(5) def _get_scheduler_unhandled_request_num(self) -> int: """ diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 7e0f809d5d3..08cef5849e1 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -178,8 +178,8 @@ def _validate_split_kv_size(value: int) -> int: "PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES": lambda: int( os.getenv("PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", "1") ), - "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), - "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), + "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "1")), + "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "1")), "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), "FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")), diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index a7876669f8f..b0fc9bb3385 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -848,3 +848,13 @@ def cleanup(self): """ if self.manager is not None and self.is_server: self.manager.shutdown() + + def is_broken(self): + try: + self.manager.connect() + return False + except (ConnectionRefusedError, ConnectionResetError, BrokenPipeError, EOFError, OSError): + llm_logger.error("Failed to connect to engine worker queue") + return True + except Exception: + return False diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 6c0b72ae8ae..965fca6d96c 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -626,6 +626,12 @@ def is_port_available(host, port): import errno import socket + # If FD_ENGINE_TASK_QUEUE_WITH_SHM is enabled, then check the file socket is available + if envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + socket_path = f"/dev/shm/fd_task_queue_{port}.sock" + if not is_file_socket_available(socket_path): + return False + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -637,6 +643,35 @@ def is_port_available(host, port): return True +def is_file_socket_available(socket_path): + """ + Check the Unix domain socket (file socket) is available. + + Args: + socket_path: Path to the socket file, e.g. /dev/shm/fd_task_queue_8000.sock + + Returns: + True if the socket is available (not in use), False otherwise. + """ + import errno + import os + import socket + + if not os.path.exists(socket_path): + return True + + # File exists, try to connect to see if someone is listening + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + try: + s.connect(socket_path) + return False + except OSError as e: + if e.errno in (errno.ECONNREFUSED, errno.ENOENT): + # Stale socket file: exists but nobody is listening + return True + return False + + def find_free_ports( port_range: tuple[int, int] = (8000, 65535), num_ports: int = 1, diff --git a/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py index 6d8dfac53fd..7c3c6657434 100644 --- a/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py +++ b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py @@ -16,7 +16,6 @@ import queue import shutil import signal -import socket import subprocess import sys import time @@ -30,6 +29,7 @@ sys.path.insert(0, project_root) from ci_use.EB_Lite_with_adapter.zmq_client import LLMControlClient, LLMReqClient +from e2e.utils.serving_utils import clean_ports, is_port_open env = os.environ.copy() @@ -79,88 +79,6 @@ def zmq_control_client(): return client -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False - - -def kill_process_on_port(port: int): - """ - Kill processes that are listening on the given port. - Uses multiple methods to ensure thorough cleanup. - """ - current_pid = os.getpid() - parent_pid = os.getppid() - - # Method 1: Use lsof to find processes - try: - output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() - for pid in output.splitlines(): - pid = int(pid) - if pid in (current_pid, parent_pid): - print(f"Skip killing current process (pid={pid}) on port {port}") - continue - try: - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - time.sleep(1) - # Then SIGKILL if still running - os.kill(pid, signal.SIGKILL) - print(f"Killed process on port {port}, pid={pid}") - except ProcessLookupError: - pass # Process already terminated - except subprocess.CalledProcessError: - pass - - # Method 2: Use netstat and fuser as backup - try: - # Find processes using netstat and awk - cmd = f"netstat -tulpn 2>/dev/null | grep :{port} | awk '{{print $7}}' | cut -d'/' -f1" - output = subprocess.check_output(cmd, shell=True).decode().strip() - for pid in output.splitlines(): - if pid and pid.isdigit(): - pid = int(pid) - if pid in (current_pid, parent_pid): - continue - try: - os.kill(pid, signal.SIGKILL) - print(f"Killed process (netstat) on port {port}, pid={pid}") - except ProcessLookupError: - pass - except (subprocess.CalledProcessError, FileNotFoundError): - pass - - # Method 3: Use fuser if available - try: - subprocess.run(f"fuser -k {port}/tcp", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def clean_ports(): - """ - Kill all processes occupying the ports listed in PORTS_TO_CLEAN. - """ - print(f"Cleaning ports: {PORTS_TO_CLEAN}") - for port in PORTS_TO_CLEAN: - kill_process_on_port(port) - - # Double check and retry if ports are still in use - time.sleep(2) - for port in PORTS_TO_CLEAN: - if is_port_open("127.0.0.1", port, timeout=0.1): - print(f"Port {port} still in use, retrying cleanup...") - kill_process_on_port(port) - time.sleep(1) - - @pytest.fixture(scope="session", autouse=True) def setup_and_run_server(): """ @@ -170,8 +88,15 @@ def setup_and_run_server(): - Waits for server port to open (up to 30 seconds) - Tears down server after all tests finish """ + # 清理/dev/shm中的临时文件 + try: + subprocess.run("rm -rf /dev/shm/*", shell=True) + print("Successfully cleaned up /dev/shm.") + except Exception as e: + print(f"Failed to cleanup /dev/shm: {e}") + print("Pre-test port cleanup...") - clean_ports() + clean_ports(PORTS_TO_CLEAN) base_path = os.getenv("MODEL_PATH") if base_path: @@ -236,7 +161,7 @@ def setup_and_run_server(): print("\n===== Post-test server cleanup... =====") try: os.killpg(process.pid, signal.SIGTERM) - clean_ports() + clean_ports(PORTS_TO_CLEAN) print(f"API server (pid={process.pid}) terminated") except Exception as e: print(f"Failed to terminate API server: {e}") diff --git a/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py b/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py index fde03d70ee1..b42799ce066 100644 --- a/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py +++ b/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py @@ -13,9 +13,7 @@ # limitations under the License. import os -import signal -import socket -import subprocess +import sys import time import traceback @@ -23,21 +21,17 @@ from fastdeploy import LLM, SamplingParams -FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313)) -FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) -MAX_WAIT_SECONDS = 60 - +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..", "..")) +sys.path.insert(0, project_root) +from e2e.utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + clean_ports, +) -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False +MAX_WAIT_SECONDS = 60 def format_chat_prompt(messages): @@ -74,19 +68,15 @@ def llm(model_path): """ Fixture to initialize the LLM model with a given model path """ - try: - output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip() - for pid in output.splitlines(): - os.kill(int(pid), signal.SIGKILL) - print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}") - except subprocess.CalledProcessError: - pass + # Clean ports before starting the test + clean_ports() try: start = time.time() llm = LLM( model=model_path, tensor_parallel_size=1, + port=FD_API_PORT, engine_worker_queue_port=FD_ENGINE_QUEUE_PORT, cache_queue_port=FD_CACHE_QUEUE_PORT, max_model_len=32768, @@ -94,15 +84,7 @@ def llm(model_path): logits_processors=["LogitBiasLogitsProcessor"], ) - # Wait for the port to be open - wait_start = time.time() - while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT): - if time.time() - wait_start > MAX_WAIT_SECONDS: - pytest.fail( - f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}" - ) - time.sleep(1) - + time.sleep(2) print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.") yield llm except Exception: diff --git a/tests/e2e/utils/serving_utils.py b/tests/e2e/utils/serving_utils.py index 6dd5e77c9b7..9e47ca177e7 100644 --- a/tests/e2e/utils/serving_utils.py +++ b/tests/e2e/utils/serving_utils.py @@ -98,6 +98,60 @@ def kill_process_on_port(port: int): pass +def kill_process_by_unix_socket( + socket_path: str, + force: bool = True, +): + """ + 根据 unix socket 文件路径杀掉对应进程 + cmd: ss -xlpn | grep /dev/shm/fd_task_queue_8664.sock + Args: + socket_path: 例如 /dev/shm/fd_task_queue_8664.sock + force: + True -> SIGKILL + False -> SIGTERM + Returns: + pid 或 None + """ + try: + output = subprocess.check_output( + ["ss", "-xlpn"], + text=True, + ) + for line in output.splitlines(): + if socket_path not in line: + continue + m = re.search(r"pid=(\d+)", line) + if not m: + continue + pid = int(m.group(1)) + os.kill( + pid, + signal.SIGKILL if force else signal.SIGTERM, + ) + return pid + except Exception: + pass + return None + + +def cleanup_unix_socket(socket_path: str): + if not os.path.exists(socket_path): + return + try: + pid = kill_process_by_unix_socket(socket_path) + print(f"Killed process by unix socket: {socket_path}, pid={pid}") + except Exception as e: + print(f"Failed to kill process by unix socket: {socket_path}, error={e}") + finally: + try: + if os.path.exists(socket_path): + os.remove(socket_path) + print(f"Cleaned unix socket: {socket_path}") + except Exception: + pass + + def clean_ports(ports=None): """ Kill all processes occupying the ports @@ -117,6 +171,11 @@ def clean_ports(ports=None): kill_process_on_port(port) time.sleep(1) + # Clean unix socket, fd_task_queue_*.sock, for FD_ENGINE_TASK_QUEUE_WITH_SHM = 1 + print("Cleaning unix socket") + for port in ports: + cleanup_unix_socket(f"/dev/shm/fd_task_queue_{port}.sock") + def clean(ports=None): """ diff --git a/tests/utils/test_find_free_ports.py b/tests/utils/test_find_free_ports.py new file mode 100644 index 00000000000..3ffe272443e --- /dev/null +++ b/tests/utils/test_find_free_ports.py @@ -0,0 +1,212 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from unittest.mock import patch + +import pytest + +from fastdeploy.utils import find_free_ports + + +class TestFindFreePorts: + """Unit tests for find_free_ports function.""" + + def test_find_single_free_port_success(self): + """Test finding a single free port successfully.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(20000, 20100), num_ports=1) + assert len(ports) == 1 + assert 20000 <= ports[0] <= 20100 + + def test_find_multiple_free_ports_success(self): + """Test finding multiple free ports successfully.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(20000, 20100), num_ports=5) + assert len(ports) == 5 + for port in ports: + assert 20000 <= port <= 20100 + + def test_find_ports_with_custom_host(self): + """Test finding ports with a custom host.""" + with patch("fastdeploy.utils.is_port_available", return_value=True) as mock_avail: + ports = find_free_ports(port_range=(30000, 30010), num_ports=2, host="127.0.0.1") + assert len(ports) == 2 + # Verify is_port_available was called with the custom host + for call in mock_avail.call_args_list: + assert call[0][0] == "127.0.0.1" + + def test_find_all_ports_in_range(self): + """Test finding all ports in a small range.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(40000, 40002), num_ports=3) + assert len(ports) == 3 + # All ports should be from the range + expected_ports = {40000, 40001, 40002} + assert set(ports) == expected_ports + + def test_invalid_port_range_start_negative(self): + """Test ValueError when port range start is negative.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(-1, 1000)) + + def test_invalid_port_range_end_exceeds_max(self): + """Test ValueError when port range end exceeds 65535.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(1000, 65536)) + + def test_invalid_port_range_start_greater_than_end(self): + """Test ValueError when port range start is greater than end.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(10000, 9000)) + + def test_invalid_port_range_boundary_values(self): + """Test port range boundary at exactly 0 and 65535.""" + # Valid: start = 0 + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(0, 100), num_ports=1) + assert len(ports) == 1 + + # Valid: end = 65535 + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(65530, 65535), num_ports=1) + assert len(ports) == 1 + + def test_num_ports_zero_raises_error(self): + """Test ValueError when num_ports is zero.""" + with pytest.raises(ValueError, match="num_ports must be a positive integer"): + find_free_ports(port_range=(20000, 30000), num_ports=0) + + def test_num_ports_negative_raises_error(self): + """Test ValueError when num_ports is negative.""" + with pytest.raises(ValueError, match="num_ports must be a positive integer"): + find_free_ports(port_range=(20000, 30000), num_ports=-1) + + def test_num_ports_larger_than_range_size(self): + """Test ValueError when num_ports exceeds the range size.""" + # Range has only 5 ports (100-104), but requesting 6 + with pytest.raises(ValueError, match="num_ports is larger than range size"): + find_free_ports(port_range=(100, 104), num_ports=6) + + def test_not_enough_free_ports_raises_runtime_error(self): + """Test RuntimeError when not enough free ports are available.""" + # Mock to return False for all ports + with patch("fastdeploy.utils.is_port_available", return_value=False): + with pytest.raises(RuntimeError, match="Only found 0 free ports"): + find_free_ports(port_range=(20000, 20010), num_ports=3) + + def test_partial_free_ports_raises_runtime_error(self): + """Test RuntimeError when only some ports are free.""" + call_count = [0] + + def mock_availability(host, port): + # Only first 2 ports are available + call_count[0] += 1 + return call_count[0] <= 2 + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with pytest.raises(RuntimeError, match="Only found 2 free ports"): + find_free_ports(port_range=(20000, 20005), num_ports=5) + + def test_random_start_offset(self): + """Test that port scanning starts from a random offset.""" + # Track the order of ports checked + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + ports = find_free_ports(port_range=(100, 105), num_ports=3) + + # With offset 0, ports should be checked in order + assert checked_ports[:3] == [100, 101, 102] + assert ports == [100, 101, 102] + + def test_random_start_offset_non_zero(self): + """Test port scanning with non-zero random offset.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + # With offset 2, scanning starts from port 102 + with patch("fastdeploy.utils.random.randint", return_value=2): + ports = find_free_ports(port_range=(100, 105), num_ports=3) + + # With offset 2, ports are rotated: [102, 103, 104, 105, 100, 101] + assert checked_ports[:3] == [102, 103, 104] + assert ports == [102, 103, 104] + + def test_single_port_range(self): + """Test finding port from a single-port range.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(12345, 12345), num_ports=1) + assert ports == [12345] + + def test_single_port_range_not_available(self): + """Test RuntimeError when the single port in range is not available.""" + with patch("fastdeploy.utils.is_port_available", return_value=False): + with pytest.raises(RuntimeError, match="Only found 0 free ports"): + find_free_ports(port_range=(12345, 12345), num_ports=1) + + def test_default_parameters(self): + """Test function with default parameters.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports() + assert len(ports) == 1 + assert 8000 <= ports[0] <= 65535 + + def test_stops_early_when_enough_ports_found(self): + """Test that scanning stops as soon as enough ports are found.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + # Range has 100 ports but we only need 2 + ports = find_free_ports(port_range=(20000, 20099), num_ports=2) + + # Should only check 2 ports, not all 100 + assert len(checked_ports) == 2 + assert len(ports) == 2 + + def test_skips_unavailable_ports(self): + """Test that unavailable ports are skipped.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + # Only odd ports are available + return port % 2 == 1 + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + ports = find_free_ports(port_range=(100, 110), num_ports=3) + + # Should find 3 odd ports: 101, 103, 105 + assert len(ports) == 3 + assert all(p % 2 == 1 for p in ports) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/xpu_ci/conftest.py b/tests/xpu_ci/conftest.py index ae0c95d727a..dc6e4d30262 100644 --- a/tests/xpu_ci/conftest.py +++ b/tests/xpu_ci/conftest.py @@ -101,6 +101,13 @@ def safe_kill_cmd(cmd): for cmd in commands: safe_kill_cmd(cmd) + try: + # 清理/dev/shm下的所有文件 + subprocess.run("rm -rf /dev/shm/*", shell=True, check=True) + except subprocess.CalledProcessError: + print("Failed to remove files from /dev/shm") + pass + def cleanup_resources(): """