diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh index 38befef2..14966ce9 100644 --- a/cookbook/client/server/megatron/run.sh +++ b/cookbook/client/server/megatron/run.sh @@ -1,6 +1,341 @@ +#!/bin/bash + +# ============================================ +# Twinkle Megatron 服务启动脚本 +# ============================================ +# 功能:启动 Ray 集群(支持多 GPU/CPU 节点)、Prometheus 监控和 Twinkle 服务器 +# +# 用法:./run.sh [选项] +# +# 选项: +# --head NODE Head 节点 GPU 配置,格式 "设备列表:数量" (默认: 0,1,2,3:4) +# --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7:4) +# --cpu-workers N CPU Worker 数量 (默认: 1) +# --temp-dir DIR Ray 临时目录 (默认: /dashscope/caches/application/ray_logs) +# --help 显示帮助信息 +# +# 示例: +# ./run.sh # 使用默认配置 +# ./run.sh --head "0,1,2,3" --gpu-workers "4,5,6,7" --cpu-workers 1 +# ./run.sh --head "0,1,2,3" --gpu-workers "" --cpu-workers 0 +# ./run.sh --head "" --cpu-workers 4 # 纯 CPU 模式 +# ./run.sh --temp-dir /tmp/my_ray_logs # 自定义临时目录 +# ============================================ + +set -e # 遇到错误立即退出 + +# ============================================ +# 配置区(根据你的环境修改) +# ============================================ + +# --- Ray 集群配置 --- +# Head 节点(必须是第一个启动) +# 格式:"GPU设备列表:GPU数量",如 "0,1,2,3:4" +# 如果不需要 GPU,设为空字符串 "" +# 可通过命令行参数 $1 传入 + +# GPU Worker 节点列表(可以有多个) +# 格式:用分号分隔的 "GPU设备列表:GPU数量" +# 示例:"4,5,6,7:4" 或 "4,5,6,7:4;8,9,10,11:4" +# 可通过命令行参数 $2 传入 + +# CPU Worker 数量 +# 可通过命令行参数 $3 传入 + +# --- 网络配置 --- +RAY_PORT=6379 +RAY_ADDRESS="127.0.0.1:$RAY_PORT" + +# --- 路径配置 --- +DEFAULT_TEMP_DIR="/dashscope/caches/application/ray_logs" +LOG_FILE="run.log" + +# --- Prometheus 监控配置 --- +PROMETHEUS_BIN="/dashscope/caches/application/monitor/prometheus-3.10.0.linux-amd64/prometheus" +PROMETHEUS_CONFIG_SUFFIX="session_latest/metrics/prometheus/prometheus.yml" + +# --- Ray 日志轮转配置 --- export RAY_ROTATION_MAX_BYTES=1024 export RAY_ROTATION_BACKUP_COUNT=1 -CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false -CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4 -CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 -python server.py + +# ============================================ +# 参数解析(支持 --key=value 或 --key value 格式) +# ============================================ + +# 默认值 +HEAD_NODE="0,1,2,3" +GPU_WORKERS_INPUT="4,5,6,7" +CPU_WORKER_COUNT="1" +TEMP_DIR="$DEFAULT_TEMP_DIR" + +# 解析命名参数 +while [[ $# -gt 0 ]]; do + case $1 in + --head) + HEAD_NODE="$2" + shift 2 + ;; + --head=*) + HEAD_NODE="${1#*=}" + shift + ;; + --gpu-workers) + GPU_WORKERS_INPUT="$2" + shift 2 + ;; + --gpu-workers=*) + GPU_WORKERS_INPUT="${1#*=}" + shift + ;; + --cpu-workers) + CPU_WORKER_COUNT="$2" + shift 2 + ;; + --cpu-workers=*) + CPU_WORKER_COUNT="${1#*=}" + shift + ;; + --temp-dir) + TEMP_DIR="$2" + shift 2 + ;; + --temp-dir=*) + TEMP_DIR="${1#*=}" + shift + ;; + --help|-h) + echo "用法: ./run.sh [选项]" + echo "" + echo "选项:" + echo " --head NODE Head 节点 GPU 设备列表,逗号分隔 (默认: 0,1,2,3)" + echo " --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7)" + echo " --cpu-workers N CPU Worker 数量 (默认: 1)" + echo " --temp-dir DIR Ray 临时目录" + echo " --help, -h 显示帮助信息" + echo "" + echo "示例:" + echo " ./run.sh # 默认配置" + echo " ./run.sh --head '0,1,2,3' --gpu-workers '4,5,6,7'" + echo " ./run.sh --head '0,1,2,3,4,5,6,7' # 单机 8 卡" + echo " ./run.sh --gpu-workers '4,5,6,7;8,9,10,11' # 多 GPU Worker" + echo " ./run.sh --cpu-workers 4 --head '' # 纯 CPU 模式" + exit 0 + ;; + *) + print_error "未知参数: $1" + echo "使用 --help 查看帮助" + exit 1 + ;; + esac +done + +# 将分号分隔的字符串转为数组 +if [ -z "$GPU_WORKERS_INPUT" ]; then + GPU_WORKERS=() +else + IFS=';' read -ra GPU_WORKERS <<< "$GPU_WORKERS_INPUT" +fi + +PROMETHEUS_CONFIG="${TEMP_DIR}/${PROMETHEUS_CONFIG_SUFFIX}" + +# ============================================ +# 辅助函数 +# ============================================ +print_info() { + echo -e "\033[36m[INFO]\033[0m $1" +} + +print_success() { + echo -e "\033[32m[SUCCESS]\033[0m $1" +} + +print_warning() { + echo -e "\033[33m[WARNING]\033[0m $1" +} + +print_error() { + echo -e "\033[31m[ERROR]\033[0m $1" +} + +print_separator() { + echo "============================================" +} + +print_header() { + echo "" + print_separator + echo -e "\033[1;34m $1 \033[0m" + print_separator +} + +# 解析节点配置 "devices" -> 返回 devices 和自动计算 _gpu_count +# 示例: "0,1,2,3" -> devices="0,1,2,3", count=4 +parse_node_config() { + local config="$1" + if [ -z "$config" ]; then + _gpu_devices="" + _gpu_count=0 + return + fi + _gpu_devices="$config" + # 通过逗号数量+1计算 GPU 数量 + local comma_count=$(echo "$config" | tr -cd ',' | wc -c) + _gpu_count=$((comma_count + 1)) +} + +# ============================================ +# 开始启动 +# ============================================ +print_header "Twinkle Megatron 服务启动脚本" + +# 打印配置信息 +print_info "集群配置:" +echo "" + +# 解析并显示 Head 节点 +parse_node_config "$HEAD_NODE" +if [ -n "$_gpu_devices" ]; then + echo " [Head 节点]" + echo " - GPU 设备: $_gpu_devices" + echo " - GPU 数量: $_gpu_count" +else + echo " [Head 节点] CPU only" +fi + +# 显示 GPU Worker 节点 +if [ ${#GPU_WORKERS[@]} -gt 0 ]; then + echo "" + echo " [GPU Worker 节点] 共 ${#GPU_WORKERS[@]} 个" + for i in "${!GPU_WORKERS[@]}"; do + parse_node_config "${GPU_WORKERS[$i]}" + echo " Worker $((i+1)): GPU=$_gpu_devices, Count=$_gpu_count" + done +fi + +# 显示 CPU Worker +if [ "$CPU_WORKER_COUNT" -gt 0 ]; then + echo "" + echo " [CPU Worker 节点] $CPU_WORKER_COUNT 个" +fi + +echo "" +print_info "运行参数:" +echo " - Ray 地址: $RAY_ADDRESS" +echo " - 临时目录: $TEMP_DIR" +echo " - 日志文件: $LOG_FILE" +echo "" + +# 检查临时目录 +if [ ! -d "$TEMP_DIR" ]; then + print_info "创建临时目录: $TEMP_DIR" + mkdir -p "$TEMP_DIR" +fi + +# ============================================ +# 停止已有 Ray 集群和 Prometheus +# ============================================ +print_header "清理环境" +print_info "停止已有的 Ray 集群..." +ray stop --force 2>/dev/null || true + +print_info "停止已有的 Prometheus..." +pkill prometheus 2>/dev/null || true + +# ============================================ +# 启动 Ray Head 节点 +# ============================================ +print_header "启动 Ray 集群" + +parse_node_config "$HEAD_NODE" +if [ -n "$_gpu_devices" ]; then + print_info "启动 Head 节点 (GPU: $_gpu_devices)..." + CUDA_VISIBLE_DEVICES="$_gpu_devices" ray start --head \ + --port=$RAY_PORT \ + --num-gpus=$_gpu_count \ + --disable-usage-stats \ + --include-dashboard=true \ + --temp-dir="$TEMP_DIR" +else + print_info "启动 Head 节点 (CPU only)..." + CUDA_VISIBLE_DEVICES="" ray start --head \ + --port=$RAY_PORT \ + --num-gpus=0 \ + --disable-usage-stats \ + --include-dashboard=true \ + --temp-dir="$TEMP_DIR" +fi +print_success "Head 节点启动成功!" + +# ============================================ +# 启动 GPU Worker 节点 +# ============================================ +for i in "${!GPU_WORKERS[@]}"; do + parse_node_config "${GPU_WORKERS[$i]}" + print_info "启动 GPU Worker $((i+1)) (GPU: $_gpu_devices)..." + CUDA_VISIBLE_DEVICES="$_gpu_devices" ray start \ + --address=$RAY_ADDRESS \ + --num-gpus=$_gpu_count + print_success "GPU Worker $((i+1)) 启动成功!" +done + +# ============================================ +# 启动 CPU Worker 节点 +# ============================================ +if [ "$CPU_WORKER_COUNT" -gt 0 ]; then + print_info "启动 $CPU_WORKER_COUNT 个 CPU Worker..." + for ((i=1; i<=CPU_WORKER_COUNT; i++)); do + CUDA_VISIBLE_DEVICES="" ray start \ + --address=$RAY_ADDRESS \ + --num-gpus=0 + done + print_success "CPU Worker 启动成功!" +fi + +# ============================================ +# 显示集群状态 +# ============================================ +echo "" +print_info "集群状态:" +ray status 2>/dev/null || true + +# ============================================ +# 启动 Prometheus 监控(可选) +# ============================================ +print_header "启动监控(可选)" + +PROMETHEUS_PID="" +if [ -f "$PROMETHEUS_BIN" ]; then + print_info "检测到 Prometheus,正在启动监控服务..." + + # 等待 Ray 生成 Prometheus 配置 + sleep 2 + + if [ -f "$PROMETHEUS_CONFIG" ]; then + nohup "$PROMETHEUS_BIN" --config.file="$PROMETHEUS_CONFIG" > prometheus.log 2>&1 & + PROMETHEUS_PID=$! + print_success "Prometheus 监控已启动 (PID: $PROMETHEUS_PID)" + echo " - 监控日志: prometheus.log" + echo " - 配置文件: $PROMETHEUS_CONFIG" + else + print_warning "Prometheus 配置文件不存在,跳过监控启动" + echo " - 预期路径: $PROMETHEUS_CONFIG" + fi +else + print_warning "未检测到 Prometheus,跳过监控启动" + echo " - 预期路径: $PROMETHEUS_BIN" +fi + +# ============================================ +# 启动 Twinkle 服务器 +# ============================================ +print_header "启动 Twinkle 服务器" + +print_info "日志输出到: $LOG_FILE" +echo "" + +# 启动服务器并实时显示日志 +nohup python server.py > "$LOG_FILE" 2>&1 & +SERVER_PID=$! + +# 实时显示日志 +tail -f "$LOG_FILE" diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 6b5cce0e..82aa9b2a 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -84,7 +84,7 @@ applications: nproc_per_node: 4 # Number of GPU processes per node device_group: name: model - ranks: 4 # GPU rank indices + ranks: 4 device_type: cuda device_mesh: device_type: cuda diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index e191b981..5dd8a696 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -7,7 +7,7 @@ proxy_location: EveryNode # HTTP listener settings http_options: host: 0.0.0.0 # Listen on all network interfaces - port: 8000 # Port number for the server + port: 9000 # Port number for the server # Applications: each entry defines a service component deployed on the server applications: @@ -39,25 +39,24 @@ applications: import_path: model args: use_megatron: true - model_cls: Qwen3_5ForConditionalGeneration model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 2 # Number of GPU processes per node device_group: name: model - ranks: 2 # GPU rank indices + ranks: 2 device_type: cuda device_mesh: device_type: cuda dp_size: 2 queue_config: rps_limit: 100 # Max requests per second - tps_limit: 10000 # Max tokens per second for a single user - max_input_tokens: 10000 # Maximum input tokens per request + tps_limit: 100000 # Max tokens per second for a single user + max_input_tokens: 60000 # Maximum input tokens per request adapter_config: adapter_timeout: 30 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) - max_loras: 1 # Maximum number of LoRA adapters per model + max_loras: 5 # Maximum number of LoRA adapters per model deployments: - name: ModelManagement autoscaling_config: @@ -80,8 +79,8 @@ applications: nproc_per_node: 2 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + max_model_len: 16000 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0) enable_lora: true # Allow loading LoRA adapters during inference logprobs_mode: processed_logprobs # Logprobs mode for sampling results device_group: # Logical device group for the sampler diff --git a/cookbook/client/tinker/self_host/dpo.py b/cookbook/client/tinker/self_host/dpo.py new file mode 100644 index 00000000..e69de29b diff --git a/cookbook/client/twinkle/self_host/dpo.py b/cookbook/client/twinkle/self_host/dpo.py new file mode 100644 index 00000000..5b996f06 --- /dev/null +++ b/cookbook/client/twinkle/self_host/dpo.py @@ -0,0 +1,207 @@ +# Twinkle Client - DPO (Direct Preference Optimization) Training with LoRA +# +# This script demonstrates how to fine-tune a language model using DPO +# through the Twinkle client-server architecture. +# The server must be running first (see server.py and server_config.yaml). + +# Step 1: Load environment variables from a .env file (e.g., API tokens) +import dotenv +import os +from typing import Any, Dict, List + +dotenv.load_dotenv('.env') +import numpy as np +import torch +from peft import LoraConfig + +from twinkle import get_logger +from twinkle.dataset import Dataset, DatasetMeta +from twinkle_client import init_twinkle_client +from twinkle.dataloader import DataLoader +from twinkle_client.model import MultiLoraTransformersModel +from twinkle.loss import DPOLoss +from twinkle.metric import DPOMetric +from twinkle.preprocessor import EmojiDPOProcessor +from twinkle.processor import InputProcessor + +logger = get_logger() + +# Configuration (direct values, not from env) +base_model = 'Qwen/Qwen3.5-4B' +base_url = 'http://localhost:8000' +dataset_id = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' + +batch_size = 4 +gradient_accumulation_steps = 2 +learning_rate = 1e-4 +dpo_beta = 0.1 +sft_weight = 1.0 +loss_type = 'sigmoid' +max_length = 2048 +adapter_name = 'default' +system_prompt = 'You are a helpful assistant.' + +# Step 2: Initialize the Twinkle client to communicate with the remote server. +# - base_url: the address of the running Twinkle server +# - api_key: authentication token (loaded from environment variable) +client = init_twinkle_client(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN')) + +# Step 3: Query the server for existing training runs and their checkpoints. +# This is useful for resuming a previous training session. +runs = client.list_training_runs() + +resume_path = None +for run in runs: + logger.info(run.model_dump_json(indent=2)) + # List all saved checkpoints for this training run + checkpoints = client.list_checkpoints(run.training_run_id) + + for checkpoint in checkpoints: + logger.info(checkpoint.model_dump_json(indent=2)) + # Uncomment the line below to resume from a specific checkpoint: + # resume_path = checkpoint.twinkle_path + + +def create_dpo_dataset(): + """Create DPO dataset with positive/negative format.""" + dataset = Dataset(DatasetMeta(dataset_id, data_slice=range(600))) + dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=max_length) + dataset.map( + EmojiDPOProcessor, + init_args={ + 'system': system_prompt, + } + ) + # DPO preprocessor returns {'positive': [...], 'negative': [...]} + # batch_encode handles this format automatically + dataset.encode() + return dataset + + +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. + + Args: + batch: List of rows, each with 'positive' and 'negative' InputFeatures + and other fields (question, etc.) + + Returns: + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP + worker gets complete positive/negative pairs after slicing. + Each item contains all original fields plus the InputFeature fields. + """ + result = [] + + for row in batch: + # Get base fields (excluding positive/negative) + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + + # Positive sample: merge base fields with positive InputFeature + pos_sample = {**base_fields, **row['positive']} + # Negative sample: merge base fields with negative InputFeature + neg_sample = {**base_fields, **row['negative']} + + # Interleave: [pos, neg] per pair for DP-safe slicing + result.append(pos_sample) + result.append(neg_sample) + + return result + + +def train(): + # Step 4: Prepare the dataset + + # Load the DPO dataset from ModelScope + dataset = create_dpo_dataset() + + # Wrap the dataset into a DataLoader that yields batches + dataloader = DataLoader(dataset=dataset, batch_size=batch_size) + + # Step 5: Configure the model + + # Create a multi-LoRA Transformers model pointing to the base model on ModelScope + model = MultiLoraTransformersModel(model_id=f'ms://{base_model}') + + # Define LoRA configuration: apply low-rank adapters to all linear layers + lora_config = LoraConfig( + target_modules='all-linear', + r=8, + lora_alpha=32, + lora_dropout=0.05, + ) + + # Attach the LoRA adapter named 'default' to the model. + # gradient_accumulation_steps means gradients are accumulated over micro-batches + # before an optimizer step, effectively increasing the batch size. + model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=gradient_accumulation_steps) + + # Set the same chat template used during data preprocessing + model.set_template('Qwen3_5Template') + + # Set the input processor (pads sequences on the right side) + model.set_processor('InputProcessor', padding_side='right') + + # Use DPO loss for preference optimization + model.set_loss('DPOLoss', beta=dpo_beta, loss_type=loss_type, reference_free=False, sft_weight=sft_weight) + + # Add DPO metric for logging + model.add_metric('DPOMetric', beta=dpo_beta) + + # Use Adam optimizer with a learning rate of 1e-4 + model.set_optimizer('Adam', lr=learning_rate) + + # Step 6: Optionally resume from a previous checkpoint + if resume_path: + logger.info(f'Resuming training from {resume_path}') + model.load(resume_path, load_optimizer=True) + + # Step 7: Run the training loop + logger.info(model.get_train_configs().model_dump()) + + optim_step = 0 + max_steps = len(dataloader) + logger.info(f'Starting LoRA DPO training: loss_type={loss_type}, beta={dpo_beta}, lr={learning_rate}') + logger.info(f'Using base model (disable_lora=True) as reference model') + + for batch in dataloader: + # batch is List[Dict] with 'positive' and 'negative' keys + # Convert numpy/torch tensors to lists for serialization + for row in batch: + for key in row: + if isinstance(row[key], np.ndarray): + row[key] = row[key].tolist() + elif isinstance(row[key], torch.Tensor): + row[key] = row[key].cpu().numpy().tolist() + + dpo_batch = prepare_dpo_batch(batch) + + # Get reference outputs using base model (without LoRA adapter) + # disable_lora=True tells the model to skip LoRA and use base weights + ref_outputs = model.forward_only(inputs=dpo_batch, disable_lora=True) + model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs.result) + model.clip_grad_and_step() + + optim_step += 1 + + # Logging + if optim_step % gradient_accumulation_steps == 0: + metrics = model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step // gradient_accumulation_steps}/{max_steps}] {metrics}') + + # Step 8: Save the trained checkpoint + twinkle_path = model.save(name='dpo-lora-final', save_optimizer=True) + logger.info(f'Saved checkpoint: {twinkle_path}') + + # Step 9: Upload the checkpoint to ModelScope Hub + # YOUR_USER_NAME = "your_username" + # hub_model_id = f'{YOUR_USER_NAME}/twinkle-dpo-lora' + # model.upload_to_hub( + # checkpoint_dir=twinkle_path, + # hub_model_id=hub_model_id, + # async_upload=False + # ) + # logger.info(f"Uploaded checkpoint to hub: {hub_model_id}") + + +if __name__ == '__main__': + train() diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 44d81f82..d50837a0 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -310,6 +310,8 @@ def __call__( reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) elif ref_logps is not None: # Per-token reference log probs provided, need to align and sum + if not torch.is_tensor(ref_logps): + ref_logps = torch.as_tensor(ref_logps) ref_logps_aligned = self._align_logps(ref_logps, labels.shape, device, dtype) ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 5ce61410..8a1d4d6c 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -50,6 +50,9 @@ def _align_logps(self, logps, target_shape, device, dtype): Aligned tensor with shape matching target_shape """ import torch + + if not torch.is_tensor(logps): + logps = torch.as_tensor(logps) logps = logps.to(device=device, dtype=dtype) batch_size, src_len = logps.shape _, target_len = target_shape diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index cd942e61..4d99c9c5 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -13,6 +13,7 @@ from typing import Any import twinkle_client.types as types +from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.utils.state import get_server_state from twinkle.server.utils.validation import verify_request_token from twinkle.utils.logger import get_logger @@ -93,6 +94,8 @@ def build_server_app(deploy_options: dict[str, Any], async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + app.middleware('http')(create_metrics_middleware('Gateway')) + def get_self() -> GatewayServer: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 9271b681..9a2c2f27 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -125,3 +125,17 @@ async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) return types.CheckpointPathResponse(path=str(ckpt_dir), twinkle_path=checkpoint.twinkle_path) + + @app.get('/twinkle/status') + async def status( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> dict: + cleanup_stats = await self.state.get_cleanup_stats() + return { + 'resources': cleanup_stats['resource_counts'], + 'cleanup': { + 'running': cleanup_stats['cleanup_running'], + 'expiration_timeout': cleanup_stats['expiration_timeout'], + }, + } diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index cf084606..9d28e52b 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -128,7 +128,12 @@ def _init_ray(self) -> None: # Use runtime_env to apply patches in worker processes # This is required because Ray Serve's ProxyActor runs in separate processes runtime_env = get_runtime_env_for_patches() - ray.init(namespace=namespace, runtime_env=runtime_env) + # Connect to existing cluster if available, otherwise start local instance + ray.init( + address='auto', + namespace=namespace, + runtime_env=runtime_env, + ) logger.info(f'Ray initialized with namespace={namespace}') self._ray_initialized = True diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 8f0c6f77..243fb4a4 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -15,6 +15,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.server.utils.lifecycle import AdapterManagerMixin +from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token @@ -81,7 +82,7 @@ def __init__(self, self._replica_registered = False # Initialize mixins - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Model') self._init_adapter_manager(**adapter_config) # Note: countdown task is started lazily in _ensure_sticky() @@ -164,6 +165,8 @@ def build_model_app(model_id: str, async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + app.middleware('http')(create_metrics_middleware('Model')) + def get_self() -> ModelManagement: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 831b9468..6503cf6d 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -8,6 +8,7 @@ from twinkle import remote_class, remote_function from twinkle.data_format import InputFeature, Trajectory +from twinkle.infra import collect_tensor_dict from twinkle.model.megatron import MultiLoraMegatronModel from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, @@ -119,7 +120,13 @@ def tinker_load(self, checkpoint_dir: str, **kwargs): # Twinkle-native methods (InputFeature/Trajectory-based I/O) # ------------------------------------------------------------------ - @remote_function(dispatch='slice_dp', collect='mean') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + """Forward-only for twinkle-native clients (InputFeature/Trajectory I/O).""" + output = super().forward_only(inputs=inputs, **kwargs) + return to_cpu_safe_output(output) + + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index fe30f616..8a0716af 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -11,6 +11,7 @@ from twinkle import remote_class, remote_function from twinkle.data_format import InputFeature, Trajectory +from twinkle.infra import collect_tensor_dict from twinkle.model import MultiLoraTransformersModel from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, @@ -106,7 +107,13 @@ def tinker_load(self, checkpoint_dir: str, **kwargs): # Twinkle-native methods (InputFeature/Trajectory-based I/O) # ------------------------------------------------------------------ - @remote_function(dispatch='slice_dp', collect='mean') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + """Forward-only for twinkle-native clients (InputFeature/Trajectory I/O).""" + output = super().forward_only(inputs=inputs, **kwargs) + return to_cpu_safe_output(output) + + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 92cdd62d..40fdadbe 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -21,6 +21,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_logger from twinkle.server.utils.lifecycle import ProcessorManagerMixin +from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token from .twinkle_handlers import _register_processor_routes @@ -124,6 +125,8 @@ def build_processor_app(ncpu_proc_per_node: int, async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + app.middleware('http')(create_metrics_middleware('Processor')) + def get_self() -> ProcessorManagement: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index dc54e4f6..56458b6e 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -13,6 +13,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token @@ -80,7 +81,7 @@ def __init__(self, self.state: ServerStateProxy = get_server_state() # Initialize task queue mixin - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Sampler') @serve.multiplexed(max_num_models_per_replica=5) async def _sticky_entry(self, sticky_key: str): @@ -135,6 +136,8 @@ def build_sampler_app(model_id: str, async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + app.middleware('http')(create_metrics_middleware('Sampler')) + def get_self() -> SamplerManagement: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/utils/metrics.py b/src/twinkle/server/utils/metrics.py new file mode 100644 index 00000000..eee915d7 --- /dev/null +++ b/src/twinkle/server/utils/metrics.py @@ -0,0 +1,267 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Central metrics module for Twinkle server observability. + +Provides ray.util.metrics instruments that feed both the Ray Dashboard +(port 8265) and Prometheus (via /api/prometheus). + +All metric names use the ``twinkle_`` prefix. Metric instances are +cached per deployment to avoid duplicate registration. + +Public entry-points: + +* ``create_metrics_middleware(deployment)`` – FastAPI HTTP middleware +* ``get_task_metrics(deployment)`` – task-queue / rate-limit gauges +* ``get_resource_metrics()`` – ServerState resource gauges +""" +from __future__ import annotations + +import time +from pydantic import BaseModel, ConfigDict +from ray.util.metrics import Counter, Gauge, Histogram +from typing import Any, Callable + +from twinkle.utils.logger import get_logger + +logger = get_logger() + +# --------------------------------------------------------------------------- +# Histogram bucket boundaries (seconds) – shared by all histograms +# --------------------------------------------------------------------------- +_HISTOGRAM_BOUNDARIES = [ + 0.01, + 0.05, + 0.1, + 0.25, + 0.5, + 1.0, + 2.5, + 5.0, + 10.0, + 30.0, + 60.0, + 120.0, + 300.0, +] + +# --------------------------------------------------------------------------- +# Lazy caches – populated on first call per deployment / globally +# --------------------------------------------------------------------------- +_task_metrics_cache: dict[str, TaskMetrics] = {} +_resource_metrics_cache: ResourceMetrics | None = None +_request_metrics_cache: dict[str, _RequestMetrics] = {} + +# --------------------------------------------------------------------------- +# Pydantic models for structured metric access +# --------------------------------------------------------------------------- + + +class TaskMetrics(BaseModel): + """Task queue metrics container. + + Attributes: + queue_depth: Current number of queued tasks. + tasks_total: Total task completions. + execution_seconds: Pure task execution time in seconds. + queue_wait_seconds: Time from enqueue to execution start. + rate_limit_rejections: Total rate-limit rejections. + rate_limiter_active_tokens: Tokens tracked by rate limiter. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + queue_depth: Gauge + tasks_total: Counter + execution_seconds: Histogram + queue_wait_seconds: Histogram + rate_limit_rejections: Counter + rate_limiter_active_tokens: Gauge + + +class ResourceMetrics(BaseModel): + """Resource gauge metrics container. + + Attributes: + active_sessions: Current active session count. + active_models: Current registered model count. + active_sampling_sessions: Current sampling session count. + active_futures: Current future/request count. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + active_sessions: Gauge + active_models: Gauge + active_sampling_sessions: Gauge + active_futures: Gauge + + +class _RequestMetrics(BaseModel): + """HTTP request metrics container (internal).""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + requests_total: Counter + request_duration_seconds: Histogram + + +# --------------------------------------------------------------------------- +# A. Request-level metrics (FastAPI middleware) +# --------------------------------------------------------------------------- + + +def _get_request_metrics(deployment: str) -> _RequestMetrics: + """Return (or create) per-deployment HTTP request metrics.""" + if deployment in _request_metrics_cache: + return _request_metrics_cache[deployment] + + metrics = _RequestMetrics( + requests_total=Counter( + 'twinkle_requests_total', + description='Total HTTP requests.', + tag_keys=('deployment', 'method', 'status'), + ), + request_duration_seconds=Histogram( + 'twinkle_request_duration_seconds', + description='End-to-end HTTP request latency in seconds.', + boundaries=_HISTOGRAM_BOUNDARIES, + tag_keys=('deployment', 'method'), + ), + ) + _request_metrics_cache[deployment] = metrics + return metrics + + +def create_metrics_middleware(deployment: str) -> Callable: + """Return a FastAPI ``http`` middleware that records request metrics. + + Usage inside a ``build_*_app()`` function:: + + from twinkle.server.utils.metrics import create_metrics_middleware + metrics_mw = create_metrics_middleware("Model") + app.middleware('http')(metrics_mw) + + Because FastAPI executes middleware in LIFO order, registering this + **after** ``verify_token`` means it wraps the outermost layer and + captures full end-to-end latency including auth. + """ + + async def metrics_middleware(request: Any, call_next: Callable) -> Any: + start = time.monotonic() + response = await call_next(request) + elapsed = time.monotonic() - start + status = str(response.status_code) + method = request.scope['route'].path if 'route' in request.scope else request.url.path + m = _get_request_metrics(deployment) + m.requests_total.inc(tags={ + 'deployment': deployment, + 'method': method, + 'status': status, + }) + m.request_duration_seconds.observe( + elapsed, tags={ + 'deployment': deployment, + 'method': method, + }) + return response + + return metrics_middleware + + +# --------------------------------------------------------------------------- +# B. Task-queue metrics +# --------------------------------------------------------------------------- + + +def get_task_metrics(deployment: str) -> TaskMetrics: + """Return (or create) per-deployment task-queue metrics. + + Returns a :class:`TaskMetrics` Pydantic model with: + + - ``queue_depth`` – Gauge + - ``tasks_total`` – Counter + - ``execution_seconds`` – Histogram + - ``queue_wait_seconds`` – Histogram + - ``rate_limit_rejections`` – Counter + - ``rate_limiter_active_tokens`` – Gauge + """ + if deployment in _task_metrics_cache: + return _task_metrics_cache[deployment] + + metrics = TaskMetrics( + queue_depth=Gauge( + 'twinkle_task_queue_depth', + description='Current number of queued tasks.', + tag_keys=('deployment', ), + ), + tasks_total=Counter( + 'twinkle_tasks_total', + description='Total task completions.', + tag_keys=('deployment', 'task_type', 'status'), + ), + execution_seconds=Histogram( + 'twinkle_task_execution_seconds', + description='Pure task execution time in seconds.', + boundaries=_HISTOGRAM_BOUNDARIES, + tag_keys=('deployment', 'task_type'), + ), + queue_wait_seconds=Histogram( + 'twinkle_task_queue_wait_seconds', + description='Time from enqueue to execution start in seconds.', + boundaries=_HISTOGRAM_BOUNDARIES, + tag_keys=('deployment', 'task_type'), + ), + rate_limit_rejections=Counter( + 'twinkle_rate_limit_rejections_total', + description='Total rate-limit rejections.', + tag_keys=('deployment', ), + ), + rate_limiter_active_tokens=Gauge( + 'twinkle_rate_limiter_active_tokens', + description='Number of tokens tracked by the rate limiter.', + tag_keys=('deployment', ), + ), + ) + _task_metrics_cache[deployment] = metrics + return metrics + + +# --------------------------------------------------------------------------- +# D. Resource gauges (ServerState actor, updated every 15 s) +# --------------------------------------------------------------------------- + + +def get_resource_metrics() -> ResourceMetrics: + """Return (or create) global resource gauge metrics. + + Returns a :class:`ResourceMetrics` Pydantic model with: + + - ``active_sessions`` – Gauge + - ``active_models`` – Gauge + - ``active_sampling_sessions`` – Gauge + - ``active_futures`` – Gauge + """ + global _resource_metrics_cache + if _resource_metrics_cache is not None: + return _resource_metrics_cache + + metrics = ResourceMetrics( + active_sessions=Gauge( + 'twinkle_active_sessions', + description='Current active session count.', + ), + active_models=Gauge( + 'twinkle_active_models', + description='Current registered model count.', + ), + active_sampling_sessions=Gauge( + 'twinkle_active_sampling_sessions', + description='Current sampling session count.', + ), + active_futures=Gauge( + 'twinkle_active_futures', + description='Current future/request count.', + ), + ) + _resource_metrics_cache = metrics + return metrics diff --git a/src/twinkle/server/utils/rate_limiter.py b/src/twinkle/server/utils/rate_limiter.py index beefaa83..845cf246 100644 --- a/src/twinkle/server/utils/rate_limiter.py +++ b/src/twinkle/server/utils/rate_limiter.py @@ -41,6 +41,8 @@ def __init__( window_seconds: float = 1.0, token_cleanup_multiplier: float = 10.0, token_cleanup_interval: float = 60.0, + active_tokens_gauge=None, + deployment_name: str = '', ): """Initialize the rate limiter. @@ -53,6 +55,8 @@ def __init__( will be removed. Default is 10.0 (10x the window). token_cleanup_interval: How often to run the cleanup task in seconds. Default is 60.0 (every minute). + active_tokens_gauge: Optional ray.util.metrics Gauge for tracking active token count. + deployment_name: Deployment name for metrics labels. """ self.rps_limit = rps_limit self.tps_limit = tps_limit @@ -72,6 +76,10 @@ def __init__( self._cleanup_task: asyncio.Task | None = None self._cleanup_started = False + # Metrics gauge for active token count + self._active_tokens_gauge = active_tokens_gauge + self._deployment_name = deployment_name + def _cleanup_old_requests(self, token: str, current_time: float) -> None: """Remove requests outside the sliding window. @@ -122,6 +130,10 @@ async def _cleanup_inactive_tokens(self) -> None: logger.debug(f'[RateLimiter] Cleaned up {len(tokens_to_remove)} inactive tokens. ' f'Active tokens remaining: {len(self._token_requests)}') + if self._active_tokens_gauge is not None: + tags = {'deployment': self._deployment_name} if self._deployment_name else {} + self._active_tokens_gauge.set(len(self._token_requests), tags=tags) + except asyncio.CancelledError: logger.debug('[RateLimiter] Cleanup task cancelled') break @@ -193,6 +205,9 @@ async def check_and_record(self, token: str, input_tokens: int) -> tuple[bool, s # Record this request self._token_requests[token].append((current_time, input_tokens)) + if self._active_tokens_gauge is not None: + tags = {'deployment': self._deployment_name} if self._deployment_name else {} + self._active_tokens_gauge.set(len(self._token_requests), tags=tags) return True, None def get_stats(self, token: str) -> dict[str, Any]: diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index a42aa9a4..8e7689b2 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -9,6 +9,7 @@ from datetime import datetime from typing import Any +from twinkle.server.utils.metrics import get_resource_metrics from twinkle.utils.logger import get_logger from .config_manager import ConfigManager from .future_manager import FutureManager @@ -51,6 +52,11 @@ def __init__( self._cleanup_task: asyncio.Task | None = None self._cleanup_running = False + # Metrics loop state + self._metrics_task: asyncio.Task | None = None + self._metrics_running = False + self._metrics_update_interval: float = float(kwargs.get('metrics_update_interval', 15.0)) + # ----- Session Management ----- async def create_session(self, payload: dict[str, Any]) -> str: @@ -284,6 +290,22 @@ async def _cleanup_loop(self) -> None: logger.warning(f'[ServerState Cleanup] Error during cleanup: {e}') continue + async def _metrics_loop(self) -> None: + """Background task that updates resource gauge metrics every N seconds.""" + resource_metrics = get_resource_metrics() + while self._metrics_running: + try: + await asyncio.sleep(self._metrics_update_interval) + resource_metrics.active_sessions.set(self._session_mgr.count()) + resource_metrics.active_models.set(self._model_mgr.count()) + resource_metrics.active_sampling_sessions.set(self._sampling_mgr.count()) + resource_metrics.active_futures.set(self._future_mgr.count()) + except asyncio.CancelledError: + break + except Exception as e: + logger.debug(f'[ServerState] Error updating metrics: {e}') + continue + async def start_cleanup_task(self) -> bool: """Start the background cleanup task. @@ -294,6 +316,9 @@ async def start_cleanup_task(self) -> bool: return False self._cleanup_running = True self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + if not self._metrics_running: + self._metrics_running = True + self._metrics_task = asyncio.create_task(self._metrics_loop()) return True async def stop_cleanup_task(self) -> bool: @@ -308,6 +333,10 @@ async def stop_cleanup_task(self) -> bool: if self._cleanup_task: self._cleanup_task.cancel() self._cleanup_task = None + self._metrics_running = False + if self._metrics_task: + self._metrics_task.cancel() + self._metrics_task = None return True async def get_cleanup_stats(self) -> dict[str, Any]: diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index d0985c15..9cd99253 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -18,6 +18,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Coroutine, Deque, Dict, Optional +from twinkle.server.utils.metrics import get_task_metrics from twinkle.utils.logger import get_logger from .rate_limiter import RateLimiter @@ -157,11 +158,12 @@ async def _do_work(): # Type hint for state attribute that inheriting classes must provide state: ServerStateProxy - def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None: + def _init_task_queue(self, config: TaskQueueConfig | None = None, deployment_name: str = '') -> None: """Initialize the task queue system. Args: config: Optional TaskQueueConfig. If None, uses default config. + deployment_name: Deployment name for metrics labels (e.g. 'Model', 'Sampler'). """ self._task_queue_config = config or TaskQueueConfig() # Per-key queues, but executed by a single global worker. @@ -169,6 +171,10 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None: self._queue_order: Deque[str] = deque() self._new_task_event: asyncio.Event = asyncio.Event() + # Metrics initialization + self._deployment_name = deployment_name + self._task_metrics = get_task_metrics(deployment_name) if deployment_name else None + # Initialize rate limiter for RPS/TPS control self._rate_limiter = RateLimiter( rps_limit=self._task_queue_config.rps_limit, @@ -176,6 +182,8 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None: window_seconds=self._task_queue_config.window_seconds, token_cleanup_multiplier=self._task_queue_config.token_cleanup_multiplier, token_cleanup_interval=self._task_queue_config.token_cleanup_interval, + active_tokens_gauge=self._task_metrics.rate_limiter_active_tokens if self._task_metrics else None, + deployment_name=deployment_name, ) # Start the rate limiter cleanup task self._rate_limiter.start_cleanup_task() @@ -247,6 +255,18 @@ async def _queue_worker(self) -> None: except asyncio.QueueEmpty: continue + # Record queue wait time and update depth gauge + if self._task_metrics: + queue_wait = time.monotonic() - task.created_at + task_type_label = task.task_type or 'unknown' + self._task_metrics.queue_wait_seconds.observe( + queue_wait, tags={ + 'deployment': self._deployment_name, + 'task_type': task_type_label + }) + total_depth = sum(qq.qsize() for qq in self._task_queues.values()) + self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) + now = time.monotonic() # Global queue timeout @@ -263,6 +283,13 @@ async def _queue_worker(self) -> None: queue_state=QueueState.PAUSED_CAPACITY.value, queue_state_reason=error_payload['error'], ) + if self._task_metrics: + self._task_metrics.tasks_total.inc( + tags={ + 'deployment': self._deployment_name, + 'task_type': task.task_type or 'unknown', + 'status': 'timeout' + }) q.task_done() continue @@ -273,6 +300,8 @@ async def _queue_worker(self) -> None: await self.state.store_future_status( task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value) + exec_start = time.monotonic() + task_status = 'completed' try: coro = task.coro_factory() result = await coro @@ -283,6 +312,7 @@ async def _queue_worker(self) -> None: result=result, queue_state=QueueState.ACTIVE.value) except Exception: + task_status = 'failed' error_payload = {'error': traceback.format_exc(), 'category': 'Server'} await self.state.store_future_status( task.request_id, @@ -292,6 +322,20 @@ async def _queue_worker(self) -> None: queue_state=QueueState.ACTIVE.value) finally: q.task_done() + if self._task_metrics: + exec_time = time.monotonic() - exec_start + self._task_metrics.execution_seconds.observe( + exec_time, + tags={ + 'deployment': self._deployment_name, + 'task_type': task.task_type or 'unknown' + }) + self._task_metrics.tasks_total.inc( + tags={ + 'deployment': self._deployment_name, + 'task_type': task.task_type or 'unknown', + 'status': task_status + }) # Keep serial semantics: execute at most one runnable task per loop break @@ -409,6 +453,8 @@ async def _perform_preflight_checks( # Check rate limits allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) if not allowed: + if self._task_metrics: + self._task_metrics.rate_limit_rejections.inc(tags={'deployment': self._deployment_name}) error_msg = f'Rate limit exceeded: {reason}' error_payload = {'error': error_msg, 'category': 'User'} await self.state.store_future_status( @@ -506,6 +552,10 @@ async def schedule_task( self._new_task_event.set() + if self._task_metrics: + total_depth = sum(q.qsize() for q in self._task_queues.values()) + self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) + return {'request_id': request_id, 'model_id': model_id} def get_queue_stats(self) -> dict[str, Any]: diff --git a/src/twinkle_client/common/serialize.py b/src/twinkle_client/common/serialize.py index de3ca4bb..b2d1720c 100644 --- a/src/twinkle_client/common/serialize.py +++ b/src/twinkle_client/common/serialize.py @@ -2,6 +2,7 @@ import json from numbers import Number from peft import LoraConfig +from pydantic import BaseModel from typing import Any, Mapping from twinkle.dataset import DatasetMeta @@ -56,6 +57,9 @@ def serialize_object(obj) -> str: } filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig' return json.dumps(filtered_dict, ensure_ascii=False) + elif isinstance(obj, BaseModel): + # Pydantic models: convert to dict for JSON serialization by requests + return obj.model_dump(mode='json') elif isinstance(obj, Mapping): return json.dumps(obj, ensure_ascii=False) elif isinstance(obj, basic_types):