diff --git a/.coverage b/.coverage deleted file mode 100644 index 759f0418..00000000 Binary files a/.coverage and /dev/null differ diff --git a/src/code/agent/.coverage b/src/code/agent/.coverage deleted file mode 100644 index f74d16dc..00000000 Binary files a/src/code/agent/.coverage and /dev/null differ diff --git a/src/code/agent/constants.py b/src/code/agent/constants.py index 6515fa40..95697d28 100644 --- a/src/code/agent/constants.py +++ b/src/code/agent/constants.py @@ -167,6 +167,12 @@ # 日志配置 LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") +# 多租户/多用户模式配置 +ENABLE_COMFYUI_MULTI_USER = os.getenv('ENABLE_COMFYUI_MULTI_USER', '').lower() == 'true' + +# 用户身份识别相关常量 +HEADER_FUNART_COMFY_USERID = 'X-FunArt-Comfy-UserId' + class ERROR_CODE(Enum): UNCLASSIFY = "UNCLASSIFY" INVALID_PARAMS = "INVALID_PARAMS" diff --git a/src/code/agent/routes/gateway_routes.py b/src/code/agent/routes/gateway_routes.py index 118b53f0..8ea2440f 100644 --- a/src/code/agent/routes/gateway_routes.py +++ b/src/code/agent/routes/gateway_routes.py @@ -5,7 +5,7 @@ import time import traceback -from flask import Blueprint, Flask, jsonify, request +from flask import Blueprint, Flask, jsonify, request, g from flask_sock import Sock import websocket @@ -22,6 +22,7 @@ from services.gateway.handlers.userdata_handler import UserdataHandler from services.gateway.handlers.ws_handler import WsHandler from services.gateway.handlers.serverless_ws_handler import ServerlessWsHandler +from utils.user_identity import set_user_identity_or_default class GatewayRoutes: @@ -58,6 +59,7 @@ def register(self, app: Flask): def setup_routes(self): """设置所有路由""" self._register_backend_status_middleware() + self._register_user_identity_middleware() self._register_reboot_handler() # 只在 CPU 模式下注册这些路由 @@ -92,6 +94,12 @@ def check_backend_status(): status_code=500 ) + def _register_user_identity_middleware(self): + """注册用户身份识别中间件,在每个请求前识别用户""" + @self.bp.before_request + def identify_user(): + set_user_identity_or_default() + def _register_websocket(self): @self.sock.route("/ws") def comfyui_compatible_ws(ws): diff --git a/src/code/agent/routes/routes.py b/src/code/agent/routes/routes.py index 9674594a..7d7ed6f8 100644 --- a/src/code/agent/routes/routes.py +++ b/src/code/agent/routes/routes.py @@ -4,7 +4,7 @@ import threading import traceback -from flask import Flask, jsonify, request, Response +from flask import Flask, jsonify, request, Response, g from flask_sock import Sock import requests @@ -13,6 +13,7 @@ from services.management_service import ManagementService, Action, BackendStatus from utils.logger import log from utils.error_handler import ErrorResponse +from utils.user_identity import identify_user_or_default from .management_routes import ManagementRoutes from .serverless_api_routes import ServerlessApiRoutes from .gateway_routes import GatewayRoutes @@ -202,6 +203,7 @@ def do_save(result_queue, target_snapshot_name): @self.app.route("/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) @self.app.route("/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) + @identify_user_or_default def proxy(path=""): backend_status = self.management.service.status if backend_status not in (BackendStatus.RUNNING, BackendStatus.SAVING): @@ -216,10 +218,18 @@ def proxy(path=""): target_url = f"http://{constants.APP_HOST}{original_uri}" # print(f"Forwarding http request to path: {target_url}") + # 准备转发的 headers(添加用户标识用于多租户支持) + forward_headers = dict(request.headers) + forward_headers.pop(constants.HEADER_FUNART_COMFY_USERID, None) + + user_id = g.user_id + if user_id: + forward_headers[constants.HEADER_FUNART_COMFY_USERID] = user_id + resp = requests.request( method=request.method, url=target_url, - headers=dict(request.headers), + headers=forward_headers, params=request.args, data=request.get_data(), cookies=request.cookies, diff --git a/src/code/agent/services/gateway/handlers/history_handler.py b/src/code/agent/services/gateway/handlers/history_handler.py index a9a0e1c2..22cef950 100644 --- a/src/code/agent/services/gateway/handlers/history_handler.py +++ b/src/code/agent/services/gateway/handlers/history_handler.py @@ -29,347 +29,25 @@ def __init__(self): def handle_get_request(self): """处理 GET /api/history 请求""" - limit = self._parse_limit_param() - history = self._get_all_history_from_queue(limit) - log("DEBUG", f"Retrieved {len(history)} history items from queue") - return jsonify(history) - - def _get_all_history_from_queue(self, limit: Optional[int] = None) -> OrderedDict: - """ - 从队列中获取所有已完成任务的历史记录 - - Args: - limit: 可选的数量限制 - - Returns: - OrderedDict: prompt_id -> history_item - """ if not self._is_initialized(): - return OrderedDict() - - try: - # 获取并过滤已完成的任务 - ended_tasks = self._get_ended_tasks(limit) - - # 转换为历史记录格式 - result = OrderedDict() - total_count = len(ended_tasks) - for idx, task in enumerate(ended_tasks): - history_item = self._convert_task_to_history_item(task, total_count - idx) - if history_item: - prompt_id = list(history_item.keys())[0] - result[prompt_id] = history_item[prompt_id] - - return result - except Exception as e: - log("ERROR", f"Error getting all history from queue: {e}\n{traceback.format_exc()}") - return OrderedDict() - - def _convert_task_to_history_item(self, task, sequence_number: Optional[int] = None) -> Optional[Dict[str, Any]]: - """ - 将 Task 对象转换为 ComfyUI history 格式 - - Args: - task: Task 对象 - sequence_number: 可选的序号(用于排序) - - Returns: - Optional[Dict]: {prompt_id: {prompt, outputs, status, meta}} - """ - try: - prompt_id = self._extract_prompt_id(task) - outputs = self._build_outputs(task) - prompt_structure = self._build_prompt_structure(task, prompt_id, sequence_number) - status_obj = self._build_status_object(task, prompt_id) - meta = self._build_meta_object(outputs) - - return { - prompt_id: { - "prompt": prompt_structure, - "outputs": outputs, - "status": status_obj, - "meta": meta - } - } - except Exception as e: - log("ERROR", f"Error converting task {task.task_id} to history item: {e}\n{traceback.format_exc()}") - return None - - def _extract_prompt_id(self, task) -> str: - """ - 从 Task 对象中提取 prompt_id - - 优先级: - 1. task.final_status_data.data.prompt_id - 2. task.task_id (fallback) - """ - # 优先从 final_status_data 获取 - if task.final_status_data: - prompt_id = (task.final_status_data.get("data", {}) or {}).get("prompt_id") - if prompt_id: - return prompt_id - - # 最后使用 task_id 作为 fallback - return task.task_id - - def _build_outputs(self, task) -> Dict[str, Dict[str, List[Dict[str, str]]]]: - """ - 构建 outputs 对象 - - Returns: - Dict: {node_id: {output_type: [output_items]}} - """ - outputs = {} - - # 获取 results - results = task.results - if results is None and task.final_status_data: - results = (task.final_status_data.get("data", {}) or {}).get("results", []) - - # 处理每个 result - for result in (results or []): - node_id = result.get("node_id", "unknown") - output_data = result.get("output", {}) - output_type = output_data.get("type", "images") - - # 初始化结构 - if node_id not in outputs: - outputs[node_id] = {} - if output_type not in outputs[node_id]: - outputs[node_id][output_type] = [] - - # 构建输出项 - raw_data = output_data.get("raw", {}) - comfy_output = { - "filename": raw_data.get("filename", ""), - "subfolder": raw_data.get("subfolder", ""), - "type": raw_data.get("type", "output") - } - outputs[node_id][output_type].append(comfy_output) - - return outputs - - def _build_prompt_structure(self, task, prompt_id: str, sequence_number: Optional[int] = None) -> List: - """ - 构建 prompt 结构体 - - Returns: - List: [number, prompt_id, prompt, extra_data, outputs_to_execute] - """ - number = sequence_number if sequence_number is not None else ( - task.completed_at or task.create_at or 1.0 - ) - - # 安全地提取 prompt 和 extra_data - prompt_body = task.prompt_body or {} + return jsonify({}), 503 - # 兼容两种格式: - # 1. 新格式: {"prompt": {...}, "extra_data": {...}} - # 2. 旧格式: 直接是 prompt 工作流定义 - if isinstance(prompt_body, dict) and "prompt" in prompt_body: - prompt = prompt_body.get("prompt", {}) - extra_data = prompt_body.get("extra_data", {}) - else: - prompt = prompt_body - extra_data = {} - - return [ - number, - prompt_id, - prompt or {}, - extra_data or {}, - [] - ] - - def _build_status_object(self, task, prompt_id: str) -> Dict[str, Any]: - """ - 构建 status 对象 - - Returns: - Dict: {status_str, completed, messages} - """ - if task.status.name == "FAILED": - return self._build_error_status(task, prompt_id) - else: - return self._build_success_status(task, prompt_id) - - def _extract_execution_timestamps(self, task) -> tuple: - """ - 从 status_history 中提取 execution_start 和 execution_end 的时间戳 - - 优先级: - 1. 从 status_history 中找 execution_start 第一个消息 - 2. 从 status_history 中找结束标记(优先级:serverless_api > execution_success > execution_error) - 3. 作为 fallback,使用 task 的时间戳 - - Returns: - tuple: (start_ts_int, end_ts_int) 以毫秒为单位 - """ - def _normalize_timestamp(ts_value, default_ts_seconds): - """ - 将时间戳标准化为毫秒格式 - - Args: - ts_value: 可能的时间戳值(可能是秒或毫秒) - default_ts_seconds: 默认时间戳(秒格式) - - Returns: - int: 毫秒格式的时间戳 - """ - if ts_value is None: - return int(default_ts_seconds * 1000) - - # 判断是秒还是毫秒:如果小于 10000000000(约 2001年),认为是秒格式 - if ts_value < 10000000000: - return int(ts_value * 1000) - else: - return int(ts_value) - - start_ts = None - end_ts = None - - # 从 status_history 中提取时间戳 - for status_msg in task.status_history: - msg_type = status_msg.get("type") - msg_data = status_msg.get("data", {}) - - # 找 execution_start 消息 - if msg_type == "execution_start" and start_ts is None: - ts_value = msg_data.get("timestamp") - start_ts = _normalize_timestamp(ts_value, time.time()) - - # 找执行成功 或 失败的标记 - # 优先级:serverless_api > execution_success > execution_error - if msg_type == "serverless_api": - # serverless_api 消息可能没有 timestamp,只有 execution_time - ts_value = msg_data.get("timestamp") - if ts_value is None: - # 使用 execution_time 计算结束时间戳 - execution_time = msg_data.get("execution_time") - if execution_time is not None: - # 先尝试从 start_ts 计算 - if start_ts is not None: - # execution_time 是秒,转换为毫秒并加到开始时间 - end_ts = start_ts + int(execution_time * 1000) - else: - # 如果 start_ts 还没找到,使用 task.create_at 作为基准 - base_ts = task.create_at - if base_ts: - end_ts = int(base_ts * 1000) + int(execution_time * 1000) - else: - # 最后 fallback 到 task.completed_at - default_ts = task.completed_at - end_ts = int(default_ts * 1000) if default_ts else 0 - else: - # 如果没有 execution_time,使用 task.completed_at - default_ts = task.completed_at or task.create_at - end_ts = int(default_ts * 1000) - else: - default_ts = task.completed_at or task.create_at - end_ts = _normalize_timestamp(ts_value, default_ts) - elif msg_type == "execution_success" and end_ts is None: - ts_value = msg_data.get("timestamp") - default_ts = task.completed_at or task.create_at - end_ts = _normalize_timestamp(ts_value, default_ts) - elif msg_type == "execution_error" and end_ts is None: - ts_value = msg_data.get("timestamp") - default_ts = task.completed_at or task.create_at - end_ts = _normalize_timestamp(ts_value, default_ts) - - if start_ts is None: - start_ts = int(task.create_at * 1000) - - if end_ts is None: - end_ts = int((task.completed_at or task.create_at) * 1000) - - return start_ts, end_ts - - def _build_error_status(self, task, prompt_id: str) -> Dict[str, Any]: - """构建错误状态的 status 对象""" - data = (task.final_status_data.get("data", {}) if task.final_status_data else {}) - start_ts, error_ts = self._extract_execution_timestamps(task) - - return { - "status_str": "error", - "completed": True, - "messages": [ - ["execution_start", {"prompt_id": prompt_id, "timestamp": start_ts}], - ["execution_error", { - "prompt_id": prompt_id, - "node_id": data.get("node_id") or data.get("node", "unknown"), - "exception_message": data.get("exception_message", "Unknown error"), - "timestamp": error_ts - }] - ] - } - - def _build_success_status(self, task, prompt_id: str) -> Dict[str, Any]: - """构建成功状态的 status 对象""" - start_ts, end_ts = self._extract_execution_timestamps(task) - - return { - "status_str": "success", - "completed": True, - "messages": [ - ["execution_start", {"prompt_id": prompt_id, "timestamp": start_ts}], - ["execution_success", {"prompt_id": prompt_id, "timestamp": end_ts}] - ] - } - - def _build_meta_object(self, outputs: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: - """ - 构建 meta 对象(记录每个输出节点) - - Args: - outputs: outputs 字典 - - Returns: - Dict: {node_id: {node_id, display_node, parent_node, real_node_id}} - """ - meta = {} - for node_id in outputs.keys(): - meta[node_id] = { - "node_id": node_id, - "display_node": node_id, - "parent_node": None, - "real_node_id": node_id - } - return meta + max_items = self._parse_max_items_param() + history = self.task_manager.get_history(max_items=max_items) + return jsonify(history) def _is_initialized(self) -> bool: """检查服务是否已正确初始化""" return self.task_manager is not None and self.TaskStatus is not None - def _get_ended_tasks(self, limit: Optional[int] = None) -> List: - """ - 获取已结束的任务列表(按完成时间倒序) - - Args: - limit: 可选的数量限制 - - Returns: - List: 已结束的任务列表 - """ - tasks = self.task_manager.get_all_tasks() - ended = [ - t for t in tasks - if t.status in (self.TaskStatus.COMPLETED, self.TaskStatus.FAILED) - ] - ended.sort(key=lambda t: t.completed_at or 0, reverse=True) - - if isinstance(limit, int) and limit > 0: - ended = ended[:limit] - - return ended - - def _parse_limit_param(self) -> Optional[int]: - """解析请求中的 limit 参数""" - limit_param = request.args.get('limit') - if not limit_param: + def _parse_max_items_param(self) -> Optional[int]: + """解析请求中的 max_items 参数)""" + max_items_param = request.args.get('max_items') + if not max_items_param: return None try: - limit = int(limit_param) - return limit if limit > 0 else None + max_items = int(max_items_param) + return max_items if max_items > 0 else None except ValueError: return None diff --git a/src/code/agent/services/gateway/handlers/prompt_handler.py b/src/code/agent/services/gateway/handlers/prompt_handler.py index fb8288f3..0ca51be4 100644 --- a/src/code/agent/services/gateway/handlers/prompt_handler.py +++ b/src/code/agent/services/gateway/handlers/prompt_handler.py @@ -3,8 +3,9 @@ 处理 /prompt 请求逻辑 """ import traceback -from flask import request, jsonify +from flask import request, jsonify, g +import constants from utils.logger import log from exceptions.exceptions import TaskError, InternalError @@ -44,6 +45,12 @@ def handle_post_request(self): } }), 400 + # 注入 user_id 到 extra_data + user_id = getattr(g, 'user_id', 'default') + if 'extra_data' not in request_data: + request_data['extra_data'] = {} + request_data['extra_data'][constants.HEADER_FUNART_COMFY_USERID.lower()] = user_id + try: # 转发给GPU task_id, result = self.task_manager.forward_to_gpu_async( diff --git a/src/code/agent/services/gateway/handlers/queue_handler.py b/src/code/agent/services/gateway/handlers/queue_handler.py index 7058f64f..52e7a706 100644 --- a/src/code/agent/services/gateway/handlers/queue_handler.py +++ b/src/code/agent/services/gateway/handlers/queue_handler.py @@ -24,7 +24,7 @@ def handle_get_request(self): try: # 获取任务列表 - all_tasks = self.task_manager.get_all_tasks() + all_tasks = self.task_manager.get_current_user_tasks() except Exception as e: log("ERROR", f"Error fetching tasks for queue request: {e}") return jsonify({ @@ -79,8 +79,6 @@ def handle_post_request(self): Returns: Flask response """ - log("DEBUG", f"Handling POST /api/queue request") - request_data = request.get_json() or {} if "clear" in request_data and request_data["clear"]: diff --git a/src/code/agent/services/gateway/handlers/ws_handler.py b/src/code/agent/services/gateway/handlers/ws_handler.py index 090033c5..ad27e894 100644 --- a/src/code/agent/services/gateway/handlers/ws_handler.py +++ b/src/code/agent/services/gateway/handlers/ws_handler.py @@ -5,12 +5,13 @@ import time import traceback -from flask import request +from flask import request, g from services.gateway import get_task_manager from services.gateway.task.task import TaskStatus from services.process.websocket.websocket_manager import ws_manager from utils.logger import log +from utils.user_identity import extract_user_from_header class WsHandler: @@ -44,8 +45,13 @@ def handle_connection(self, ws): client_id = f"funart_client_{int(time.time() * 1000)}" log("INFO", f"New ComfyUI WebSocket connection with client_id: {client_id}") - # 添加连接到管理器(同时关联 client_id,处理重连逻辑) - ws_manager.add_connection(ws, client_id) + user_id = extract_user_from_header() + g.user_id = user_id if user_id is not None else 'default' + + log("INFO", f"[WsHandler] WebSocket connection established: client_id={client_id}, user_id={g.user_id}") + + # 添加连接到管理器(同时关联 client_id 和 user_id,处理重连逻辑) + ws_manager.add_connection(ws, client_id, user_id) # 通过消息队列发送初始状态消息,保证线程安全 initial_status = { @@ -54,7 +60,7 @@ def handle_connection(self, ws): "sid": client_id, "status": { "exec_info": { - "queue_remaining": self.task_manager.get_running_task_count() + "queue_remaining": self.task_manager.get_running_task_count_by_user(g.user_id) } } } diff --git a/src/code/agent/services/gateway/task/__init__.py b/src/code/agent/services/gateway/task/__init__.py index a8015033..2c33720e 100644 --- a/src/code/agent/services/gateway/task/__init__.py +++ b/src/code/agent/services/gateway/task/__init__.py @@ -1,10 +1,12 @@ from .task import TaskStatus, Task from .task_manager import TaskManager, get_task_manager +from .history_manager import HistoryManager __all__ = [ 'TaskStatus', 'Task', 'TaskManager', - 'get_task_manager' + 'get_task_manager', + 'HistoryManager' ] diff --git a/src/code/agent/services/gateway/task/history_manager.py b/src/code/agent/services/gateway/task/history_manager.py new file mode 100644 index 00000000..04dab925 --- /dev/null +++ b/src/code/agent/services/gateway/task/history_manager.py @@ -0,0 +1,467 @@ +""" +History 管理模块 +负责管理任务历史记录的创建、更新和查询 +""" +import time +import threading +import traceback +from typing import Dict, Any, Optional +from collections import defaultdict + +from utils.logger import log + + +class HistoryManager: + """ + 历史记录管理类 + 负责管理任务执行历史的存储和查询 + + 注意:此类是线程安全的,所有公共方法都使用独立的锁保护 + """ + + def __init__(self): + """初始化历史记录存储""" + # 历史记录主存储: {prompt_id: {prompt, outputs, status, meta, user_id}} + self.history: Dict[str, dict] = {} + + # 按用户分组的历史记录: {user_id: {prompt_id: history_item}} + self._history_by_user: Dict[str, Dict[str, dict]] = defaultdict(dict) + + # 独立的锁,保护 history 和 _history_by_user 的并发访问 + self._lock = threading.Lock() + + def get_history(self, user_id: str, max_items=None, offset: int = -1) -> Dict[Any, Any]: + """ + 获取用户的历史记录 + + Args: + user_id: 用户ID + max_items: 最大返回数量 + offset: 偏移量,-1 表示从末尾开始 + + Returns: + dict: 历史记录字典,格式为 {prompt_id: {prompt, outputs, status, meta, user_id}} + 只返回已完成(status.completed == True)的历史记录 + """ + with self._lock: + user_history = self._history_by_user.get(user_id, {}) + + # 过滤出已完成的历史记录(不复制数据,直接引用) + completed_history = {} + for prompt_id, history_item in user_history.items(): + status = history_item.get("status", {}) + if status.get("completed", False): + # 直接使用原始 history_item,不复制 + # 因为已经通过 user_id 过滤,返回的都是当前用户自己的数据 + completed_history[prompt_id] = history_item + + # 应用 offset 和 max_items 限制 + out = {} + i = 0 + if offset < 0 and max_items is not None: + offset = len(completed_history) - max_items + for k in completed_history: + if i >= offset: + out[k] = completed_history[k] + if max_items is not None and len(out) >= max_items: + break + i += 1 + return out + + def add_history_item(self, prompt_id: str, history_item: dict) -> bool: + """ + 原子地添加 history item 到两个字典 + + Args: + prompt_id: prompt ID + history_item: history 数据项(必须包含 user_id) + + Returns: + bool: 是否成功添加(False表示已存在或缺少user_id) + """ + with self._lock: + try: + user_id = history_item.get("user_id") + if not user_id: + log("ERROR", f"[HistoryManager] Cannot add history item for {prompt_id}: user_id is missing") + return False + + # 检查是否已存在 + if prompt_id in self.history: + log("DEBUG", f"[HistoryManager] history_item for prompt_id {prompt_id} already exists") + return False + + # 原子性添加到两个字典 + try: + self.history[prompt_id] = history_item + self._history_by_user[user_id][prompt_id] = history_item + log("DEBUG", f"[HistoryManager] Added history_item for prompt_id {prompt_id}, user {user_id}") + return True + except Exception as e: + # 回滚:如果出错,确保两个字典保持一致 + self.history.pop(prompt_id, None) + self._history_by_user[user_id].pop(prompt_id, None) + log("ERROR", f"[HistoryManager] Failed to add history item for {prompt_id}, rolled back: {e}") + return False + + except Exception as e: + log("ERROR", f"[HistoryManager] Error in add_history_item for {prompt_id}: {e}\n{traceback.format_exc()}") + return False + + def get_history_item(self, prompt_id: str) -> Optional[dict]: + """ + 获取指定 prompt_id 的 history_item + + Args: + prompt_id: prompt ID + + Returns: + dict: history_item,如果不存在则返回 None + """ + with self._lock: + return self.history.get(prompt_id) + + def remove_history_item(self, prompt_id: str) -> bool: + """ + 原子地从两个字典移除 history item + + Args: + prompt_id: prompt ID + + Returns: + bool: 是否成功移除 + """ + with self._lock: + try: + history_item = self.history.get(prompt_id) + if not history_item: + return False + + user_id = history_item.get("user_id") + + # 原子性删除 + self.history.pop(prompt_id, None) + if user_id: + self._history_by_user[user_id].pop(prompt_id, None) + + log("DEBUG", f"[HistoryManager] Removed history_item for prompt_id {prompt_id}") + return True + + except Exception as e: + log("ERROR", f"[HistoryManager] Error in remove_history_item for {prompt_id}: {e}\n{traceback.format_exc()}") + return False + + def init_history_item(self, prompt_id: str, prompt_body: dict, client_id: str, + user_id: str, message: dict) -> bool: + """ + 初始化 history_item(在 execution_start 时调用) + + Args: + prompt_id: prompt ID + prompt_body: 任务的 prompt 数据 + client_id: 客户端ID + user_id: 用户ID + message: execution_start 消息 + + Returns: + bool: 是否成功初始化 + """ + with self._lock: + try: + # 检查是否已存在 + if prompt_id in self.history: + log("DEBUG", f"[HistoryManager] history_item for prompt_id {prompt_id} already exists") + return False + + # 立即设置占位符,防止其他线程重复初始化 + self.history[prompt_id] = {"_initializing": True, "user_id": user_id} + + try: + # 构造 history_item + history_item = self._build_history_item( + prompt_id=prompt_id, + prompt_body=prompt_body, + client_id=client_id, + user_id=user_id, + message=message + ) + + # 最终检查:确保占位符还在(没有被其他操作删除) + current = self.history.get(prompt_id) + if current and current.get("_initializing"): + # 替换占位符为完整数据 + self.history[prompt_id] = history_item + self._history_by_user[user_id][prompt_id] = history_item + log("DEBUG", f"[HistoryManager] Initialized history_item for prompt_id {prompt_id}") + return True + else: + log("WARNING", f"[HistoryManager] history_item for prompt_id {prompt_id} was modified during initialization") + return False + + except Exception as e: + # 清理占位符 + current = self.history.get(prompt_id) + if current and current.get("_initializing"): + self.history.pop(prompt_id, None) + log("ERROR", f"[HistoryManager] Failed to build history_item for {prompt_id}: {e}\n{traceback.format_exc()}") + return False + + except Exception as e: + log("ERROR", f"[HistoryManager] Error initializing history_item for prompt_id {prompt_id}: {e}\n{traceback.format_exc()}") + return False + + def _build_history_item(self, prompt_id: str, prompt_body: dict, client_id: str, + user_id: str, message: dict) -> dict: + """ + 构造 history_item(辅助方法) + + Args: + prompt_id: prompt ID + prompt_body: 任务的 prompt 数据 + client_id: 客户端ID + user_id: 用户ID + message: execution_start 消息 + + Returns: + dict: 构造好的 history_item + """ + # 提取 prompt 节点定义和 outputs_to_execute + outputs_to_execute = [] + prompt_dict = prompt_body or {} + + # 处理 prompt_body 可能是不同格式的情况 + if isinstance(prompt_dict, dict): + # 情况1: prompt_body 是包含 prompt 字段的对象,如 {prompt: {...}, outputs_to_execute: [...]} + if "prompt" in prompt_dict and isinstance(prompt_dict.get("prompt"), dict): + outputs_to_execute = prompt_dict.get("outputs_to_execute", []) + prompt_dict = prompt_dict["prompt"] + # 情况2: prompt_body 本身就是节点定义的字典(最常见的情况) + # 这种情况下 outputs_to_execute 通常为空,由 ComfyUI 自动推断 + + # 构造 extra_data + extra_data = {} + if client_id: + extra_data["client_id"] = client_id + + # 使用时间戳作为序号(确保唯一性) + sequence_number = int(time.time() * 1000000) % 1000000000 # 微秒时间戳 + + # 提取时间戳 + msg_data = message.get("data", {}) + timestamp = msg_data.get("timestamp") + if timestamp is None: + timestamp = int(time.time() * 1000) + else: + # 标准化时间戳为毫秒 + if timestamp < 10000000000: + timestamp = int(timestamp * 1000) + else: + timestamp = int(timestamp) + + # 构造 prompt 数组,格式:[number, prompt_id, prompt_dict, extra_data, outputs_to_execute] + return { + "meta": {}, + "outputs": {}, + "prompt": [ + sequence_number, + prompt_id, + prompt_dict, + extra_data, + outputs_to_execute + ], + "status": { + "status_str": "running", + "completed": False, + "messages": [ + ["execution_start", {"prompt_id": prompt_id, "timestamp": timestamp}] + ] + }, + "user_id": user_id + } + + def update_history_status(self, message: dict, status_str: str) -> bool: + """ + 更新 history_item 的 status + + Args: + message: 消息数据(execution_success、execution_error、execution_cached) + status_str: 状态字符串("success"、"error"、"running") + + Returns: + bool: 是否成功更新 + """ + with self._lock: + try: + data = message.get("data", {}) + prompt_id = data.get("prompt_id") + if not prompt_id: + return False + + history_item = self.history.get(prompt_id) + if not history_item: + log("WARNING", f"[HistoryManager] Cannot update history status: history_item not found for prompt_id {prompt_id}") + return False + + if "status" not in history_item: + history_item["status"] = { + "status_str": status_str, + "completed": False, + "messages": [] + } + + status = history_item["status"] + + # 提取时间戳 + timestamp = data.get("timestamp") + if timestamp is None: + timestamp = int(time.time() * 1000) + else: + # 标准化时间戳为毫秒 + if timestamp < 10000000000: + timestamp = int(timestamp * 1000) + else: + timestamp = int(timestamp) + + # 更新状态 + status["status_str"] = status_str + + # 根据状态类型添加消息 + if status_str == "success": + status["completed"] = True + # 添加 execution_success 消息(如果还没有) + if not any(msg[0] == "execution_success" for msg in status.get("messages", [])): + status.setdefault("messages", []).append( + ["execution_success", {"prompt_id": prompt_id, "timestamp": timestamp}] + ) + elif status_str == "error": + status["completed"] = True + # 添加 execution_error 消息(如果还没有) + if not any(msg[0] == "execution_error" for msg in status.get("messages", [])): + error_info = { + "prompt_id": prompt_id, + "node_id": data.get("node_id") or data.get("node", "unknown"), + "exception_message": data.get("exception_message", "Unknown error"), + "timestamp": timestamp + } + status.setdefault("messages", []).append(["execution_error", error_info]) + elif status_str == "running": + # execution_cached 或其他运行中状态,不改变 completed 标志 + # 可以添加 execution_cached 消息 + if message.get("type") == "execution_cached": + if not any(msg[0] == "execution_cached" for msg in status.get("messages", [])): + status.setdefault("messages", []).append( + ["execution_cached", {"prompt_id": prompt_id, "timestamp": timestamp}] + ) + + return True + + except Exception as e: + log("ERROR", f"[HistoryManager] Error updating history status: {e}\n{traceback.format_exc()}") + return False + + def update_history_outputs(self, message: dict) -> bool: + """ + 更新 history_item 的 outputs 和 meta(处理 executed 消息) + + Args: + message: executed 消息数据 + + Returns: + bool: 是否成功更新 + """ + with self._lock: + try: + data = message.get("data", {}) + prompt_id = data.get("prompt_id") + node_id = data.get("node") + + if not prompt_id or not node_id: + log("WARNING", f"[HistoryManager] executed message missing prompt_id or node_id") + return False + + history_item = self.history.get(prompt_id) + if not history_item: + log("WARNING", f"[HistoryManager] History item not found for prompt_id {prompt_id}") + return False + + # 检查是否是初始化占位符 + if history_item.get("_initializing"): + log("WARNING", f"[HistoryManager] History item for {prompt_id} is still initializing, skipping executed message") + return False + + # 处理节点输出:构造 meta + if "meta" not in history_item: + history_item["meta"] = {} + + history_item["meta"][node_id] = { + "node_id": node_id, + "display_node": data.get("display_node", node_id), + "parent_node": None, + "real_node_id": node_id + } + + # 构造 outputs,从 output.images 中获取图片信息 + output_data = data.get("output", {}) + images = output_data.get("images", []) + + if images: + if "outputs" not in history_item: + history_item["outputs"] = {} + if node_id not in history_item["outputs"]: + history_item["outputs"][node_id] = {} + if "images" not in history_item["outputs"][node_id]: + history_item["outputs"][node_id]["images"] = [] + + for img in images: + image_item = { + "filename": img.get("filename", ""), + "type": img.get("type", "output"), + "subfolder": img.get("subfolder", "") + } + history_item["outputs"][node_id]["images"].append(image_item) + + return True + + except Exception as e: + log("ERROR", f"[HistoryManager] Error updating history outputs: {e}\n{traceback.format_exc()}") + return False + + def late_init_history_item(self, task_id: str, prompt_id: str, prompt_body: dict, + client_id: str, user_id: str) -> bool: + """ + 延迟初始化历史项(当 executed 消息到达但 history_item 尚未创建时) + + Args: + task_id: 任务ID + prompt_id: prompt ID + prompt_body: 任务的 prompt 数据 + client_id: 客户端ID + user_id: 用户ID + + Returns: + bool: 是否成功初始化 + """ + with self._lock: + try: + if prompt_id in self.history: + return False + + log("INFO", f"[HistoryManager] Late-initializing history_item for prompt_id {prompt_id}") + history_item = { + "meta": {}, + "outputs": {}, + "prompt": [0, prompt_id, prompt_body or {}, {"client_id": client_id}, []], + "status": { + "status_str": "running", + "completed": False, + "messages": [] + }, + "user_id": user_id + } + self.history[prompt_id] = history_item + self._history_by_user[user_id][prompt_id] = history_item + return True + + except Exception as e: + log("ERROR", f"[HistoryManager] Error in late_init_history_item for {prompt_id}: {e}\n{traceback.format_exc()}") + return False diff --git a/src/code/agent/services/gateway/task/task.py b/src/code/agent/services/gateway/task/task.py index 1fcb4b19..80627c76 100644 --- a/src/code/agent/services/gateway/task/task.py +++ b/src/code/agent/services/gateway/task/task.py @@ -3,8 +3,8 @@ 定义任务的状态枚举和数据结构 """ import time -from typing import Optional, Callable, Any -from dataclasses import dataclass, field +from typing import Optional +from dataclasses import dataclass from enum import Enum @@ -27,21 +27,13 @@ def is_active(self) -> bool: @dataclass class Task: """任务数据模型""" - task_id: str + task_id: str # 任务ID,同时也作为 ComfyUI history 的 prompt_id client_id: str prompt_body: dict - callback: Optional[Callable] = None + user_id: str # 任务所属用户ID status: TaskStatus = TaskStatus.PENDING - - # 时间戳 - create_at: float = field(default_factory=time.time) completed_at: Optional[float] = None - # 任务执行结果与状态历史(用于 history.json 构造) - results: Optional[list] = None # serverless_api 最终结果 data.results - final_status_data: Optional[dict] = None # 最终状态整包(serverless_api 或 error/execution_error) - status_history: list = field(default_factory=list) - def update_status(self, new_status: TaskStatus) -> bool: """ 更新任务状态,包含状态转换校验 diff --git a/src/code/agent/services/gateway/task/task_manager.py b/src/code/agent/services/gateway/task/task_manager.py index e29a0621..379c127f 100644 --- a/src/code/agent/services/gateway/task/task_manager.py +++ b/src/code/agent/services/gateway/task/task_manager.py @@ -7,11 +7,12 @@ import time import traceback import uuid -from typing import Dict, Optional, Callable, List, Tuple, Union +from collections import defaultdict +from typing import Dict, Optional, Callable, List, Tuple, Union, Any import constants import requests -from flask import request +from flask import request, g from utils.logger import log from exceptions.exceptions import ( ConfigurationError, @@ -23,6 +24,7 @@ from .task import TaskStatus, Task from .utils.task_manager_util import TaskStatusBroadcaster +from .history_manager import HistoryManager class TaskManager: @@ -40,10 +42,16 @@ def __init__(self, # 任务字典 self._tasks: Dict[str, Task] = {} self._lock = threading.Lock() + + # 历史记录辅助器 + self._history_manager = HistoryManager() # 已完成任务计数器(COMPLETED + FAILED) self._completed_task_count = 0 + # {user_id: running_count} + self._running_count_by_user: Dict[str, int] = defaultdict(int) + # 消息轮询器管理 self._message_pollers: Dict[str, 'MessagesPoller'] = {} self._poller_lock = threading.Lock() @@ -63,16 +71,17 @@ def stop(self): def submit_task(self, prompt_body: dict, client_id: str, - task_id: Optional[str] = None, - callback: Optional[Callable] = None) -> str: + task_id: Optional[str] = None) -> str: if task_id is None: task_id = str(uuid.uuid4()) + user_id = g.user_id + task_request = Task( task_id=task_id, client_id=client_id, prompt_body=prompt_body, - callback=callback + user_id=user_id ) # 检查活跃任务数量 @@ -90,88 +99,163 @@ def submit_task(self, ) self._tasks[task_id] = task_request + self._running_count_by_user[user_id] += 1 # 启动状态轮询(监控GPU函数端的状态) self._start_polling(task_id) # 广播队列状态更新(任务提交后) - TaskStatusBroadcaster.broadcast_queue_status(self.get_running_task_count) + TaskStatusBroadcaster.broadcast_queue_status() return task_id - def get_all_tasks(self) -> List[Task]: + def get_current_user_tasks(self) -> List[Task]: + """获取当前用户的所有任务""" + user_id = g.user_id with self._lock: - return list(self._tasks.values()) + return [task for task in self._tasks.values() if task.user_id == user_id] - def get_running_task_count(self) -> int: - """获取运行中任务数量(PENDING和RUNNING状态)""" + def get_task(self, task_id: str) -> Optional[Task]: + """ + 根据任务ID获取任务对象 + + Args: + task_id: 任务ID + + Returns: + Task对象,如果任务不存在则返回None + """ with self._lock: - return sum( - 1 for task in self._tasks.values() - if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING] - ) + return self._tasks.get(task_id) + + def get_running_task_count_by_user(self, user_id: str) -> int: + """ + 获取指定用户的运行中任务数量 + + Args: + user_id: 用户ID + + Returns: + 运行中任务数量(包括 PENDING 和 RUNNING 状态) + """ + with self._lock: + return self._running_count_by_user.get(user_id, 0) + + def get_history(self, max_items=None, offset: int = -1) -> dict[Any, Any]: + """ + 获取历史记录(从 history_helper 获取) + + Args: + max_items: 最大返回数量 + offset: 偏移量,-1 表示从末尾开始 + + Returns: + dict: 历史记录字典,格式为 {prompt_id: {prompt, outputs, status, meta, user_id}} + 只返回已完成(status.completed == True)且属于当前用户的历史记录 + """ + user_id = g.user_id + + # HistoryManager 已经是线程安全的,不需要额外加锁 + return self._history_manager.get_history(user_id, max_items, offset) def clear_queue(self) -> int: """ - 清空队列:删除 PENDING 任务和所有已完成任务(COMPLETED/FAILED) + 清空队列:删除当前用户的 PENDING 任务和所有已完成任务(COMPLETED/FAILED) 不删除 RUNNING 状态的任务(正在执行) Returns: int: 清理的任务数量 """ - cleared_count = 0 + user_id = g.user_id + tasks_to_cleanup = [] # 存储需要清理的任务信息(在锁外清理 history) + # 在锁内:收集、删除任务,更新计数器 with self._lock: - task_ids_to_remove = [] + # 获取当前用户的所有任务 + user_tasks = [task for task in self._tasks.values() if task.user_id == user_id] + + pending_count = 0 completed_count = 0 - - for task_id, task in self._tasks.items(): - # 清理 PENDING 和已完成的任务 + cleared_count = 0 + + for task in user_tasks: + # 只删除 PENDING 和已完成的任务 if task.status == TaskStatus.PENDING: - task_ids_to_remove.append(task_id) + if self._tasks.pop(task.task_id, None): + cleared_count += 1 + pending_count += 1 + tasks_to_cleanup.append(task.task_id) + elif task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED): - task_ids_to_remove.append(task_id) - completed_count += 1 - - # 执行删除 - for task_id in task_ids_to_remove: - self._tasks.pop(task_id, None) - cleared_count += 1 - + if self._tasks.pop(task.task_id, None): + cleared_count += 1 + completed_count += 1 + tasks_to_cleanup.append(task.task_id) + # 更新已完成任务计数器 if completed_count > 0: self._completed_task_count = max(0, self._completed_task_count - completed_count) + + # 更新运行中任务计数(只需减去被清理的 PENDING 任务) + if pending_count > 0: + self._running_count_by_user[user_id] -= pending_count + + # 清理 history + for task_id in tasks_to_cleanup: + self._history_manager.remove_history_item(task_id) # 广播队列状态更新 if cleared_count > 0: - TaskStatusBroadcaster.broadcast_queue_status(self.get_running_task_count) - log("INFO", f"[TaskManager] Cleared {cleared_count} tasks (including {completed_count} completed)") + TaskStatusBroadcaster.broadcast_queue_status() + log("INFO", f"[TaskManager] User {user_id} cleared {cleared_count} tasks (including {completed_count} completed)") return cleared_count def cancel_task(self, task_id: str) -> bool: - # 先停止状态轮询 - self._stop_polling(task_id) + """ + 取消任务(只能取消当前用户的任务) + + Args: + task_id: 要取消的任务ID + + Returns: + bool: 是否成功取消(True表示成功,False表示任务不存在、不可取消或无权限) + """ + user_id = g.user_id - # 取消任务 with self._lock: task = self._tasks.get(task_id) if not task: return False - task_status = task.status + # 只允许取消自己的任务 + if task.user_id != user_id: + log("WARNING", f"[TaskManager] User {user_id} cannot cancel task {task_id} (belongs to {task.user_id})") + return False + + old_status = task.status # 只能取消未开始执行或正在处理的任务 - if task_status not in [TaskStatus.PENDING, TaskStatus.RUNNING]: + if old_status not in [TaskStatus.PENDING, TaskStatus.RUNNING]: + log("DEBUG", f"[TaskManager] Cannot cancel task {task_id}: status is {old_status}") return False # 从内存中删除 self._tasks.pop(task_id, None) - cancelled = True + + # 减少运行中任务计数 + self._running_count_by_user[user_id] -= 1 + + log("INFO", f"[TaskManager] User {user_id} cancelled task {task_id} (was {old_status})") + + # 停止轮询 + try: + self._stop_polling(task_id) + except Exception as e: + log("ERROR", f"[TaskManager] Error stopping polling for cancelled task {task_id}: {e}") # 广播队列状态更新 - if cancelled: - TaskStatusBroadcaster.broadcast_queue_status(self.get_running_task_count) + TaskStatusBroadcaster.broadcast_queue_status() return True @@ -253,25 +337,123 @@ def handle_message(self, task_id: str, message: Union[dict, str]) -> None: if status_type == 'execution_start': # 任务开始执行 self._update_task_status(task_id, message, TaskStatus.RUNNING) + # 初始化 history_item(在 execution_start 时最合适) + self._init_history_item(task_id, message) elif status_type == 'execution_success': # 任务完成 self._update_task_status(task_id, message, TaskStatus.COMPLETED) - + # 更新 history_item 的 status + self._update_history_status(message, "success") + + elif status_type == 'execution_cached': + # 节点执行缓存(节点已缓存,跳过执行) + # 更新 history_item 的 status + self._update_history_status(message, "running") + + elif status_type == 'executed': + # 单节点任务执行结束 + data = message.get("data", {}) + prompt_id = data.get("prompt_id") + node_id = data.get("node") + + if not prompt_id: + log("WARNING", f"[TaskManager] executed message missing prompt_id for task {task_id}") + return + + if not node_id: + log("WARNING", f"[TaskManager] executed message missing node_id for task {task_id}") + return + + # 检查 history_item 是否存在(HistoryManager 已经是线程安全的) + history_item = self._history_manager.get_history_item(prompt_id) + + # 如果 history_item 不存在,尝试延迟初始化 + if not history_item: + log("WARNING", f"[TaskManager] History item not found for prompt_id {prompt_id} in executed message") + + # 需要在锁内获取 task 信息 + with self._lock: + task = self._tasks.get(task_id) + if not task: + log("ERROR", f"[TaskManager] Cannot process executed message: task {task_id} not found") + return + + # 验证 prompt_id 是否匹配 + if task.task_id != prompt_id: + log("ERROR", f"[TaskManager] prompt_id mismatch: task.task_id={task.task_id}, message.prompt_id={prompt_id}") + return + + # 保存任务信息(在锁外调用 HistoryManager) + task_prompt_body = task.prompt_body + task_client_id = task.client_id + task_user_id = task.user_id + + # 延迟初始化(HistoryManager 已经是线程安全的) + self._history_manager.late_init_history_item( + task_id=task_id, + prompt_id=prompt_id, + prompt_body=task_prompt_body, + client_id=task_client_id, + user_id=task_user_id + ) + + # 更新 history outputs(HistoryManager 已经是线程安全的) + self._history_manager.update_history_outputs(message) elif status_type == 'execution_error': # 任务失败 self._update_task_status(task_id, message, TaskStatus.FAILED) + # 更新 history_item 的 status + self._update_history_status(message, "error") elif status_type == 'status': # 忽略纯 status 消息, agent的队列代替comfyui自己的队列 return - self._record_task_status(task_id, message) except Exception as e: log("DEBUG", f"[TaskManager] Error handling message for task {task_id}: {e}") TaskStatusBroadcaster.broadcast_task_status(task_id, message) + def _init_history_item(self, task_id: str, message: dict) -> None: + """ + 初始化 history_item(在 execution_start 时调用) + Args: + task_id: 任务ID + message: execution_start 消息 + """ + try: + # 提取 prompt_id(如果消息中没有,就使用 task_id) + data = message.get("data", {}) + prompt_id = data.get("prompt_id") or task_id + + # 第二步:获取任务信息 + with self._lock: + task = self._tasks.get(task_id) + if not task: + log("WARNING", f"[TaskManager] Cannot initialize history_item: task {task_id} not found") + return + + user_id = task.user_id + prompt_body = task.prompt_body + client_id = task.client_id + + # 第三步:委托给 HistoryManager 处理(在锁外) + # HistoryManager 已经是线程安全的,不需要额外加锁 + self._history_manager.init_history_item( + prompt_id=prompt_id, + prompt_body=prompt_body, + client_id=client_id, + user_id=user_id, + message=message + ) + + except Exception as e: + log("ERROR", f"[TaskManager] Error initializing history_item for task {task_id}: {e}\n{traceback.format_exc()}") + + def _update_history_status(self, message: dict, status_str: str) -> None: + self._history_manager.update_history_status(message, status_str) + def _update_task_status(self, task_id: str, status_data: dict, target_status: TaskStatus) -> bool: """ 更新任务状态(包括记录 + 转换) @@ -285,9 +467,6 @@ def _update_task_status(self, task_id: str, status_data: dict, target_status: Ta bool: 是否成功更新 """ try: - # 先记录状态历史 - self._record_task_status(task_id, status_data) - # 执行原子更新 with self._lock: task = self._tasks.get(task_id) @@ -298,7 +477,18 @@ def _update_task_status(self, task_id: str, status_data: dict, target_status: Ta if task.status == target_status: return False + # 记录旧状态用于更新计数 + old_status = task.status success = task.update_status(target_status) + + # 更新运行中任务计数 + if success: + old_is_running = old_status in [TaskStatus.PENDING, TaskStatus.RUNNING] + new_is_running = target_status in [TaskStatus.PENDING, TaskStatus.RUNNING] + + if old_is_running and not new_is_running: + # 从运行中变为完成/失败 + self._running_count_by_user[task.user_id] -= 1 if success and target_status in (TaskStatus.COMPLETED, TaskStatus.FAILED): # 增加已完成任务计数 @@ -306,7 +496,7 @@ def _update_task_status(self, task_id: str, status_data: dict, target_status: Ta self._completed_task_count += 1 # 检查是否需要清理旧任务 self._cleanup_old_completed_tasks_if_needed() - TaskStatusBroadcaster.broadcast_queue_status(self.get_running_task_count) + TaskStatusBroadcaster.broadcast_queue_status() return success @@ -316,40 +506,6 @@ def _update_task_status(self, task_id: str, status_data: dict, target_status: Ta traceback.print_exc() return False - def _record_task_status(self, task_id: str, status_data: dict) -> bool: - """ - 记录任务状态到任务对象(用于 history 构造) - - Args: - task_id: 任务ID - status_data: 状态数据 - - Returns: - bool: 是否成功记录 - """ - try: - with self._lock: - task = self._tasks.get(task_id) - if not task: - return False - - # 只存储 history 需要的消息类型 - s_type = status_data.get("type") - if s_type in ("execution_start", "serverless_api", "execution_success", "execution_error", "error"): - task.status_history.append(status_data) - - if s_type == "serverless_api": - task.final_status_data = status_data - data = status_data.get("data", {}) - task.results = data.get("results") - elif s_type in ("error", "execution_error"): - task.final_status_data = status_data - - return True - except Exception as e: - log("ERROR", f"[TaskManager] Error recording status for task {task_id}: {e}") - return False - # FIXME:@dehui.kdh, 实现更优雅的清理逻辑 def _cleanup_old_completed_tasks_if_needed(self) -> None: """检查是否需要清理旧的已完成任务(基于计数器,只在超过阈值时才执行清理)""" @@ -358,6 +514,9 @@ def _cleanup_old_completed_tasks_if_needed(self) -> None: return try: + # 收集需要删除的任务(在锁内) + tasks_to_remove = [] + with self._lock: # 再次检查(双重检查锁定) if self._completed_task_count <= self._max_completed_tasks: @@ -375,17 +534,27 @@ def _cleanup_old_completed_tasks_if_needed(self) -> None: # 获取需要删除的任务(保留最新的 max_completed_tasks 个) tasks_to_remove = completed_tasks[self._max_completed_tasks:] - # 删除旧任务 + # 删除任务(但先不删除 history) removed_count = 0 - for task_id, _ in tasks_to_remove: + for task_id, task in tasks_to_remove: if self._tasks.pop(task_id, None): removed_count += 1 # 更新计数器 self._completed_task_count = len(completed_tasks) - removed_count - - if removed_count > 0: - log("INFO", f"[TaskManager] Cleaned up {removed_count} old completed tasks (keeping latest {self._max_completed_tasks}, current count: {self._completed_task_count})") + + # 在锁外删除 history(避免长时间持锁) + removed_history_count = 0 + for task_id, task in tasks_to_remove: + # 防御性检查:确保 history 的 user_id 与 task 一致 + # (使用 HistoryManager 的线程安全方法) + history_item = self._history_manager.get_history_item(task.task_id) + if history_item and history_item.get('user_id') == task.user_id: + if self._history_manager.remove_history_item(task.task_id): + removed_history_count += 1 + + if removed_count > 0: + log("INFO", f"[TaskManager] Cleaned up {removed_count} old tasks and {removed_history_count} history items (keeping latest {self._max_completed_tasks}, current count: {self._completed_task_count})") except Exception as e: log("ERROR", f"[TaskManager] Error cleaning up old completed tasks: {e}") @@ -710,7 +879,6 @@ def _is_message_completed(message: Union[dict, str]) -> bool: try: message_type = message.get("type", "") except Exception as e: - log("DEBUG", f"[MessagesPoller] Error checking message completion: {e}") return False # 只有收到 serverless_api 时才认为任务完成 @@ -728,4 +896,4 @@ def get_task_manager() -> TaskManager: _task_manager = TaskManager() # 启动任务管理器 _task_manager.start() - return _task_manager + return _task_manager \ No newline at end of file diff --git a/src/code/agent/services/gateway/task/utils/task_manager_util.py b/src/code/agent/services/gateway/task/utils/task_manager_util.py index e15815c3..0bb31039 100644 --- a/src/code/agent/services/gateway/task/utils/task_manager_util.py +++ b/src/code/agent/services/gateway/task/utils/task_manager_util.py @@ -12,7 +12,6 @@ class TaskStatusBroadcaster: """处理任务状态的WebSocket广播""" - # FIXME:@dehui.kdh, 不是广播,而是直接推送给客户端 @staticmethod def broadcast_task_status(task_id: str, status_data: Union[dict, str]): """ @@ -37,11 +36,16 @@ def broadcast_task_status(task_id: str, status_data: Union[dict, str]): # 通过 task_id 获取 Task,再获取 client_id task_manager = get_task_manager() - task = task_manager._tasks.get(task_id) + task = task_manager.get_task(task_id) if not task: log("WARNING", f"Task {task_id} not found, cannot broadcast status") return + + # 检查 client_id 是否存在 + if not task.client_id: + log("WARNING", f"Task {task_id} has no client_id, cannot broadcast status") + return # 将消息发送给对应的客户端 if isinstance(status_data, list): @@ -54,11 +58,10 @@ def broadcast_task_status(task_id: str, status_data: Union[dict, str]): log("ERROR", f"Error broadcasting task status via WebSocket: {e}") @staticmethod - def broadcast_queue_status(get_pending_task_count_fn: Callable[[], int]): - """广播当前队列状态给所有连接(类似 ComfyUI 的 queue_updated) + def broadcast_queue_status(): + """向每个连接的客户端发送其各自的队列状态(类似 ComfyUI 的 queue_updated) - Args: - get_pending_task_count_fn: 获取运行中任务数量的函数 + 注意:此方法会向每个客户端发送其对应用户的队列状态,实现多租户隔离 """ try: # 只在CPU模式下广播 @@ -66,24 +69,36 @@ def broadcast_queue_status(get_pending_task_count_fn: Callable[[], int]): return from services.process.websocket.websocket_manager import ws_manager + from services.gateway import get_task_manager + from services.gateway.task.task import TaskStatus - # 获取当前队列中的任务数 - pending_count = get_pending_task_count_fn() + task_manager = get_task_manager() + + # 获取所有活跃的客户端连接及其用户ID + client_user_mapping = ws_manager.get_client_user_mapping() - # 构建 ComfyUI 格式的状态消息 - queue_status_msg = { - "type": "status", - "data": { - "status": { - "exec_info": { - "queue_remaining": pending_count + # 为每个客户端发送其对应用户的队列状态 + for client_id, user_id in client_user_mapping.items(): + try: + pending_count = task_manager.get_running_task_count_by_user(user_id) + + # 构建 ComfyUI 格式的状态消息 + queue_status_msg = { + "type": "status", + "data": { + "status": { + "exec_info": { + "queue_remaining": pending_count + } + } } } - } - } - - # 将队列状态消息广播给所有连接 - ws_manager.broadcast_to_all(queue_status_msg) + + # 向该客户端发送其专属的队列状态 + ws_manager.send_to_client(client_id, queue_status_msg) + + except Exception as e: + log("ERROR", f"[TaskStatusBroadcaster] Failed to send queue status to client {client_id}: {e}") except Exception as e: log("ERROR", f"[TaskStatusBroadcaster] Failed to broadcast queue status: {e}") diff --git a/src/code/agent/services/process/websocket/websocket_manager.py b/src/code/agent/services/process/websocket/websocket_manager.py index e6076c82..a11e3c31 100644 --- a/src/code/agent/services/process/websocket/websocket_manager.py +++ b/src/code/agent/services/process/websocket/websocket_manager.py @@ -16,6 +16,9 @@ def __init__(self): # 客户端ID映射 self._client_id_mapping: Dict[str, Any] = {} # client_id -> websocket + # 客户端用户ID映射(用于多租户隔离) + self._client_user_mapping: Dict[str, str] = {} # client_id -> user_id + # 使用消息队列序列化所有发送操作 self._message_queue = Queue() # 线程安全的消息队列 self._send_thread: Optional[threading.Thread] = None @@ -48,13 +51,14 @@ def get_connection_info(self, ws): 'port': None } - def add_connection(self, ws, client_id: Optional[str] = None): + def add_connection(self, ws, client_id: Optional[str] = None, user_id: Optional[str] = None): """ 添加WebSocket连接到管理器 Args: ws: WebSocket连接 client_id: 可选的客户端ID,如果提供则建立映射并处理重连 + user_id: 可选的用户ID,用于多租户隔离 """ with self._lock: # 如果有 client_id,先处理重连逻辑(移除旧连接) @@ -73,9 +77,12 @@ def add_connection(self, ws, client_id: Optional[str] = None): # 如果有 client_id,建立映射(一个客户端只有一个连接) if client_id: self._client_id_mapping[client_id] = ws + # 存储用户ID映射 + if user_id: + self._client_user_mapping[client_id] = user_id conn_info = self.get_connection_info(ws) - log("DEBUG", f"[WebSocketManager] Connection added: {json.dumps(conn_info, indent=2)}" + (f" (client_id: {client_id})" if client_id else "")) + log("DEBUG", f"[WebSocketManager] Connection added: {json.dumps(conn_info, indent=2)}" + (f" (client_id: {client_id}, user_id: {user_id})" if client_id else "")) if self._send_thread is None or not self._send_thread.is_alive(): log("WARNING", "[WebSocketManager] Send thread not running, restarting...") @@ -94,9 +101,12 @@ def remove_connection(self, ws, client_id: Optional[str] = None): conn_id = id(ws) start_time = self._connection_times.pop(conn_id, None) - # 清理客户端ID映射 - if client_id and self._client_id_mapping.get(client_id) == ws: - del self._client_id_mapping[client_id] + # 清理客户端ID映射和用户ID映射 + if client_id: + if self._client_id_mapping.get(client_id) == ws: + del self._client_id_mapping[client_id] + # 同时清理用户ID映射 + self._client_user_mapping.pop(client_id, None) conn_info = self.get_connection_info(ws) if start_time: @@ -117,6 +127,16 @@ def get_connection(self, client_id: str) -> Optional[Any]: with self._lock: return self._client_id_mapping.get(client_id) + def get_client_user_mapping(self) -> Dict[str, str]: + """ + 获取客户端ID到用户ID的映射副本 + + Returns: + Dict[str, str]: 客户端ID到用户ID的映射字典副本 + """ + with self._lock: + return self._client_user_mapping.copy() + def send_to_client(self, client_id: str, message: Union[dict, str]) -> bool: """ @@ -167,6 +187,7 @@ def close_all_connections(self): # 清理连接和映射记录 self.active_connections.clear() self._client_id_mapping.clear() + self._client_user_mapping.clear() for ws in connections_to_close: try: diff --git a/src/code/agent/test/unit/services/gateway/history_handler_test.py b/src/code/agent/test/unit/services/gateway/history_handler_test.py new file mode 100644 index 00000000..93ab0a7c --- /dev/null +++ b/src/code/agent/test/unit/services/gateway/history_handler_test.py @@ -0,0 +1,332 @@ +""" +History Handler 单元测试 +测试 /api/history 接口的核心逻辑 +""" +import pytest +import time +from unittest.mock import Mock, patch, MagicMock +from flask import Flask, g +from dataclasses import dataclass +from typing import Optional + +from services.gateway.handlers.history_handler import HistoryHandler + + +# ==================== Mock Task 数据类 ==================== + +@dataclass +class MockTask: + """模拟 Task 对象用于测试""" + task_id: str + client_id: str + prompt_body: dict + user_id: str + prompt_id: str + status: Mock + completed_at: Optional[float] = None + + +# ==================== Fixtures ==================== + +@pytest.fixture +def app(): + """创建测试用的 Flask 应用""" + app = Flask(__name__) + app.config['TESTING'] = True + return app + + +@pytest.fixture +def mock_task_manager(): + """创建 mock 的 TaskManager""" + manager = Mock() + manager.get_history = Mock(return_value={}) + return manager + + +@pytest.fixture +def mock_task_status(): + """创建 mock 的 TaskStatus 枚举""" + status = Mock() + status.COMPLETED = Mock(name="COMPLETED") + status.FAILED = Mock(name="FAILED") + status.PENDING = Mock(name="PENDING") + status.RUNNING = Mock(name="RUNNING") + return status + + +@pytest.fixture +def handler_with_mocks(mock_task_manager, mock_task_status): + """创建已初始化的 HistoryHandler""" + with patch('services.gateway.task.task_manager.get_task_manager', return_value=mock_task_manager), \ + patch('services.gateway.task.task.TaskStatus', mock_task_status): + handler = HistoryHandler() + yield handler + + +@pytest.fixture +def handler_uninitialized(): + """创建未初始化的 HistoryHandler(模拟初始化失败)""" + with patch('services.gateway.task.task_manager.get_task_manager', side_effect=Exception("Init failed")): + handler = HistoryHandler() + yield handler + + +# ==================== 测试类 ==================== + +class TestHistoryHandlerInitialization: + """测试 HistoryHandler 初始化""" + + def test_successful_initialization(self, handler_with_mocks): + """测试成功初始化""" + assert handler_with_mocks.task_manager is not None + assert handler_with_mocks.TaskStatus is not None + assert handler_with_mocks._is_initialized() is True + + def test_initialization_failure(self, handler_uninitialized): + """测试初始化失败""" + assert handler_uninitialized.task_manager is None + assert handler_uninitialized._is_initialized() is False + + +class TestHandleGetRequest: + """测试 handle_get_request 方法""" + + def test_get_history_success_no_limit(self, app, handler_with_mocks, mock_task_manager): + """测试成功获取历史 - 无 max_items 参数""" + # 准备测试数据 + mock_history = { + "prompt-123": { + "prompt": [1, "prompt-123", {}, {}, []], + "outputs": {}, + "status": {"status_str": "success", "completed": True, "messages": []}, + "meta": {} + } + } + mock_task_manager.get_history.return_value = mock_history + + with app.test_request_context(): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + # 验证调用 + mock_task_manager.get_history.assert_called_once_with(max_items=None) + + # 验证响应 + assert response.json == mock_history + assert response.status_code == 200 + + def test_get_history_success_with_valid_limit(self, app, handler_with_mocks, mock_task_manager): + """测试成功获取历史 - 带有效 max_items 参数""" + mock_history = { + "prompt-1": {"status": {"completed": True}}, + "prompt-2": {"status": {"completed": True}} + } + mock_task_manager.get_history.return_value = mock_history + + with app.test_request_context('/?max_items=10'): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + mock_task_manager.get_history.assert_called_once_with(max_items=10) + assert response.json == mock_history + + def test_get_history_empty_result(self, app, handler_with_mocks, mock_task_manager): + """测试空历史记录 - 返回空字典""" + mock_task_manager.get_history.return_value = {} + + with app.test_request_context(): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + assert response.json == {} + assert response.status_code == 200 + + def test_get_history_uninitialized_handler(self, app, handler_uninitialized): + """测试初始化失败 - 返回 503""" + with app.test_request_context(): + g.user_id = 'user-test' + response, status_code = handler_uninitialized.handle_get_request() + + assert response.json == {} + assert status_code == 503 + + def test_get_history_invalid_limit_string(self, app, handler_with_mocks, mock_task_manager): + """测试无效 max_items 参数 - 非数字字符串""" + mock_task_manager.get_history.return_value = {} + + with app.test_request_context('/?max_items=abc'): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + # 无效的 limit 应该被忽略,传递 None + mock_task_manager.get_history.assert_called_once_with(max_items=None) + + def test_get_history_limit_zero(self, app, handler_with_mocks, mock_task_manager): + """测试 max_items 为 0 - 应该被忽略""" + mock_task_manager.get_history.return_value = {} + + with app.test_request_context('/?max_items=0'): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + mock_task_manager.get_history.assert_called_once_with(max_items=None) + + def test_get_history_limit_negative(self, app, handler_with_mocks, mock_task_manager): + """测试 max_items 为负数 - 应该被忽略""" + mock_task_manager.get_history.return_value = {} + + with app.test_request_context('/?max_items=-5'): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + mock_task_manager.get_history.assert_called_once_with(max_items=None) + + +class TestParseLimitParam: + """测试 _parse_limit_param 方法""" + + def test_parse_valid_limit(self, app, handler_with_mocks): + """测试解析有效的 max_items 参数""" + with app.test_request_context('/?max_items=50'): + result = handler_with_mocks._parse_max_items_param() + assert result == 50 + + def test_parse_no_limit(self, app, handler_with_mocks): + """测试无 max_items 参数""" + with app.test_request_context('/'): + result = handler_with_mocks._parse_max_items_param() + assert result is None + + def test_parse_invalid_limit(self, app, handler_with_mocks): + """测试无效的 max_items 参数""" + with app.test_request_context('/?max_items=invalid'): + result = handler_with_mocks._parse_max_items_param() + assert result is None + + def test_parse_limit_zero(self, app, handler_with_mocks): + """测试 max_items=0""" + with app.test_request_context('/?max_items=0'): + result = handler_with_mocks._parse_max_items_param() + assert result is None + + def test_parse_limit_negative(self, app, handler_with_mocks): + """测试负数 max_items""" + with app.test_request_context('/?max_items=-10'): + result = handler_with_mocks._parse_max_items_param() + assert result is None + + +class TestMultiTenantIsolation: + """测试多租户隔离(集成测试,需要真实的 TaskManager.get_history 逻辑)""" + + def test_user_only_sees_own_history(self, app, handler_with_mocks, mock_task_manager): + """测试用户只能看到自己的历史记录""" + # TaskManager.get_history 已经实现了用户隔离 + # 这里主要验证正确调用了 get_history + + user_history = { + "prompt-user1": {"status": {"completed": True}} + } + mock_task_manager.get_history.return_value = user_history + + with app.test_request_context(): + g.user_id = 'user-1' + response = handler_with_mocks.handle_get_request() + + # 验证返回的是当前用户的历史 + assert response.json == user_history + + # TaskManager.get_history 内部会通过 _get_current_user_id_from_request() + # 获取 g.user_id 进行过滤,这里我们验证它被正确调用 + mock_task_manager.get_history.assert_called_once() + + +class TestDataFormat: + """测试返回数据格式""" + + def test_response_contains_required_fields(self, app, handler_with_mocks, mock_task_manager): + """测试响应包含必需的字段""" + complete_history = { + "prompt-123": { + "prompt": [1, "prompt-123", {"node1": {}}, {"client_id": "test"}, []], + "outputs": { + "node1": { + "images": [ + { + "filename": "output.png", + "subfolder": "", + "type": "output" + } + ] + } + }, + "status": { + "status_str": "success", + "completed": True, + "messages": [ + ["execution_start", {"prompt_id": "prompt-123", "timestamp": 1000}], + ["execution_success", {"prompt_id": "prompt-123", "timestamp": 2000}] + ] + }, + "meta": { + "node1": { + "node_id": "node1", + "display_node": "node1", + "parent_node": None, + "real_node_id": "node1" + } + } + } + } + mock_task_manager.get_history.return_value = complete_history + + with app.test_request_context(): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + result = response.json + assert "prompt-123" in result + + history_item = result["prompt-123"] + assert "prompt" in history_item + assert "outputs" in history_item + assert "status" in history_item + assert "meta" in history_item + + # 验证不包含 user_id(应该在 TaskManager.get_history 中被过滤) + assert "user_id" not in history_item + + def test_response_for_failed_task(self, app, handler_with_mocks, mock_task_manager): + """测试失败任务的响应格式""" + failed_history = { + "prompt-failed": { + "prompt": [1, "prompt-failed", {}, {}, []], + "outputs": {}, + "status": { + "status_str": "error", + "completed": True, + "messages": [ + ["execution_start", {"prompt_id": "prompt-failed", "timestamp": 1000}], + ["execution_error", { + "prompt_id": "prompt-failed", + "node_id": "node_x", + "exception_message": "Error occurred", + "timestamp": 2000 + }] + ] + }, + "meta": {} + } + } + mock_task_manager.get_history.return_value = failed_history + + with app.test_request_context(): + g.user_id = 'user-test' + response = handler_with_mocks.handle_get_request() + + result = response.json + assert "prompt-failed" in result + assert result["prompt-failed"]["status"]["status_str"] == "error" + assert result["prompt-failed"]["status"]["completed"] is True + diff --git a/src/code/agent/test/unit/services/gateway/history_helper_test.py b/src/code/agent/test/unit/services/gateway/history_helper_test.py new file mode 100644 index 00000000..aa56b603 --- /dev/null +++ b/src/code/agent/test/unit/services/gateway/history_helper_test.py @@ -0,0 +1,1690 @@ +import pytest +import time +from unittest.mock import Mock, patch +from flask import Flask, g +from collections import defaultdict + +from services.gateway.task.task_manager import TaskManager +from services.gateway.task.task import Task, TaskStatus +from services.gateway.task.history_manager import HistoryManager + + +# ==================== Fixtures ==================== + +@pytest.fixture +def app(): + app = Flask(__name__) + return app + + +@pytest.fixture +def task_manager(): + with patch('services.serverlessapi.serverless_api_service.ServerlessApiService') as mock_service: + mock_service.return_value = Mock() + manager = TaskManager( + max_active_tasks=100, + max_completed_tasks=50, + gpu_function_url="http://gpu-service" + ) + yield manager + + +@pytest.fixture +def history_manager(): + """独立的 HistoryManager 实例用于单元测试""" + return HistoryManager() + + +# ==================== 测试类 ==================== + +class TestHistoryByUserIndex: + """测试用户索引功能""" + + def test_history_by_user_initialization(self, task_manager): + """_history_by_user 初始化""" + assert hasattr(task_manager._history_manager, '_history_by_user') + assert isinstance(task_manager._history_manager._history_by_user, defaultdict) + assert len(task_manager._history_manager._history_by_user) == 0 + + def test_add_history_updates_index(self, task_manager, app): + """添加 history 时同步更新索引""" + with app.test_request_context(): + g.user_id = 'user-123' + + prompt_id = "prompt-abc" + history_item = { + "prompt": [1, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True}, + "meta": {}, + "user_id": "user-123" + } + + with task_manager._lock: + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-123"][prompt_id] = history_item + + assert prompt_id in task_manager._history_manager.history + assert prompt_id in task_manager._history_manager._history_by_user["user-123"] + assert task_manager._history_manager._history_by_user["user-123"][prompt_id] == history_item + + def test_get_history_uses_user_index(self, task_manager, app): + """get_history 使用用户索引""" + with app.test_request_context(): + g.user_id = 'user-1' + + with task_manager._lock: + for i in range(3): + prompt_id = f"prompt-user1-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-1" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-1"][prompt_id] = history_item + + for i in range(5): + prompt_id = f"prompt-user2-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-2" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-2"][prompt_id] = history_item + + result = task_manager.get_history() + + assert len(result) == 3 + for prompt_id in result: + assert "user1" in prompt_id + assert "user2" not in prompt_id + # user_id 字段保留(性能优化) + assert result[prompt_id]["user_id"] == "user-1" + + def test_delete_history_removes_from_index(self, task_manager, app): + """删除 history 时同步从索引中删除""" + with app.test_request_context(): + g.user_id = 'user-123' + + prompt_id = "prompt-to-delete" + history_item = { + "prompt": [1, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True}, + "meta": {}, + "user_id": "user-123" + } + + with task_manager._lock: + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-123"][prompt_id] = history_item + + with task_manager._lock: + task_manager._history_manager.history.pop(prompt_id, None) + task_manager._history_manager._history_by_user["user-123"].pop(prompt_id, None) + + assert prompt_id not in task_manager._history_manager.history + assert prompt_id not in task_manager._history_manager._history_by_user["user-123"] + + def test_multi_tenant_isolation(self, task_manager, app): + """多租户隔离""" + users = ["user-a", "user-b", "user-c"] + + with task_manager._lock: + for user in users: + for i in range(2): + prompt_id = f"prompt-{user}-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": user + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user[user][prompt_id] = history_item + + for user in users: + with app.test_request_context(): + g.user_id = user + result = task_manager.get_history() + + assert len(result) == 2 + for prompt_id in result: + assert user in prompt_id + + +class TestGetHistory: + """测试 get_history 方法""" + + def test_get_history_empty(self, task_manager, app): + """获取空 history""" + with app.test_request_context(): + g.user_id = 'user-empty' + result = task_manager.get_history() + assert result == {} + + def test_get_history_with_max_items(self, task_manager, app): + """max_items 参数""" + with app.test_request_context(): + g.user_id = 'user-test' + + with task_manager._lock: + for i in range(10): + prompt_id = f"prompt-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-test" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-test"][prompt_id] = history_item + + result = task_manager.get_history(max_items=3) + assert len(result) == 3 + + def test_get_history_filters_uncompleted(self, task_manager, app): + """过滤未完成的任务""" + with app.test_request_context(): + g.user_id = 'user-test' + + with task_manager._lock: + for i in range(3): + prompt_id = f"prompt-completed-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-test" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-test"][prompt_id] = history_item + + for i in range(2): + prompt_id = f"prompt-running-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": False, "status_str": "running", "messages": []}, + "meta": {}, + "user_id": "user-test" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-test"][prompt_id] = history_item + + result = task_manager.get_history() + assert len(result) == 3 + for prompt_id in result: + assert "completed" in prompt_id + + def test_get_history_keeps_user_id_field(self, task_manager, app): + """返回的 history 中包含 user_id 字段(性能优化,不再复制)""" + with app.test_request_context(): + g.user_id = 'user-test' + + with task_manager._lock: + prompt_id = "prompt-123" + history_item = { + "prompt": [1, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-test" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-test"][prompt_id] = history_item + + result = task_manager.get_history() + + assert prompt_id in result + # ✅ 性能优化:不再去除 user_id(因为已经通过 user_id 过滤) + assert "user_id" in result[prompt_id] + assert result[prompt_id]["user_id"] == "user-test" + assert "prompt" in result[prompt_id] + assert "outputs" in result[prompt_id] + assert "status" in result[prompt_id] + + +class TestHistoryRaceCondition: + """测试 history 初始化的竞态条件修复""" + + def test_init_history_double_check_locking(self, task_manager, app): + """测试双重检查锁定防止重复初始化""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}} + } + + # 提交任务 + task_id = task_manager.submit_task(prompt_body, "client-123") + + # 第一次初始化 + message = { + "type": "execution_start", + "data": {"prompt_id": task_id, "timestamp": int(time.time() * 1000)} + } + task_manager._init_history_item(task_id, message) + + # 验证 history 已创建 + assert task_id in task_manager._history_manager.history + first_item = task_manager._history_manager.history[task_id] + + # 尝试第二次初始化(应该被跳过) + task_manager._init_history_item(task_id, message) + + # 验证 history 没有改变 + assert task_manager._history_manager.history[task_id] is first_item + + def test_init_history_with_placeholder(self, task_manager, app): + """测试占位符机制防止并发初始化""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}} + } + + task_id = task_manager.submit_task(prompt_body, "client-123") + + message = { + "type": "execution_start", + "data": {"prompt_id": task_id, "timestamp": int(time.time() * 1000)} + } + + # 初始化 + task_manager._init_history_item(task_id, message) + + # 验证最终的 history 不包含 _initializing 标记 + history_item = task_manager._history_manager.history.get(task_id) + assert history_item is not None + assert "_initializing" not in history_item or history_item.get("_initializing") is not True + assert "prompt" in history_item + assert "user_id" in history_item + + +class TestExecutedMessageHandling: + """测试 executed 消息处理的健壮性""" + + def test_executed_with_missing_history(self, task_manager, app): + """executed 消息到达时 history 不存在应延迟初始化""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}} + } + + task_id = task_manager.submit_task(prompt_body, "client-123") + + # 直接发送 executed 消息(跳过 execution_start) + message = { + "type": "executed", + "data": { + "prompt_id": task_id, + "node": "10", + "display_node": "10", + "output": { + "images": [ + {"filename": "test.png", "type": "output", "subfolder": ""} + ] + } + } + } + + # 处理消息 + task_manager.handle_message(task_id, message) + + # 验证 history 被延迟初始化 + assert task_id in task_manager._history_manager.history + history_item = task_manager._history_manager.history[task_id] + + # 验证包含必要字段 + assert "meta" in history_item + assert "outputs" in history_item + assert "user_id" in history_item + assert history_item["user_id"] == "user-test" + + # 验证 outputs 被正确添加 + assert "10" in history_item["outputs"] + assert "images" in history_item["outputs"]["10"] + assert len(history_item["outputs"]["10"]["images"]) == 1 + + def test_executed_message_validation(self, task_manager, app): + """测试 executed 消息的字段验证""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + } + + task_id = task_manager.submit_task(prompt_body, "client-123") + + # 缺少 prompt_id 的消息 + message_no_prompt = { + "type": "executed", + "data": { + "node": "10", + "output": {} + } + } + + # 不应抛出异常 + task_manager.handle_message(task_id, message_no_prompt) + + # 缺少 node_id 的消息 + message_no_node = { + "type": "executed", + "data": { + "prompt_id": "some-prompt", + "output": {} + } + } + + # 不应抛出异常 + task_manager.handle_message(task_id, message_no_node) + + def test_executed_with_initializing_placeholder(self, task_manager, app): + """测试 executed 遇到初始化占位符时的处理""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}} + } + + task_id = task_manager.submit_task(prompt_body, "client-123") + + # 手动设置一个初始化占位符 + with task_manager._lock: + task_manager._history_manager.history[task_id] = { + "_initializing": True, + "user_id": "user-test" + } + + # 发送 executed 消息 + message = { + "type": "executed", + "data": { + "prompt_id": task_id, + "node": "10", + "output": {"images": []} + } + } + + # 应该跳过处理(因为还在初始化中) + task_manager.handle_message(task_id, message) + + # 验证占位符还在 + history_item = task_manager._history_manager.history.get(task_id) + assert history_item.get("_initializing") is True + + +class TestHistoryAtomicOperations: + """测试 history 原子操作""" + + def test_add_history_item_atomic(self, task_manager, app): + """测试 _add_history_item 的原子性""" + with app.test_request_context(): + g.user_id = 'user-test' + + history_item = { + "prompt": [1, "prompt-123", {}, {}, []], + "outputs": {}, + "status": {"completed": False}, + "meta": {}, + "user_id": "user-test" + } + + # 添加 + success = task_manager._history_manager.add_history_item("prompt-123", history_item) + assert success is True + + # 验证两个字典都更新了 + assert "prompt-123" in task_manager._history_manager.history + assert "prompt-123" in task_manager._history_manager._history_by_user["user-test"] + + # 尝试重复添加 + success = task_manager._history_manager.add_history_item("prompt-123", history_item) + assert success is False + + def test_add_history_item_missing_user_id(self, task_manager): + """测试缺少 user_id 时添加失败""" + history_item = { + "prompt": [1, "prompt-456", {}, {}, []], + "outputs": {}, + # 缺少 user_id + } + + success = task_manager._history_manager.add_history_item("prompt-456", history_item) + assert success is False + assert "prompt-456" not in task_manager._history_manager.history + + def test_remove_history_item_atomic(self, task_manager, app): + """测试 _remove_history_item 的原子性""" + with app.test_request_context(): + g.user_id = 'user-test' + + # 先添加 + history_item = { + "prompt": [1, "prompt-789", {}, {}, []], + "outputs": {}, + "status": {"completed": True}, + "meta": {}, + "user_id": "user-test" + } + + task_manager._history_manager.add_history_item("prompt-789", history_item) + + # 验证存在 + assert "prompt-789" in task_manager._history_manager.history + assert "prompt-789" in task_manager._history_manager._history_by_user["user-test"] + + # 删除 + success = task_manager._history_manager.remove_history_item("prompt-789") + assert success is True + + # 验证两个字典都删除了 + assert "prompt-789" not in task_manager._history_manager.history + assert "prompt-789" not in task_manager._history_manager._history_by_user["user-test"] + + def test_remove_nonexistent_history_item(self, task_manager): + """测试删除不存在的 history item""" + success = task_manager._history_manager.remove_history_item("nonexistent-prompt") + assert success is False + + def test_history_consistency_after_operations(self, task_manager, app): + """测试多次操作后 history 和索引保持一致""" + with app.test_request_context(): + g.user_id = 'user-alice' + + # 添加多个 history items + for i in range(5): + history_item = { + "prompt": [i, f"prompt-{i}", {}, {}, []], + "outputs": {}, + "status": {"completed": True}, + "meta": {}, + "user_id": "user-alice" + } + task_manager._history_manager.add_history_item(f"prompt-{i}", history_item) + + # 验证一致性 + assert len(task_manager._history_manager.history) >= 5 + assert len(task_manager._history_manager._history_by_user["user-alice"]) >= 5 + + # 删除一些 + task_manager._history_manager.remove_history_item("prompt-1") + task_manager._history_manager.remove_history_item("prompt-3") + + # 验证一致性 + assert "prompt-1" not in task_manager._history_manager.history + assert "prompt-1" not in task_manager._history_manager._history_by_user["user-alice"] + assert "prompt-3" not in task_manager._history_manager.history + assert "prompt-3" not in task_manager._history_manager._history_by_user["user-alice"] + + # 其他的还在 + assert "prompt-0" in task_manager._history_manager.history + assert "prompt-2" in task_manager._history_manager.history + assert "prompt-4" in task_manager._history_manager.history + + +class TestImageHandling: + """测试图片处理""" + + def test_multiple_images_added_directly(self, task_manager): + """多张图片直接添加""" + history_item = { + "prompt": [1, "prompt-123", {}, {}, []], + "outputs": {}, + "status": {"completed": False}, + "meta": {}, + "user_id": "user-test" + } + + node_id = "33" + history_item["outputs"][node_id] = {"images": []} + + images = [ + {"filename": "image1.png", "type": "output", "subfolder": ""}, + {"filename": "image2.png", "type": "output", "subfolder": ""}, + {"filename": "image3.png", "type": "output", "subfolder": ""} + ] + + for img in images: + image_item = { + "filename": img.get("filename", ""), + "type": img.get("type", "output"), + "subfolder": img.get("subfolder", "") + } + history_item["outputs"][node_id]["images"].append(image_item) + + assert len(history_item["outputs"][node_id]["images"]) == 3 + assert history_item["outputs"][node_id]["images"][0]["filename"] == "image1.png" + assert history_item["outputs"][node_id]["images"][1]["filename"] == "image2.png" + assert history_item["outputs"][node_id]["images"][2]["filename"] == "image3.png" + + +class TestDataConsistency: + """测试数据一致性""" + + def test_history_and_index_stay_in_sync(self, task_manager, app): + """history 和索引保持同步""" + with app.test_request_context(): + g.user_id = 'user-test' + + with task_manager._lock: + for i in range(5): + prompt_id = f"prompt-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-test" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-test"][prompt_id] = history_item + + assert len(task_manager._history_manager.history) == 5 + assert len(task_manager._history_manager._history_by_user["user-test"]) == 5 + assert set(task_manager._history_manager.history.keys()) == set(task_manager._history_manager._history_by_user["user-test"].keys()) + + +class TestHistoryManagerUnit: + """HistoryManager 独立单元测试""" + + def test_helper_initialization(self, history_manager): + """测试 HistoryManager 初始化""" + assert isinstance(history_manager.history, dict) + assert isinstance(history_manager._history_by_user, defaultdict) + assert len(history_manager.history) == 0 + assert len(history_manager._history_by_user) == 0 + + def test_get_history_empty(self, history_manager): + """测试获取空历史""" + result = history_manager.get_history("user-test") + assert result == {} + + def test_add_and_get_history(self, history_manager): + """测试添加和获取历史""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-alice" + } + + success = history_manager.add_history_item("prompt-1", history_item) + assert success is True + + result = history_manager.get_history("user-alice") + assert len(result) == 1 + assert "prompt-1" in result + + def test_remove_history(self, history_manager): + """测试删除历史""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": True}, + "meta": {}, + "user_id": "user-bob" + } + + history_manager.add_history_item("prompt-1", history_item) + assert "prompt-1" in history_manager.history + + success = history_manager.remove_history_item("prompt-1") + assert success is True + assert "prompt-1" not in history_manager.history + assert "prompt-1" not in history_manager._history_by_user["user-bob"] + + def test_update_history_status(self, history_manager): + """测试更新历史状态""" + # 先添加一个历史项 + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": { + "completed": False, + "status_str": "running", + "messages": [] + }, + "meta": {}, + "user_id": "user-test" + } + + history_manager.add_history_item("prompt-1", history_item) + + # 更新为成功 + message = { + "type": "execution_success", + "data": { + "prompt_id": "prompt-1", + "timestamp": int(time.time() * 1000) + } + } + + success = history_manager.update_history_status(message, "success") + assert success is True + + # 验证状态已更新 + updated_item = history_manager.history["prompt-1"] + assert updated_item["status"]["completed"] is True + assert updated_item["status"]["status_str"] == "success" + + def test_update_history_outputs(self, history_manager): + """测试更新历史输出""" + # 先添加一个历史项 + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": False}, + "meta": {}, + "user_id": "user-test" + } + + history_manager.add_history_item("prompt-1", history_item) + + # 更新输出 + message = { + "type": "executed", + "data": { + "prompt_id": "prompt-1", + "node": "10", + "display_node": "10", + "output": { + "images": [ + {"filename": "test.png", "type": "output", "subfolder": ""} + ] + } + } + } + + success = history_manager.update_history_outputs(message) + assert success is True + + # 验证输出已更新 + updated_item = history_manager.history["prompt-1"] + assert "10" in updated_item["meta"] + assert "10" in updated_item["outputs"] + assert len(updated_item["outputs"]["10"]["images"]) == 1 + + +class TestHistoryManagerEdgeCases: + """测试 HistoryManager 的边界情况和错误处理(提升覆盖率到 100%)""" + + def test_add_history_item_with_exception_in_assignment(self, history_manager): + """测试添加 history 时发生异常的回滚逻辑""" + # 模拟一个会导致赋值失败的情况 + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": False}, + "meta": {}, + "user_id": "user-test" + } + + # 先添加成功 + success = history_manager.add_history_item("prompt-1", history_item) + assert success is True + + # 尝试重复添加(应该失败并返回 False) + success = history_manager.add_history_item("prompt-1", history_item) + assert success is False + assert "prompt-1" in history_manager.history + + def test_add_history_item_with_invalid_history_item(self, history_manager): + """测试 add_history_item 传入无效的 history_item(触发异常)""" + # 传入一个无法获取 user_id 的对象(不是字典) + class BadHistoryItem: + def get(self, key, default=None): + raise RuntimeError("Simulated error in get") + + bad_item = BadHistoryItem() + + # 应该捕获异常并返回 False + success = history_manager.add_history_item("prompt-bad", bad_item) + assert success is False + + def test_remove_history_item_with_corrupted_data(self, history_manager): + """测试 remove_history_item 处理损坏的数据""" + # 手动创建一个会导致问题的情况 + # 添加一个没有 user_id 的项(边界情况) + history_manager.history["prompt-corrupt"] = { + "prompt": [1, "prompt-corrupt", {}, {}, []], + # 故意不添加 user_id + } + + # 应该成功删除(因为代码中有 if user_id 检查) + success = history_manager.remove_history_item("prompt-corrupt") + assert success is True + assert "prompt-corrupt" not in history_manager.history + + def test_init_history_item_placeholder_modified(self, history_manager): + """测试初始化过程中占位符被修改的情况""" + # 手动设置一个占位符 + history_manager.history["prompt-1"] = {"_initializing": True, "user_id": "user-test"} + + # 尝试初始化(已存在,应该返回 False) + success = history_manager.init_history_item( + prompt_id="prompt-1", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message={"type": "execution_start", "data": {"prompt_id": "prompt-1"}} + ) + assert success is False + + def test_init_history_item_placeholder_not_initializing(self, history_manager): + """测试构建过程中占位符状态被改变的情况""" + # 先手动设置一个占位符 + history_manager.history["prompt-new"] = {"_initializing": True, "user_id": "user-test"} + + # 在初始化过程中,手动改变占位符状态(模拟并发修改) + # 通过直接修改来模拟这种情况 + history_manager.history["prompt-new"]["_initializing"] = False + + # 由于我们无法真正并发测试,这里测试已存在的情况 + success = history_manager.init_history_item( + prompt_id="prompt-new", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message={"type": "execution_start", "data": {"prompt_id": "prompt-new"}} + ) + + # 已存在,应该返回 False + assert success is False + + def test_init_history_item_build_exception(self, history_manager, monkeypatch): + """测试 _build_history_item 抛出异常时的清理逻辑""" + def mock_build_raising(*args, **kwargs): + raise ValueError("Build failed") + + monkeypatch.setattr(history_manager, "_build_history_item", mock_build_raising) + + success = history_manager.init_history_item( + prompt_id="prompt-fail", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message={"type": "execution_start", "data": {"prompt_id": "prompt-fail"}} + ) + + assert success is False + # 验证占位符被清理 + assert "prompt-fail" not in history_manager.history + + def test_init_history_item_with_invalid_message(self, history_manager): + """测试 init_history_item 传入无效消息格式""" + # 传入 None 作为 message,会在 _build_history_item 中导致问题 + # 但由于有异常处理,应该返回 False + success = history_manager.init_history_item( + prompt_id="prompt-error", + prompt_body={}, + client_id="client-1", + user_id="user-test", + message=None # 无效的 message + ) + + # 应该捕获异常并清理占位符 + assert success is False + assert "prompt-error" not in history_manager.history + + def test_build_history_item_with_nested_prompt(self, history_manager): + """测试构建 history_item 时 prompt_body 包含嵌套 prompt 字段""" + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + "outputs_to_execute": ["1", "2"] + } + + message = { + "type": "execution_start", + "data": {"prompt_id": "prompt-1", "timestamp": 1609459200000} + } + + history_item = history_manager._build_history_item( + prompt_id="prompt-1", + prompt_body=prompt_body, + client_id="client-1", + user_id="user-test", + message=message + ) + + # 验证 outputs_to_execute 被正确提取 + assert history_item["prompt"][4] == ["1", "2"] + assert history_item["prompt"][2] == {"1": {"class_type": "Test"}} + + def test_build_history_item_with_small_timestamp(self, history_manager): + """测试时间戳小于 10000000000 的情况(秒级时间戳)""" + message = { + "type": "execution_start", + "data": {"prompt_id": "prompt-1", "timestamp": 1609459200} # 秒级时间戳 + } + + history_item = history_manager._build_history_item( + prompt_id="prompt-1", + prompt_body={}, + client_id="client-1", + user_id="user-test", + message=message + ) + + # 验证时间戳被转换为毫秒 + messages = history_item["status"]["messages"] + assert messages[0][1]["timestamp"] == 1609459200000 + + def test_build_history_item_with_large_timestamp(self, history_manager): + """测试时间戳大于等于 10000000000 的情况(毫秒级时间戳)""" + message = { + "type": "execution_start", + "data": {"prompt_id": "prompt-1", "timestamp": 1609459200000} # 毫秒级 + } + + history_item = history_manager._build_history_item( + prompt_id="prompt-1", + prompt_body={}, + client_id="client-1", + user_id="user-test", + message=message + ) + + # 验证时间戳保持不变 + messages = history_item["status"]["messages"] + assert messages[0][1]["timestamp"] == 1609459200000 + + def test_update_history_status_missing_prompt_id(self, history_manager): + """测试 update_history_status 缺少 prompt_id""" + message = { + "type": "execution_success", + "data": {} # 缺少 prompt_id + } + + success = history_manager.update_history_status(message, "success") + assert success is False + + def test_update_history_status_history_not_found(self, history_manager): + """测试 update_history_status 找不到 history_item""" + message = { + "type": "execution_success", + "data": {"prompt_id": "nonexistent"} + } + + success = history_manager.update_history_status(message, "success") + assert success is False + + def test_update_history_status_create_status_if_missing(self, history_manager): + """测试当 status 字段不存在时自动创建""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "meta": {}, + "user_id": "user-test" + # 注意:没有 status 字段 + } + history_manager.add_history_item("prompt-1", history_item) + + message = { + "type": "execution_success", + "data": {"prompt_id": "prompt-1", "timestamp": 1609459200000} + } + + success = history_manager.update_history_status(message, "success") + assert success is True + assert "status" in history_manager.history["prompt-1"] + + def test_update_history_status_error_with_node_info(self, history_manager): + """测试 error 状态更新包含 node 信息""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": False, "messages": []}, + "meta": {}, + "user_id": "user-test" + } + history_manager.add_history_item("prompt-1", history_item) + + message = { + "type": "execution_error", + "data": { + "prompt_id": "prompt-1", + "node_id": "10", + "exception_message": "Test error", + "timestamp": 1609459200000 + } + } + + success = history_manager.update_history_status(message, "error") + assert success is True + + status = history_manager.history["prompt-1"]["status"] + assert status["completed"] is True + assert status["status_str"] == "error" + assert any(msg[0] == "execution_error" for msg in status["messages"]) + + def test_update_history_status_running_with_execution_cached(self, history_manager): + """测试 running 状态且消息类型为 execution_cached""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": False, "messages": []}, + "meta": {}, + "user_id": "user-test" + } + history_manager.add_history_item("prompt-1", history_item) + + message = { + "type": "execution_cached", + "data": {"prompt_id": "prompt-1", "timestamp": 1609459200000} + } + + success = history_manager.update_history_status(message, "running") + assert success is True + + status = history_manager.history["prompt-1"]["status"] + assert status["completed"] is False # 不改变 completed 标志 + assert any(msg[0] == "execution_cached" for msg in status["messages"]) + + def test_update_history_status_with_small_timestamp(self, history_manager): + """测试状态更新时小时间戳转换""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": False, "messages": []}, + "meta": {}, + "user_id": "user-test" + } + history_manager.add_history_item("prompt-1", history_item) + + message = { + "type": "execution_success", + "data": {"prompt_id": "prompt-1", "timestamp": 1609459200} # 秒级 + } + + success = history_manager.update_history_status(message, "success") + assert success is True + + # 验证时间戳被转换为毫秒 + messages = history_manager.history["prompt-1"]["status"]["messages"] + for msg in messages: + if msg[0] == "execution_success": + assert msg[1]["timestamp"] == 1609459200000 + + def test_update_history_status_with_invalid_message(self, history_manager): + """测试 update_history_status 传入无效消息""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "user_id": "user-test", + "status": {"completed": False, "messages": []} + } + history_manager.add_history_item("prompt-1", history_item) + + # 传入 None 作为 message + success = history_manager.update_history_status(None, "success") + # 会在 message.get 时失败,被异常捕获 + assert success is False + + def test_update_history_outputs_missing_prompt_id(self, history_manager): + """测试 update_history_outputs 缺少 prompt_id""" + message = { + "type": "executed", + "data": {"node": "10"} # 缺少 prompt_id + } + + success = history_manager.update_history_outputs(message) + assert success is False + + def test_update_history_outputs_missing_node_id(self, history_manager): + """测试 update_history_outputs 缺少 node_id""" + message = { + "type": "executed", + "data": {"prompt_id": "prompt-1"} # 缺少 node + } + + success = history_manager.update_history_outputs(message) + assert success is False + + def test_update_history_outputs_history_not_found(self, history_manager): + """测试 update_history_outputs 找不到 history_item""" + message = { + "type": "executed", + "data": {"prompt_id": "nonexistent", "node": "10"} + } + + success = history_manager.update_history_outputs(message) + assert success is False + + def test_update_history_outputs_initializing_placeholder(self, history_manager): + """测试 update_history_outputs 遇到初始化占位符""" + history_manager.history["prompt-1"] = { + "_initializing": True, + "user_id": "user-test" + } + + message = { + "type": "executed", + "data": {"prompt_id": "prompt-1", "node": "10", "output": {}} + } + + success = history_manager.update_history_outputs(message) + assert success is False + + def test_update_history_outputs_create_meta_if_missing(self, history_manager): + """测试当 meta 字段不存在时自动创建""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "outputs": {}, + "status": {"completed": False}, + "user_id": "user-test" + # 注意:没有 meta 字段 + } + history_manager.add_history_item("prompt-1", history_item) + + message = { + "type": "executed", + "data": { + "prompt_id": "prompt-1", + "node": "10", + "display_node": "10", + "output": {} + } + } + + success = history_manager.update_history_outputs(message) + assert success is True + assert "meta" in history_manager.history["prompt-1"] + assert "10" in history_manager.history["prompt-1"]["meta"] + + def test_update_history_outputs_create_outputs_if_missing(self, history_manager): + """测试当 outputs 字段不存在时自动创建""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "meta": {}, + "status": {"completed": False}, + "user_id": "user-test" + # 注意:没有 outputs 字段 + } + history_manager.add_history_item("prompt-1", history_item) + + message = { + "type": "executed", + "data": { + "prompt_id": "prompt-1", + "node": "10", + "output": { + "images": [{"filename": "test.png", "type": "output", "subfolder": ""}] + } + } + } + + success = history_manager.update_history_outputs(message) + assert success is True + assert "outputs" in history_manager.history["prompt-1"] + assert "10" in history_manager.history["prompt-1"]["outputs"] + + def test_update_history_outputs_with_invalid_message(self, history_manager): + """测试 update_history_outputs 传入无效消息""" + history_item = { + "prompt": [1, "prompt-1", {}, {}, []], + "user_id": "user-test", + "meta": {}, + "outputs": {} + } + history_manager.add_history_item("prompt-1", history_item) + + # 传入 None 作为 message + success = history_manager.update_history_outputs(None) + # 会在 message.get 时失败,被异常捕获 + assert success is False + + def test_late_init_history_item_already_exists(self, history_manager): + """测试 late_init_history_item 当 history 已存在""" + history_manager.history["prompt-1"] = { + "prompt": [1, "prompt-1", {}, {}, []], + "user_id": "user-test" + } + + success = history_manager.late_init_history_item( + task_id="task-1", + prompt_id="prompt-1", + prompt_body={}, + client_id="client-1", + user_id="user-test" + ) + + assert success is False + + def test_late_init_with_none_prompt_body(self, history_manager): + """测试 late_init_history_item 使用 None prompt_body""" + # prompt_body 为 None 是合法的(代码中有 prompt_body or {}) + success = history_manager.late_init_history_item( + task_id="task-1", + prompt_id="prompt-none", + prompt_body=None, + client_id="client-1", + user_id="user-test" + ) + + # 应该成功创建 + assert success is True + assert "prompt-none" in history_manager.history + # 验证 prompt_body 被替换为 {} + assert history_manager.history["prompt-none"]["prompt"][2] == {} + + +class TestHistoryManagerFullCoverage: + """额外测试用例以达到 100% 覆盖率""" + + def test_add_history_item_assignment_exception(self, history_manager): + """测试 add_history_item 内部赋值时的异常(触发 92-97 行)""" + # 创建一个特殊的对象,在被赋值时会抛出异常 + class RaisingDict(dict): + def __setitem__(self, key, value): + if key == "test-prompt": + raise RuntimeError("Assignment failed") + super().__setitem__(key, value) + + # 替换 history 为会抛出异常的字典 + history_manager.history = RaisingDict() + history_manager._history_by_user = defaultdict(dict) + + history_item = { + "prompt": [1, "test-prompt", {}, {}, []], + "user_id": "user-test", + "outputs": {}, + "status": {"completed": False}, + "meta": {} + } + + # 应该捕获异常并回滚 + success = history_manager.add_history_item("test-prompt", history_item) + assert success is False + # 验证回滚:history 中不应该有这个项 + assert "test-prompt" not in history_manager.history + + def test_remove_history_item_deletion_exception(self, history_manager): + """测试 remove_history_item 删除时的异常(触发 128-130 行)""" + # 创建一个会在 pop 时抛出异常的字典 + class RaisingDict(dict): + def pop(self, key, default=None): + raise RuntimeError("Pop failed") + + # 手动添加一个项并替换为会抛出异常的字典 + history_manager.history["test-prompt"] = {"user_id": "user-test"} + history_manager._history_by_user["user-test"]["test-prompt"] = {"user_id": "user-test"} + + # 保存原始引用 + original_history = history_manager.history + # 替换为会抛出异常的字典 + history_manager.history = RaisingDict(original_history) + + # 应该捕获异常 + success = history_manager.remove_history_item("test-prompt") + assert success is False + + def test_init_history_item_placeholder_warning(self, history_manager): + """测试 init_history_item 占位符被修改的警告(触发 176-177 行)""" + # 手动设置一个不是初始化占位符的项 + history_manager.history["test-prompt"] = { + "_initializing": False, # 已经不是初始化状态 + "user_id": "user-test" + } + + # 尝试初始化,会进入占位符检查逻辑,发现 _initializing 为 False + success = history_manager.init_history_item( + prompt_id="test-prompt", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message={"type": "execution_start", "data": {"prompt_id": "test-prompt"}} + ) + + # 已存在,返回 False + assert success is False + + def test_init_history_item_check_fails_after_build(self, history_manager, monkeypatch): + """测试构建完成后占位符检查失败(触发 176-177 行)""" + # 模拟在构建过程中其他线程修改了占位符 + original_build = history_manager._build_history_item + + def mock_build_and_modify(*args, **kwargs): + # 在构建过程中,将占位符的 _initializing 设置为 False + if "test-prompt" in history_manager.history: + history_manager.history["test-prompt"]["_initializing"] = False + return original_build(*args, **kwargs) + + monkeypatch.setattr(history_manager, "_build_history_item", mock_build_and_modify) + + success = history_manager.init_history_item( + prompt_id="test-prompt", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message={"type": "execution_start", "data": {"prompt_id": "test-prompt"}} + ) + + # 占位符被修改,应该返回 False + assert success is False + + def test_init_history_item_outer_exception_during_placeholder_set(self, history_manager): + """测试在设置占位符时发生异常,触发最外层异常处理(187-189行)""" + # 创建一个在设置占位符时会抛出异常的字典 + class RaisingDictOnSet(dict): + def __setitem__(self, key, value): + if key == "test-outer-exception": + raise RuntimeError("Failed to set placeholder") + super().__setitem__(key, value) + + history_manager.history = RaisingDictOnSet() + history_manager._history_by_user = defaultdict(dict) + + # 调用 init_history_item,在设置占位符时会抛出异常 + success = history_manager.init_history_item( + prompt_id="test-outer-exception", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message={"type": "execution_start", "data": {"prompt_id": "test-outer-exception"}} + ) + + # 最外层异常处理应该捕获并返回 False + assert success is False + # 验证没有添加到 history + assert "test-outer-exception" not in history_manager.history + + def test_init_history_item_placeholder_cleanup(self, history_manager, monkeypatch): + """测试 init_history_item 构建失败后清理占位符(触发 187-189 行)""" + # 让 _build_history_item 抛出异常 + def mock_build_exception(*args, **kwargs): + raise ValueError("Build failed intentionally") + + monkeypatch.setattr(history_manager, "_build_history_item", mock_build_exception) + + success = history_manager.init_history_item( + prompt_id="test-build-fail", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message={"type": "execution_start", "data": {"prompt_id": "test-build-fail"}} + ) + + # 应该失败并清理占位符 + assert success is False + assert "test-build-fail" not in history_manager.history + + def test_build_history_item_no_timestamp(self, history_manager): + """测试 _build_history_item 消息中没有 timestamp(触发 231 行)""" + message = { + "type": "execution_start", + "data": {"prompt_id": "test-prompt"} # 没有 timestamp + } + + history_item = history_manager._build_history_item( + prompt_id="test-prompt", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test", + message=message + ) + + # 验证使用了默认时间戳(当前时间) + messages = history_item["status"]["messages"] + assert messages[0][0] == "execution_start" + assert "timestamp" in messages[0][1] + assert messages[0][1]["timestamp"] > 0 + + def test_update_history_status_no_timestamp(self, history_manager): + """测试 update_history_status 消息中没有 timestamp(触发 294 行)""" + history_item = { + "prompt": [1, "test-prompt", {}, {}, []], + "user_id": "user-test", + "status": {"completed": False, "messages": []}, + "outputs": {}, + "meta": {} + } + history_manager.add_history_item("test-prompt", history_item) + + message = { + "type": "execution_success", + "data": {"prompt_id": "test-prompt"} # 没有 timestamp + } + + success = history_manager.update_history_status(message, "success") + assert success is True + + # 验证使用了默认时间戳 + status = history_manager.history["test-prompt"]["status"] + assert status["completed"] is True + success_msg = [msg for msg in status["messages"] if msg[0] == "execution_success"][0] + assert "timestamp" in success_msg[1] + assert success_msg[1]["timestamp"] > 0 + + def test_late_init_history_item_with_exception(self, history_manager): + """测试 late_init_history_item 发生异常(触发 440-442 行)""" + # 创建一个会在赋值时抛出异常的字典 + class RaisingDict(dict): + def __setitem__(self, key, value): + raise RuntimeError("Assignment failed") + + history_manager.history = RaisingDict() + history_manager._history_by_user = defaultdict(dict) + + success = history_manager.late_init_history_item( + task_id="task-1", + prompt_id="test-exception", + prompt_body={"1": {"class_type": "Test"}}, + client_id="client-1", + user_id="user-test" + ) + + # 应该捕获异常并返回 False + assert success is False + + +class TestHistoryManagerThreadSafety: + """测试 HistoryManager 的线程安全性(独立锁)""" + + def test_has_independent_lock(self, history_manager): + """验证 HistoryManager 有独立的锁""" + import threading + assert hasattr(history_manager, '_lock') + # 检查锁是否有 acquire 和 release 方法(鸭子类型检查) + assert hasattr(history_manager._lock, 'acquire') + assert hasattr(history_manager._lock, 'release') + assert callable(history_manager._lock.acquire) + assert callable(history_manager._lock.release) + + def test_concurrent_add_operations(self, history_manager): + """测试并发添加操作的线程安全性""" + import threading + results = [] + errors = [] + + def add_item(index): + try: + prompt_id = f"prompt-{index}" + history_item = { + "prompt": [index, prompt_id, {}, {}, []], + "user_id": f"user-{index % 5}", # 5个用户 + "status": {"completed": False, "messages": []}, + "outputs": {}, + "meta": {} + } + success = history_manager.add_history_item(prompt_id, history_item) + results.append((index, success)) + except Exception as e: + errors.append((index, str(e))) + + # 创建100个线程并发添加 + threads = [] + for i in range(100): + t = threading.Thread(target=add_item, args=(i,)) + threads.append(t) + t.start() + + # 等待所有线程完成 + for t in threads: + t.join() + + # 验证结果 + assert len(errors) == 0, f"应该没有错误,但发现: {errors}" + assert len(results) == 100 + assert all(success for _, success in results), "所有添加操作应该成功" + assert len(history_manager.history) == 100 + + # 验证 _history_by_user 索引正确 + for i in range(100): + user_id = f"user-{i % 5}" + prompt_id = f"prompt-{i}" + assert prompt_id in history_manager._history_by_user[user_id] + + def test_concurrent_read_write_operations(self, history_manager): + """测试并发读写操作的线程安全性""" + import threading + import random + + # 预先添加一些数据 + for i in range(10): + prompt_id = f"prompt-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "user_id": f"user-{i % 3}", + "status": {"completed": False, "messages": []}, + "outputs": {}, + "meta": {} + } + history_manager.add_history_item(prompt_id, history_item) + + errors = [] + read_results = [] + write_results = [] + + def reader(user_id, iterations): + try: + for _ in range(iterations): + result = history_manager.get_history(user_id) + read_results.append(len(result)) + time.sleep(0.001) + except Exception as e: + errors.append(("read", str(e))) + + def writer(start_index, count): + try: + for i in range(start_index, start_index + count): + prompt_id = f"prompt-new-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "user_id": f"user-{i % 3}", + "status": {"completed": True, "messages": []}, + "outputs": {}, + "meta": {} + } + success = history_manager.add_history_item(prompt_id, history_item) + write_results.append(success) + time.sleep(0.001) + except Exception as e: + errors.append(("write", str(e))) + + # 创建混合读写线程 + threads = [] + + # 10个读线程 + for i in range(10): + t = threading.Thread(target=reader, args=(f"user-{i % 3}", 10)) + threads.append(t) + + # 5个写线程 + for i in range(5): + t = threading.Thread(target=writer, args=(100 + i * 10, 10)) + threads.append(t) + + # 随机启动线程 + random.shuffle(threads) + for t in threads: + t.start() + + # 等待完成 + for t in threads: + t.join() + + # 验证没有错误 + assert len(errors) == 0, f"并发读写不应该有错误: {errors}" + assert len(read_results) == 100 # 10个线程 * 10次迭代 + assert len(write_results) == 50 # 5个线程 * 10次写入 + assert all(write_results), "所有写入操作应该成功" + + def test_concurrent_remove_operations(self, history_manager): + """测试并发删除操作的线程安全性""" + import threading + + # 预先添加数据 + for i in range(50): + prompt_id = f"prompt-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "user_id": f"user-{i % 5}", + "status": {"completed": True, "messages": []}, + "outputs": {}, + "meta": {} + } + history_manager.add_history_item(prompt_id, history_item) + + assert len(history_manager.history) == 50 + + remove_results = [] + errors = [] + + def remover(index): + try: + prompt_id = f"prompt-{index}" + success = history_manager.remove_history_item(prompt_id) + remove_results.append((index, success)) + except Exception as e: + errors.append((index, str(e))) + + # 创建50个线程并发删除 + threads = [] + for i in range(50): + t = threading.Thread(target=remover, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # 验证结果 + assert len(errors) == 0, f"删除操作不应该有错误: {errors}" + assert len(remove_results) == 50 + assert all(success for _, success in remove_results) + assert len(history_manager.history) == 0 + + # 验证 _history_by_user 也被清空 + for user_id in range(5): + assert len(history_manager._history_by_user[f"user-{user_id}"]) == 0 + + def test_concurrent_update_operations(self, history_manager): + """测试并发更新操作的线程安全性""" + import threading + + # 预先添加一些数据 + for i in range(10): + prompt_id = f"prompt-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "user_id": f"user-{i % 3}", + "status": {"completed": False, "messages": []}, + "outputs": {}, + "meta": {} + } + history_manager.add_history_item(prompt_id, history_item) + + errors = [] + update_results = [] + + def updater(prompt_id, iterations): + try: + for j in range(iterations): + message = { + "type": "execution_success", + "data": { + "prompt_id": prompt_id, + "timestamp": int(time.time() * 1000) + } + } + success = history_manager.update_history_status(message, "success") + update_results.append((prompt_id, j, success)) + time.sleep(0.001) + except Exception as e: + errors.append((prompt_id, str(e))) + + # 创建多个线程更新同一个 history_item + threads = [] + for i in range(10): + # 每个 prompt 被多个线程同时更新 + t = threading.Thread(target=updater, args=(f"prompt-{i}", 5)) + threads.append(t) + + for t in threads: + t.start() + + for t in threads: + t.join() + + # 验证没有错误 + assert len(errors) == 0, f"更新操作不应该有错误: {errors}" + assert len(update_results) == 50 # 10个线程 * 5次迭代 + + # 验证所有 history_item 都被标记为完成 + for i in range(10): + prompt_id = f"prompt-{i}" + history_item = history_manager.history[prompt_id] + assert history_item["status"]["completed"] is True + + def test_get_history_item_thread_safe(self, history_manager): + """测试 get_history_item 方法的线程安全性""" + import threading + + # 添加测试数据 + history_item = { + "prompt": [1, "test-prompt", {}, {}, []], + "user_id": "user-test", + "status": {"completed": False, "messages": []}, + "outputs": {}, + "meta": {} + } + history_manager.add_history_item("test-prompt", history_item) + + results = [] + errors = [] + + def getter(iterations): + try: + for _ in range(iterations): + item = history_manager.get_history_item("test-prompt") + results.append(item is not None) + time.sleep(0.001) + except Exception as e: + errors.append(str(e)) + + # 创建多个读线程 + threads = [] + for _ in range(20): + t = threading.Thread(target=getter, args=(10,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + assert len(errors) == 0, f"get_history_item 不应该有错误: {errors}" + assert len(results) == 200 # 20个线程 * 10次迭代 + assert all(results), "所有读取都应该成功" diff --git a/src/code/agent/test/unit/services/serverless_api_test.py b/src/code/agent/test/unit/services/serverless_api_test.py index 88ef4c48..8adf834c 100644 --- a/src/code/agent/test/unit/services/serverless_api_test.py +++ b/src/code/agent/test/unit/services/serverless_api_test.py @@ -364,7 +364,7 @@ def test_run_with_none_values_in_extra_data(service, app): # 验证 None 值被正确传递给 ComfyUI assert captured_request['extra_data']['session_id'] is None assert captured_request['extra_data']['metadata'] is None - assert captured_request['extra_data']['user_id'] == "user123" + assert captured_request['extra_data']['x-art-comfy-user'] == "user123" def test_run_with_nested_extra_data(service, app): diff --git a/src/code/agent/test/unit/services/task_manager_test.py b/src/code/agent/test/unit/services/task_manager_test.py index adb38eae..9346f875 100644 --- a/src/code/agent/test/unit/services/task_manager_test.py +++ b/src/code/agent/test/unit/services/task_manager_test.py @@ -1,9 +1,12 @@ import pytest +import time from unittest.mock import Mock, patch -from flask import Flask +from flask import Flask, g +from collections import defaultdict import requests from services.gateway.task.task_manager import TaskManager +from services.gateway.task.task import Task, TaskStatus from exceptions.exceptions import ConfigurationError, InvalidRequestError, WorkerExecutionError @@ -19,7 +22,11 @@ def app(): def task_manager(): with patch('services.serverlessapi.serverless_api_service.ServerlessApiService') as mock_service: mock_service.return_value = Mock() - manager = TaskManager(gpu_function_url="http://gpu-service") + manager = TaskManager( + max_active_tasks=100, + max_completed_tasks=50, + gpu_function_url="http://gpu-service" + ) yield manager @@ -223,3 +230,261 @@ def test_other_request_exception(self, task_manager, app): assert exc_info.value.code == 500 assert 'Failed to send sync request to GPU' in str(exc_info.value) +class TestRunningTaskCount: + """测试任务计数""" + + def test_running_count_by_user_initialization(self, task_manager): + """_running_count_by_user 初始化""" + assert hasattr(task_manager, '_running_count_by_user') + assert isinstance(task_manager._running_count_by_user, defaultdict) + + def test_get_running_task_count_by_user_empty(self, task_manager, app): + """空队列的运行计数""" + with app.test_request_context(): + g.user_id = 'user-test' + count = task_manager.get_running_task_count_by_user(g.user_id) + assert count == 0 + + def test_submit_task_increases_running_count(self, task_manager, app): + """提交任务增加运行计数""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + "extra_data": {"client_id": "test-client"} + } + + task_manager.submit_task(prompt_body, "client-123") + + count = task_manager.get_running_task_count_by_user(g.user_id) + assert count == 1 + + +class TestUserAuthentication: + """测试用户认证相关功能(问题 1.1 修复验证)""" + + def test_multi_user_isolation_with_auth(self, task_manager, app): + """验证多用户隔离(修复后不再使用 'default')""" + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + "extra_data": {"client_id": "test-client"} + } + + # 用户1提交任务 + with app.test_request_context(): + g.user_id = 'user-alice' + task_id_1 = task_manager.submit_task(prompt_body, "client-1") + count_alice = task_manager.get_running_task_count_by_user(g.user_id) + + # 用户2提交任务 + with app.test_request_context(): + g.user_id = 'user-bob' + task_id_2 = task_manager.submit_task(prompt_body, "client-2") + count_bob = task_manager.get_running_task_count_by_user(g.user_id) + + # 验证每个用户只能看到自己的任务 + assert count_alice == 1 + assert count_bob == 1 + + # 用户1再次查询,应该还是1 + with app.test_request_context(): + g.user_id = 'user-alice' + count_alice_again = task_manager.get_running_task_count_by_user(g.user_id) + assert count_alice_again == 1 + + +class TestCancelTaskRaceCondition: + """测试取消任务的竞态条件修复(问题 1.2)""" + + def test_cancel_task_updates_count_correctly(self, task_manager, app): + """取消任务时正确更新计数器""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + "extra_data": {"client_id": "test-client"} + } + + # 提交任务 + task_id = task_manager.submit_task(prompt_body, "client-123") + + # 验证计数增加 + count_before = task_manager.get_running_task_count_by_user(g.user_id) + assert count_before == 1 + + # 取消任务 + with patch.object(task_manager, '_stop_polling'): + cancelled = task_manager.cancel_task(task_id) + + assert cancelled is True + + # 验证计数减少 + count_after = task_manager.get_running_task_count_by_user(g.user_id) + assert count_after == 0 + + def test_cancel_completed_task_fails(self, task_manager, app): + """无法取消已完成的任务""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + "extra_data": {"client_id": "test-client"} + } + + # 提交任务 + task_id = task_manager.submit_task(prompt_body, "client-123") + + # 模拟任务完成 + task = task_manager.get_task(task_id) + task.update_status(TaskStatus.RUNNING) + task.update_status(TaskStatus.COMPLETED) + + # 尝试取消已完成的任务 + with patch.object(task_manager, '_stop_polling'): + cancelled = task_manager.cancel_task(task_id) + + # 应该失败 + assert cancelled is False + + def test_cancel_nonexistent_task(self, task_manager, app): + """取消不存在的任务应返回 False""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_stop_polling'): + cancelled = task_manager.cancel_task("nonexistent-task-id") + + assert cancelled is False + + def test_cancel_task_stops_polling(self, task_manager, app): + """取消任务时应停止轮询""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + "extra_data": {"client_id": "test-client"} + } + + # 提交任务 + task_id = task_manager.submit_task(prompt_body, "client-123") + + # 取消任务,验证 _stop_polling 被调用 + with patch.object(task_manager, '_stop_polling') as mock_stop_polling: + task_manager.cancel_task(task_id) + + # 验证停止轮询被调用 + mock_stop_polling.assert_called_once_with(task_id) + + def test_concurrent_cancel_and_complete(self, task_manager, app): + """模拟并发场景:取消和完成同时发生""" + with app.test_request_context(): + g.user_id = 'user-test' + + with patch.object(task_manager, '_start_polling'), \ + patch('services.gateway.task.utils.task_manager_util.TaskStatusBroadcaster'): + + prompt_body = { + "prompt": {"1": {"class_type": "Test"}}, + "extra_data": {"client_id": "test-client"} + } + + # 提交任务 + task_id = task_manager.submit_task(prompt_body, "client-123") + + initial_count = task_manager.get_running_task_count_by_user(g.user_id) + assert initial_count == 1 + + # 模拟任务开始运行 + task = task_manager.get_task(task_id) + task.update_status(TaskStatus.RUNNING) + + # 尝试取消(在锁内完成所有操作) + with patch.object(task_manager, '_stop_polling'): + cancelled = task_manager.cancel_task(task_id) + + # 验证取消成功 + assert cancelled is True + + # 验证计数正确 + final_count = task_manager.get_running_task_count_by_user(g.user_id) + assert final_count == 0 + + # 验证任务已被删除 + task_after = task_manager.get_task(task_id) + assert task_after is None +class TestPerformance: + """测试性能""" + + def test_get_history_with_many_users(self, task_manager, app): + """多用户场景性能测试""" + with task_manager._lock: + for user_idx in range(100): + user_id = f"user-{user_idx}" + for i in range(10): + prompt_id = f"prompt-{user_id}-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": user_id + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user[user_id][prompt_id] = history_item + + assert len(task_manager._history_manager.history) == 1000 + + with app.test_request_context(): + g.user_id = 'user-50' + + start_time = time.time() + result = task_manager.get_history() + elapsed = time.time() - start_time + + assert len(result) == 10 + assert elapsed < 0.05 + + +class TestDataConsistency: + """测试数据一致性""" + + def test_history_and_index_stay_in_sync(self, task_manager, app): + """history 和索引保持同步""" + with app.test_request_context(): + g.user_id = 'user-test' + + with task_manager._lock: + for i in range(5): + prompt_id = f"prompt-{i}" + history_item = { + "prompt": [i, prompt_id, {}, {}, []], + "outputs": {}, + "status": {"completed": True, "status_str": "success", "messages": []}, + "meta": {}, + "user_id": "user-test" + } + task_manager._history_manager.history[prompt_id] = history_item + task_manager._history_manager._history_by_user["user-test"][prompt_id] = history_item + + assert len(task_manager._history_manager.history) == 5 + assert len(task_manager._history_manager._history_by_user["user-test"]) == 5 + assert set(task_manager._history_manager.history.keys()) == set(task_manager._history_manager._history_by_user["user-test"].keys()) diff --git a/src/code/agent/test/unit/utils/user_identity_test.py b/src/code/agent/test/unit/utils/user_identity_test.py new file mode 100644 index 00000000..4164cf79 --- /dev/null +++ b/src/code/agent/test/unit/utils/user_identity_test.py @@ -0,0 +1,235 @@ +""" +测试 utils.user_identity 模块中的用户身份识别功能 +""" +import base64 +import pytest +from unittest.mock import Mock, patch +from flask import Flask, g + +from utils.user_identity import ( + extract_user_from_basic_auth, + extract_user_from_header, + identify_user_or_default, + set_user_identity_or_default +) + + +@pytest.fixture +def app(): + """创建测试用的 Flask 应用""" + app = Flask(__name__) + app.config['TESTING'] = True + return app + + +class TestExtractUserFromBasicAuth: + """测试从 Basic Auth 提取用户名""" + + def test_valid_basic_auth(self, app): + """测试有效的 Basic Auth - 提取用户名""" + username = 'user-h30ua81' + password = 'QRpT5pPjXj@D6u%R' + credentials = f"{username}:{password}" + encoded = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + auth_header = f"Basic {encoded}" + + with app.test_request_context(headers={'Authorization': auth_header}): + result = extract_user_from_basic_auth() + assert result == username + + def test_no_authorization_header(self, app): + """测试无 Authorization header""" + with app.test_request_context(): + result = extract_user_from_basic_auth() + assert result is None + + def test_bearer_token_not_basic(self, app): + """测试 Bearer Token (不是 Basic Auth)""" + with app.test_request_context(headers={'Authorization': 'Bearer some-jwt-token'}): + result = extract_user_from_basic_auth() + assert result is None + + def test_invalid_base64(self, app): + """测试无效的 Base64 编码""" + with app.test_request_context(headers={'Authorization': 'Basic invalid-base64!!!'}): + result = extract_user_from_basic_auth() + assert result is None + + def test_no_colon_in_credentials(self, app): + """测试格式错误 - 缺少冒号分隔符""" + invalid_credentials = 'just-username' + encoded = base64.b64encode(invalid_credentials.encode('utf-8')).decode('utf-8') + auth_header = f"Basic {encoded}" + + with app.test_request_context(headers={'Authorization': auth_header}): + result = extract_user_from_basic_auth() + assert result is None + + def test_empty_username(self, app): + """测试空用户名""" + credentials = ':password' + encoded = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + auth_header = f"Basic {encoded}" + + with app.test_request_context(headers={'Authorization': auth_header}): + result = extract_user_from_basic_auth() + assert result is None + + def test_unexpected_exception(self, app): + """测试未预期的异常""" + from unittest.mock import patch + + auth_header = "Basic dGVzdDp0ZXN0" # test:test + + with app.test_request_context(headers={'Authorization': auth_header}): + # Patch 被测模块命名空间中的 base64.b64decode + with patch('utils.user_identity.base64.b64decode', side_effect=RuntimeError("Unexpected error")): + result = extract_user_from_basic_auth() + assert result is None + + +class TestExtractUserFromHeader: + """测试从 header 提取用户信息""" + + def test_multi_tenant_disabled(self, app): + """测试多租户模式关闭 - extract_user_from_header 仍然提取信息""" + username = 'user-test' + + with patch('utils.user_identity.constants') as mock_constants: + mock_constants.ENABLE_COMFYUI_MULTI_USER = False + mock_constants.HEADER_FUNART_COMFY_USERID = 'X-FunArt-Comfy-UserId' + + headers = {'X-FunArt-Comfy-UserId': username} + + with app.test_request_context(headers=headers): + result = extract_user_from_header() + # extract_user_from_header 不关心配置,只负责提取 + assert result == username + + def test_basic_auth_authentication(self, app): + """测试使用 Basic Auth 认证""" + username = 'user-h30ua81' + password = 'QRpT5pPjXj@D6u%R' + credentials = f"{username}:{password}" + encoded = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + + headers = { + 'Authorization': f"Basic {encoded}" + } + + with app.test_request_context(headers=headers): + result = extract_user_from_header() + assert result == username + + def test_jwt_authentication(self, app): + """测试使用 JWT 认证""" + username = 'user-jwt-only' + + headers = { + 'X-FunArt-Comfy-UserId': username + } + + with app.test_request_context(headers=headers): + result = extract_user_from_header() + assert result == username + + def test_no_valid_auth_returns_none(self, app): + """测试无有效认证信息时返回 None""" + # 场景1: 完全没有认证 header + with app.test_request_context(): + result = extract_user_from_header() + assert result is None + + # 场景2: JWT header 为空白字符串 + headers = {'X-FunArt-Comfy-UserId': ' '} + with app.test_request_context(headers=headers): + result = extract_user_from_header() + assert result is None + + +class TestSetUserIdentityOrDefault: + """测试 set_user_identity_or_default 函数""" + + def test_multi_tenant_with_valid_user(self, app): + """测试多租户模式下有效用户""" + username = 'user-test' + + with patch('utils.user_identity.constants') as mock_constants: + mock_constants.ENABLE_COMFYUI_MULTI_USER = True + mock_constants.HEADER_FUNART_COMFY_USERID = 'X-FunArt-Comfy-UserId' + + headers = {'X-FunArt-Comfy-UserId': username} + + with app.test_request_context(headers=headers): + set_user_identity_or_default() + assert g.user_id == username + + def test_multi_tenant_without_user_fallback_to_default(self, app): + """测试多租户模式下无用户信息时降级到 default""" + with patch('utils.user_identity.constants') as mock_constants: + mock_constants.ENABLE_COMFYUI_MULTI_USER = True + mock_constants.HEADER_FUNART_COMFY_USERID = 'X-FunArt-Comfy-UserId' + mock_constants.DEFAULT_USER_ID = 'default' + + with app.test_request_context(): + set_user_identity_or_default() + assert g.user_id == 'default' + + def test_single_tenant_always_default(self, app): + """测试单租户模式下总是使用 default""" + with patch('utils.user_identity.constants') as mock_constants: + mock_constants.ENABLE_COMFYUI_MULTI_USER = False + mock_constants.DEFAULT_USER_ID = 'default' + + # 即使有用户信息也使用 default + headers = {'X-FunArt-Comfy-UserId': 'some-user'} + with app.test_request_context(headers=headers): + set_user_identity_or_default() + assert g.user_id == 'default' + + +class TestIdentifyUserOrDefaultDecorator: + """测试 identify_user_or_default 装饰器""" + + def test_identify_user_or_default_with_valid_user(self, app): + """测试有效用户使用 identify_user_or_default 装饰器""" + username = 'user-optional' + + @identify_user_or_default + def test_view(): + return f"User: {g.user_id}" + + with patch('utils.user_identity.constants') as mock_constants: + mock_constants.ENABLE_COMFYUI_MULTI_USER = True + mock_constants.HEADER_FUNART_COMFY_USERID = 'X-FunArt-Comfy-UserId' + + headers = {'X-FunArt-Comfy-UserId': username} + + with app.test_request_context(headers=headers): + result = test_view() + assert result == f"User: {username}" + assert g.user_id == username + + def test_identify_user_or_default_fallback_scenarios(self, app): + """测试 identify_user_or_default 在各种降级场景下的行为""" + @identify_user_or_default + def test_view(): + return f"User: {g.user_id}" + + with patch('utils.user_identity.constants') as mock_constants: + # 场景1: 多租户模式,无认证信息 + mock_constants.ENABLE_COMFYUI_MULTI_USER = True + mock_constants.HEADER_FUNART_COMFY_USERID = 'X-FunArt-Comfy-UserId' + mock_constants.DEFAULT_USER_ID = 'default' + with app.test_request_context(): + result = test_view() + assert result == "User: default" + assert g.user_id == 'default' + + # 场景2: 单租户模式,忽略认证信息 + mock_constants.ENABLE_COMFYUI_MULTI_USER = False + headers = {'X-FunArt-Comfy-UserId': 'some-user'} + with app.test_request_context(headers=headers): + result = test_view() + assert result == "User: default" + assert g.user_id == 'default' diff --git a/src/code/agent/utils/user_identity.py b/src/code/agent/utils/user_identity.py new file mode 100644 index 00000000..2c3ff42d --- /dev/null +++ b/src/code/agent/utils/user_identity.py @@ -0,0 +1,124 @@ +""" +用户身份识别模块 + +支持多种方式提取用户信息(认证由网关层负责): +1. JWT 方式:从 X-FunArt-Comfy-UserId header 提取(网关解析 JWT 后注入) +2. Basic Auth 方式:从 Authorization: Basic header 中解析 username + +通过 ENABLE_COMFYUI_MULTI_USER 环境变量控制: +- true: 启用多租户模式,支持上述两种方式提取用户信息 +- false: 单租户模式,所有用户为 'default' +""" + +import base64 +from typing import Optional +from functools import wraps +from flask import request, g, abort + +import constants +from utils.logger import log + + +# 用户身份识别相关常量 +AUTHORIZATION_HEADER = 'Authorization' +BASIC_AUTH_PREFIX = 'Basic ' +DEFAULT_USER_ID = 'default' + + +def extract_user_from_basic_auth() -> Optional[str]: + """ + 从 Authorization: Basic header 中提取用户名 + + Returns: + Optional[str]: 用户ID,解析失败时返回 None + """ + auth_header = request.headers.get(AUTHORIZATION_HEADER, '') + + if not auth_header.startswith(BASIC_AUTH_PREFIX): + return None + + try: + # 解码 Base64 + encoded_credentials = auth_header[len(BASIC_AUTH_PREFIX):] # 去掉 'Basic ' 前缀 + decoded_credentials = base64.b64decode(encoded_credentials).decode('utf-8') + + # 格式: username:password + if ':' not in decoded_credentials: + log("WARNING", "Basic Auth format error: missing colon separator") + return None + + username, _ = decoded_credentials.split(':', 1) + username = username.strip() + + if not username: + log("WARNING", "Basic Auth format error: empty username") + return None + + return username + + except (ValueError, UnicodeDecodeError, AttributeError) as e: + log("WARNING", f"Basic Auth parse error: {type(e).__name__}: {e}") + return None + except Exception as e: + log("ERROR", f"Unexpected error in extract_user_from_basic_auth: {type(e).__name__}: {e}") + return None + + +def extract_user_from_header() -> Optional[str]: + """ + 提取用户ID,支持两种认证方式 + + - JWT 方式:从 X-FunArt-Comfy-UserId header 获取 + - Basic Auth 方式:从 Authorization: Basic header 解析 + + Returns: + Optional[str]: 用户ID,无认证信息时返回 None + """ + jwt_user = request.headers.get(constants.HEADER_FUNART_COMFY_USERID, '').strip() + if jwt_user: + log("DEBUG", f"User extracted from JWT: {jwt_user}") + return jwt_user + + basic_user = extract_user_from_basic_auth() + if basic_user is not None: + log("DEBUG", f"User extracted from Basic Auth: {basic_user}") + return basic_user + + # 无认证信息 + log("DEBUG", "No user authentication info found in request headers") + return None + + +def set_user_identity_or_default(): + """ + 设置用户身份到 flask.g.user_id,如果无法识别则降级到默认用户 + + 在多租户模式下: + - 尝试从请求中识别用户身份 + - 如果无法识别,降级为 'default' 用户 + + 在单租户模式下: + - 所有用户统一为 'default' + + 使用场景: + - 作为中间件在 before_request 中调用 + - 作为装饰器的内部实现 + """ + if not constants.ENABLE_COMFYUI_MULTI_USER: + g.user_id = DEFAULT_USER_ID + else: + uid = extract_user_from_header() + g.user_id = uid if uid is not None else DEFAULT_USER_ID + + +def identify_user_or_default(func): + """ + 装饰器:识别用户身份,如果无法识别则降级到默认用户 + + 这是 set_user_identity_or_default() 的装饰器版本 + """ + @wraps(func) + def decorated_function(*args, **kwargs): + set_user_identity_or_default() + return func(*args, **kwargs) + return decorated_function diff --git a/src/code/comfyui/Dockerfile b/src/code/comfyui/Dockerfile index 4c7ec57c..01719bff 100644 --- a/src/code/comfyui/Dockerfile +++ b/src/code/comfyui/Dockerfile @@ -14,6 +14,8 @@ RUN cd ${COMFYUI_DIR}/custom_nodes && \ cd ComfyUI-Manager && \ git checkout "3.39" +COPY custom_nodes/ ${COMFYUI_DIR}/custom_nodes/ + FROM python:3.10.16-slim AS dependencies WORKDIR /root diff --git a/src/code/comfyui/Dockerfile.deepgpu b/src/code/comfyui/Dockerfile.deepgpu index baabef7d..cf1770fa 100644 --- a/src/code/comfyui/Dockerfile.deepgpu +++ b/src/code/comfyui/Dockerfile.deepgpu @@ -27,6 +27,8 @@ RUN cd ${COMFYUI_DIR}/custom_nodes && \ wget https://aiacc-inference-public-v2.oss-cn-hangzhou.aliyuncs.com/deepgpu/comfyui/nodes/20251013/ComfyUI-deepgpu.tar.gz && \ tar -zxf ComfyUI-deepgpu.tar.gz +COPY custom_nodes/ ${COMFYUI_DIR}/custom_nodes/ + FROM python:3.10.16-slim AS dependencies diff --git a/src/code/comfyui/Dockerfile.nunchaku b/src/code/comfyui/Dockerfile.nunchaku index d5011824..948cd8cb 100644 --- a/src/code/comfyui/Dockerfile.nunchaku +++ b/src/code/comfyui/Dockerfile.nunchaku @@ -19,6 +19,8 @@ RUN cd ${COMFYUI_DIR}/custom_nodes && \ cd ComfyUI-nunchaku && \ git checkout "v1.0.2" +COPY custom_nodes/ ${COMFYUI_DIR}/custom_nodes/ + FROM python:3.10.16-slim AS dependencies WORKDIR /root diff --git a/src/code/comfyui/Makefile b/src/code/comfyui/Makefile index e6cfaf35..5edb6378 100644 --- a/src/code/comfyui/Makefile +++ b/src/code/comfyui/Makefile @@ -1,8 +1,8 @@ # 定义变量 REGION ?= cn-hangzhou -COMFYUI_IMAGE = comfyui:v0.3.77-nunchaku +COMFYUI_IMAGE = comfyui:v0.3.77-beta OSS_BUCKET ?= dipper-cache-$(REGION) -OSS_COMFYUI_BASE_DIR = base/comfyui/v0.3.77-alpha-nunchaku +OSS_COMFYUI_BASE_DIR = base/comfyui/v0.3.77-beta .PHONY: upgrade upgrade: build upload-base diff --git a/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/__init__.py b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/__init__.py new file mode 100644 index 00000000..73e4440c --- /dev/null +++ b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/__init__.py @@ -0,0 +1,84 @@ +""" +ComfyUI Multi-User Support Plugin +================================== + +A ComfyUI plugin that provides per-user asset directory isolation. +Enables multi-user support by automatically segregating input/output/temp directories. + +Features: + - Per-user directory isolation (input/output/temp/user) + - Automatic cache isolation between users + - HTTP request integration via X-FunArt-Comfy-UserId header + - Thread-safe and async-safe user context management + - Zero dependencies beyond ComfyUI core + +Usage: + 1. Set environment variable: ENABLE_COMFYUI_MULTI_USER=true + 2. Start ComfyUI: python main.py + 3. Send requests with X-FunArt-Comfy-UserId header or x-funart-comfy-userid in extra_data + +Example: + ```python + from FunArt-ComfyUI-Multi-User.core.context import UserContext + + with UserContext("user_001"): + # All operations use user_001's directories + execute_workflow(workflow_data) + ``` + +Author: FunArt Team +License: MIT +Version: 1.0.0 +""" + +import os + +from .context import ( + set_current_user, + get_current_user, + clear_current_user, + UserContext, +) + +from .folder_paths_patch import install_folder_paths_patch +from .execution_patch import install_execution_patch +from .cache_signature_patch import install_cache_signature_patch +from .server_patch import install_server_middleware + +# Plugin metadata +__version__ = "1.0.0" +__author__ = "FunArt Team" +__license__ = "MIT" + +# ComfyUI standard exports +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +# Check if plugin is enabled via environment variable +ENABLE_PLUGIN = os.getenv('ENABLE_COMFYUI_MULTI_USER', 'false').lower() == 'true' + +if ENABLE_PLUGIN: + try: + # Install all patches + install_folder_paths_patch() + install_execution_patch() + install_cache_signature_patch() + install_server_middleware() + print(f"[ComfyUI-Multi-User] ✅ 插件已启用并安装成功 (v{__version__})") + except Exception as e: + print(f"[ComfyUI-Multi-User] ❌ 插件安装失败: {e}") + import traceback + traceback.print_exc() +else: + print("[ComfyUI-Multi-User] ⚠️ 插件未启用") + +# Export public API +__all__ = [ + 'NODE_CLASS_MAPPINGS', + 'NODE_DISPLAY_NAME_MAPPINGS', + 'set_current_user', + 'get_current_user', + 'clear_current_user', + 'UserContext', + '__version__', +] diff --git a/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/cache_signature_patch.py b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/cache_signature_patch.py new file mode 100644 index 00000000..9ff02752 --- /dev/null +++ b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/cache_signature_patch.py @@ -0,0 +1,90 @@ +""" +Monkey Patch for Cache Signature - User-Isolated Caching (Universal Solution) + +Problem: + - ComfyUI's cache signatures only include node inputs + - extra_data (including userId) is not part of inputs + - Different users executing the same workflow reuse cache incorrectly + +Solution: + - Patch the cache signature generation function get_immediate_node_signature + - Include current user_id in the signature + - Automatically covers all nodes (official + custom) +""" + +from .context import get_current_user + + +# Store original function +_original_get_immediate_node_signature = None + + +async def _patched_get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + """ + Patched version of get_immediate_node_signature. + + Adds user_id to the original signature to ensure cache isolation between users. + + Working principle: + 1. Get current user_id BEFORE await (防止协程切换导致 context 变化) + 2. Call original function to get base signature + 3. Append user_id to signature + 4. Different users -> different signatures -> no cache reuse + + Args: + self: CacheKeySetInputSignature instance + dynprompt: Dynamic prompt object + node_id: Node identifier + ancestor_order_mapping: Ancestor ordering map + + Returns: + Modified signature with user_id appended + """ + user_id = get_current_user() + + # Get base signature from original function + signature = await _original_get_immediate_node_signature( + self, dynprompt, node_id, ancestor_order_mapping + ) + + # For ALL users (including 'default'), append user_id to signature + # This ensures different users have different signatures and won't share cache + signature = list(signature) if not isinstance(signature, list) else signature + signature.append(('__comfyui_user_id__', user_id)) + + return signature + + +def install_cache_signature_patch(): + """ + Install cache signature patch. + + Patches get_immediate_node_signature function to include user_id + in cache signatures for all nodes. + + Advantages: + - Automatically covers all nodes (official + custom) + - Zero-intrusion, no need to modify any node code + - Only need to patch one place + """ + global _original_get_immediate_node_signature + + try: + # Import caching module + from comfy_execution.caching import CacheKeySetInputSignature # type: ignore + + # Save original function + _original_get_immediate_node_signature = CacheKeySetInputSignature.get_immediate_node_signature + + # Replace with patched version + CacheKeySetInputSignature.get_immediate_node_signature = _patched_get_immediate_node_signature + + except ImportError: + pass # comfy_execution.caching 模块不存在(旧版本 ComfyUI) + except Exception as e: + print(f"[ComfyUI-User] ❌ Cache signature patch 安装失败: {e}") + import traceback + traceback.print_exc() + + +__all__ = ['install_cache_signature_patch'] diff --git a/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/context.py b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/context.py new file mode 100644 index 00000000..efced2ed --- /dev/null +++ b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/context.py @@ -0,0 +1,99 @@ +""" +User Context Management - Thread and Async Safe User Isolation + +Features: + - Use contextvars for async/await safe context storage + - Provide context manager for easy user switching + - Support nested context with proper cleanup +""" + +import contextvars + +# ============================================ +# Context Variable Storage: Store current context's user ID +# ============================================ + +_current_user: contextvars.ContextVar[str] = contextvars.ContextVar( + 'current_user', + default='default' +) + + +def set_current_user(user_id: str): + """ + Set the current context's user ID. + + Safe for both threading and async/await scenarios. + + Args: + user_id: User identifier (e.g., "user_001" or "john@email.com") + """ + _current_user.set(user_id) + + +def get_current_user() -> str: + """ + Get the current context's user ID. + + Safe for both threading and async/await scenarios. + + Returns: + User ID, defaults to "default" if not set + """ + return _current_user.get() + + +def clear_current_user(): + """ + Clear the current context's user ID by resetting to default. + + Safe for both threading and async/await scenarios. + """ + _current_user.set('default') + + +# ============================================ +# Context Manager +# ============================================ + +class UserContext: + """ + User context manager. + + Usage: + with UserContext("user_001"): + # Within this block, all asset operations use user_001's directories + execute_workflow() + """ + + def __init__(self, user_id: str): + """ + Initialize user context. + + Args: + user_id: User identifier to set as current + """ + self.user_id = user_id + self.previous_user_id = None + + def __enter__(self): + """Enter context - save previous user and set new user.""" + self.previous_user_id = get_current_user() + set_current_user(self.user_id) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context - restore previous user.""" + if self.previous_user_id: + set_current_user(self.previous_user_id) + else: + clear_current_user() + return False + + +__all__ = [ + 'set_current_user', + 'get_current_user', + 'clear_current_user', + 'UserContext', +] diff --git a/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/execution_patch.py b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/execution_patch.py new file mode 100644 index 00000000..eedf187f --- /dev/null +++ b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/execution_patch.py @@ -0,0 +1,63 @@ +""" +Monkey Patch for Execution - User Isolation in Workflow Execution + +This module patches PromptExecutor.execute to support cross-thread user_id propagation. +It ensures that the correct user context is maintained during workflow execution. +""" + +from .context import set_current_user, clear_current_user + +# Store original function +_original_prompt_executor_execute = None + + +def _patched_prompt_executor_execute(self, prompt, prompt_id, extra_data=None, execute_outputs=None): + """ + Patched version of PromptExecutor.execute. + + Extracts user_id from extra_data and sets it in the current thread context. + + Args: + self: PromptExecutor instance + prompt: Workflow prompt data + prompt_id: Unique prompt identifier + extra_data: Additional data, may contain 'X-FunArt-Comfy-UserId' + execute_outputs: Outputs to execute + + Returns: + Result from the original execute function + """ + extra_data = extra_data or {} + execute_outputs = execute_outputs or [] + + # 从 extra_data 中提取 user_id + # 注意:使用小写的 key(与 HTTP header 保持一致) + user_id = extra_data.get('x-funart-comfy-userid', 'default') + + set_current_user(user_id) + + try: + return _original_prompt_executor_execute(self, prompt, prompt_id, extra_data, execute_outputs) + finally: + clear_current_user() + + +def install_execution_patch(): + """ + Install execution monkey patch. + + Replaces PromptExecutor.execute with a patched version that handles user context. + """ + global _original_prompt_executor_execute + + try: + import execution # type: ignore + _original_prompt_executor_execute = execution.PromptExecutor.execute + execution.PromptExecutor.execute = _patched_prompt_executor_execute + except Exception as e: + print(f"[ComfyUI-User] ❌ Execution patch 安装失败: {e}") + import traceback + traceback.print_exc() + + +__all__ = ['install_execution_patch'] diff --git a/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/folder_paths_patch.py b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/folder_paths_patch.py new file mode 100644 index 00000000..da703c80 --- /dev/null +++ b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/folder_paths_patch.py @@ -0,0 +1,192 @@ +""" +Monkey Patch for folder_paths - User Asset Directory Isolation + +Features: + 1. Intercept folder_paths.get_*_directory functions + 2. Return user-specific asset directories based on current user ID + 3. Support automatic isolation for input/output/temp/user directories + 4. Use dynamic path proxy to solve node instance caching issues +""" + +import os +import threading + +from .context import get_current_user +from .path_proxy import DynamicPathProxy + + +# Store original functions +_original_functions = {} +_original_get_save_image_path = None + +# 目录创建缓存和锁 +_dir_cache = set() +_dir_lock = threading.Lock() + + +def _get_user_directory(base_dir: str, user_id: str) -> str: + """ + Get user-specific directory. + + Args: + base_dir: Base directory path + user_id: User identifier + + Returns: + User directory path in format: base_dir/users/{user_id} + """ + if user_id == 'default': + return base_dir + + user_dir = os.path.join(base_dir, "users", user_id) + + # 使用缓存避免重复创建,提高性能 + if user_dir not in _dir_cache: + with _dir_lock: + # 双重检查 + if user_dir not in _dir_cache: + os.makedirs(user_dir, exist_ok=True) + _dir_cache.add(user_dir) + + return user_dir + + +def _patched_get_input_directory(): + """ + Patched version of get_input_directory. + + Returns: + User-specific input directory + """ + base_dir = _original_functions['get_input_directory']() + user_id = get_current_user() + return _get_user_directory(base_dir, user_id) + + +def _patched_get_output_directory(): + """ + Patched version of get_output_directory. + + Returns DynamicPathProxy to solve node instance caching issues. + + Returns: + DynamicPathProxy that computes user-specific output directory on use + """ + return DynamicPathProxy(_patched_get_output_directory_real) + + +def _patched_get_output_directory_real(): + """ + Actual path computation function for output directory. + + Returns: + User-specific output directory path + """ + base_dir = _original_functions['get_output_directory']() + user_id = get_current_user() + return _get_user_directory(base_dir, user_id) + + +def _patched_get_temp_directory(): + """ + Patched version of get_temp_directory. + + Returns DynamicPathProxy to solve node instance caching issues. + + Returns: + DynamicPathProxy that computes user-specific temp directory on use + """ + return DynamicPathProxy(_patched_get_temp_directory_real) + + +def _patched_get_temp_directory_real(): + """ + Actual path computation function for temp directory. + + Returns: + User-specific temp directory path + """ + base_dir = _original_functions['get_temp_directory']() + user_id = get_current_user() + return _get_user_directory(base_dir, user_id) + + +def _patched_get_user_directory(): + """ + Patched version of get_user_directory. + + Returns user-specific directory for storing workflows and user data. + Compatible with ComfyUI's --multi-user mode. + + Returns: + User-specific user directory path + """ + user_id = get_current_user() + base_dir = _original_functions['get_user_directory']() + + # Check if running with --multi-user flag + try: + from comfy import cli_args # type: ignore + if cli_args.args.multi_user: + return base_dir + except (ImportError, AttributeError): + pass + + return _get_user_directory(base_dir, user_id) + + +def _patched_get_save_image_path(filename_prefix: str, output_dir, image_width=0, image_height=0): + """ + Patched version of get_save_image_path. + + Ensures output_dir is properly converted to string before processing. + + Args: + filename_prefix: Prefix for the saved image filename + output_dir: Output directory (may be DynamicPathProxy or string) + image_width: Width of the image + image_height: Height of the image + + Returns: + Result from original get_save_image_path + """ + # Convert output_dir to string + if hasattr(output_dir, '__fspath__'): + output_dir = os.fspath(output_dir) + elif not isinstance(output_dir, str): + output_dir = str(output_dir) + + return _original_get_save_image_path(filename_prefix, output_dir, image_width, image_height) + + +def install_folder_paths_patch(): + """ + Install folder_paths monkey patch. + + Replaces folder_paths module's directory retrieval functions to automatically + return user-specific subdirectories based on user ID. + """ + global _original_get_save_image_path + + try: + import folder_paths # type: ignore + except ImportError: + print("[ComfyUI-User] Folder paths 模块未找到,跳过安装") + return + + # Save original functions + _original_functions['get_input_directory'] = folder_paths.get_input_directory + _original_functions['get_output_directory'] = folder_paths.get_output_directory + _original_functions['get_temp_directory'] = folder_paths.get_temp_directory + _original_functions['get_user_directory'] = folder_paths.get_user_directory + _original_get_save_image_path = folder_paths.get_save_image_path + + # Replace with patched versions + folder_paths.get_input_directory = _patched_get_input_directory + folder_paths.get_output_directory = _patched_get_output_directory + folder_paths.get_temp_directory = _patched_get_temp_directory + folder_paths.get_user_directory = _patched_get_user_directory + folder_paths.get_save_image_path = _patched_get_save_image_path + + +__all__ = ['install_folder_paths_patch'] diff --git a/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/path_proxy.py b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/path_proxy.py new file mode 100644 index 00000000..4f44f602 --- /dev/null +++ b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/path_proxy.py @@ -0,0 +1,133 @@ +""" +Dynamic Path Proxy - Solving Node Instance Caching Issues + +This module provides a dynamic path proxy that computes the actual path +based on the current user context when accessed, rather than when assigned. + +This solves the problem where ComfyUI caches node instances, which would +otherwise lock paths to a specific user. +""" + +import os + + +class DynamicPathProxy: + """ + Dynamic path proxy object. + + Acts like a string but computes the actual path dynamically based on + the current user_id when the path value is actually needed. + + This ensures each user gets their own directory even when node instances are reused. + """ + + def __init__(self, path_getter): + """ + Initialize the dynamic path proxy. + + Args: + path_getter: Callable that returns the current actual path + """ + self._path_getter = path_getter + + def _get_path(self): + """Internal method to get the current path.""" + return self._path_getter() + + def __str__(self): + """Return the current user's actual path when converted to string.""" + return self._get_path() + + def __repr__(self): + """Return representation of the proxy.""" + return f"DynamicPathProxy({self._get_path()})" + + def __fspath__(self): + """Support os.path operations (os.fspath, os.path.join, etc.).""" + return self._get_path() + + # Path concatenation support + def __truediv__(self, other): + """Support / operator: path / "subdir".""" + return os.path.join(self._get_path(), other) + + def __rtruediv__(self, other): + """Support reverse / operator: "subdir" / path.""" + return os.path.join(other, self._get_path()) + + # String operation support + def __add__(self, other): + """Support string concatenation.""" + return self._get_path() + other + + def __radd__(self, other): + """Support reverse string concatenation.""" + return other + self._get_path() + + def __eq__(self, other): + """Support equality comparison.""" + return self._get_path() == other + + def __ne__(self, other): + """Support inequality comparison.""" + return self._get_path() != other + + def __hash__(self): + """ + DynamicPathProxy 的值依赖于当前 user_id,会动态变化。 + 根据 Python 规范,可变对象不应作为 dict key。 + + 如果需要将路径存储为 key,请先转换为字符串: + my_dict[str(proxy)] = value # ✅ 正确 + 而不是: + my_dict[proxy] = value # ❌ 错误 + """ + raise TypeError( + f"unhashable type: '{type(self).__name__}'. " + f"DynamicPathProxy 不能作为 dict key" + ) + + def __len__(self): + """Support len().""" + return len(self._get_path()) + + def __getitem__(self, key): + """Support indexing and slicing.""" + return self._get_path()[key] + + def __contains__(self, item): + """Support 'in' operator.""" + return item in self._get_path() + + # Common string methods + def startswith(self, prefix): + """Check if path starts with prefix.""" + return self._get_path().startswith(prefix) + + def endswith(self, suffix): + """Check if path ends with suffix.""" + return self._get_path().endswith(suffix) + + def split(self, *args, **kwargs): + """Split the path string.""" + return self._get_path().split(*args, **kwargs) + + def replace(self, *args, **kwargs): + """Replace substring in path.""" + return self._get_path().replace(*args, **kwargs) + + def format(self, *args, **kwargs): + """Format the path string.""" + return self._get_path().format(*args, **kwargs) + + def __index__(self): + """Explicitly disallow integer conversion.""" + raise TypeError(f"'{type(self).__name__}' object cannot be interpreted as an integer") + + @classmethod + def __class_getitem__(cls, item): + """Support type checking.""" + return cls + + +__all__ = ['DynamicPathProxy'] diff --git a/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/server_patch.py b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/server_patch.py new file mode 100644 index 00000000..b233a9c7 --- /dev/null +++ b/src/code/comfyui/custom_nodes/FunArt-ComfyUI-Multi-User/server_patch.py @@ -0,0 +1,123 @@ +""" +Monkey Patch for Server - User Isolation in HTTP Requests + +This module patches ComfyUI's server to automatically extract and set user_id from HTTP headers. +""" + +from .context import set_current_user, clear_current_user + +_hook_installed = False +_original_add_routes = None + + +def install_server_middleware(): + """ + Install server hook to extract user_id from HTTP headers. + + Automatically wraps all route handlers to read user_id from request headers. + + Note: This function is safe to call multiple times (idempotent). + """ + global _hook_installed, _original_add_routes + + if _hook_installed: + return + + try: + import server # type: ignore + + # Save original add_routes method + _original_add_routes = server.PromptServer.add_routes + + def patched_add_routes(self): + """ + Patched add_routes - wraps handlers after route registration. + """ + # Call original add_routes first + result = _original_add_routes(self) + + # Wrap all route handlers + wrap_route_handlers(self.app) + + return result + + # Replace add_routes method + server.PromptServer.add_routes = patched_add_routes + _hook_installed = True + + except Exception as e: + print(f"[ComfyUI-User] Server hook 安装失败: {e}") + + +def wrap_route_handlers(app): + """ + Wrap all route handlers in the application. + + Makes each handler automatically extract user_id from request headers + and set it in the context. + + Design Philosophy: + - Keep it simple: 只检查标记,够用了 + - contextvars 本身是线程安全的,不需要额外的锁 + - ComfyUI 启动时只调用一次 add_routes,不用担心重复包装 + - 即使万一重复包装,contextvars 也能正确处理(set 多次没问题) + + Args: + app: aiohttp Application instance + """ + wrapped_count = 0 + skipped_count = 0 + + for resource in app.router.resources(): + for route in resource: + try: + # Get handler + if not hasattr(route, '_handler'): + continue + + original_handler = route._handler + + # Skip if not callable + if not callable(original_handler): + continue + + # Simple check: 如果已经包装过,跳过 + if getattr(original_handler, '_comfyui_user_wrapped', False): + skipped_count += 1 + continue + + # Create wrapper - 使用工厂函数确保正确的闭包 + def make_wrapper(handler): + """简单的包装器工厂 - 确保每个包装器捕获正确的 handler""" + async def wrapped_handler(request): + # Extract user_id from headers + user_id = request.headers.get('X-FunArt-Comfy-UserId', 'default') + + # Set in context (contextvars 自动处理线程/协程隔离) + set_current_user(user_id) + + try: + return await handler(request) + finally: + clear_current_user() + + # Mark as wrapped + wrapped_handler._comfyui_user_wrapped = True + return wrapped_handler + + # Replace handler + route._handler = make_wrapper(original_handler) + wrapped_count += 1 + + except Exception: + # 静默失败,不影响其他路由 + continue + + # 简单的日志 + if wrapped_count > 0: + print(f"[ComfyUI-User] 已包装 {wrapped_count} 个路由处理器") + if skipped_count > 0: + print(f"[ComfyUI-User] 跳过 {skipped_count} 个已包装的处理器") + + +__all__ = ['install_server_middleware']