MoE/FSDP2 training fixes, Anthropic Messages API, and robustness improvements#1298
MoE/FSDP2 training fixes, Anthropic Messages API, and robustness improvements#1298ashutoshuiuc wants to merge 1 commit intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a significant set of changes, including fixes for MoE/FSDP2 training stability, the addition of the Anthropic Messages API, and several robustness improvements. The FSDP2 fixes, such as batched broadcasts and NUMA affinity rewrite, are crucial for stable and efficient training of large models. The introduction of the Anthropic Messages API is well-implemented across the stack. The robustness improvements in transitions_to_training_data and LoRA swapping logic are also valuable. My review found one minor opportunity for code simplification by removing a redundant check. Overall, these are high-quality changes that enhance the capabilities and stability of the platform.
| if not response_tokens: | ||
| logger.warning("Response tokens are empty, skipping datum") | ||
| return None |
There was a problem hiding this comment.
The check if not response_tokens: appears to be redundant. The response_tokens list can only be empty if first_nonzero == len(full_sequence). Since len(full_sequence) == len(mask), this is equivalent to first_nonzero == len(mask). This case is already handled by the check at lines 177-179, which would cause the function to return None earlier. Therefore, this block is unreachable and can be removed to simplify the code.
245274b to
df7352c
Compare
There was a problem hiding this comment.
Pull request overview
This PR improves MoE/FSDP2 training stability and observability (avoiding known NCCL deadlocks and slow init paths), and extends the inference stack with an Anthropic-compatible /v1/messages endpoint plus LoRA swap robustness and agent data conversion hardening.
Changes:
- Reworks training dispatch/metric reduction timing to avoid FSDP2+MoE NCCL deadlocks and improve per-micro-batch observability.
- Adds
/v1/messagessupport across the inference engine stack (client routing, HTTP endpoint, vLLM async implementation). - Improves robustness/performance in initialization paths (FSDP2 state-dict load coalescing, NUMA affinity rewrite, Ulysses SP position_ids caching) and agent transition conversion.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| skyrl/train/trainer.py | Dispatches per-micro-batch slices for better step progress visibility and timing logs. |
| skyrl/backends/skyrl_train/workers/worker_dispatch.py | Adds get_dp_size() and timing logs around Ray calls/snapshots. |
| skyrl/backends/skyrl_train/workers/worker.py | Moves metric all-reduce to after all micro-batches; rewrites NUMA affinity logic. |
| skyrl/backends/skyrl_train/workers/model_wrapper.py | Caches Ulysses SP position_ids outside checkpointed regions; adds MoE router-logits guard. |
| skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py | Adds MoE expert patching for FSDP2+checkpointing; adjusts meta init logic; adds checkpoint logging. |
| skyrl/backends/skyrl_train/distributed/fsdp_strategy.py | Ensures tied embeddings are linked before state_dict load; avoids non-finite grad deadlock by stepping on all ranks; adds timing logs. |
| skyrl/backends/skyrl_train/distributed/fsdp_utils.py | Coalesces broadcasts to speed up FSDP2 full-state load for large models. |
| skyrl/backends/skyrl_train/distributed/ulysses/monkey_patch.py | Pre-gathers position_ids to avoid checkpoint-backward all_gather deadlocks in Ulysses SP. |
| skyrl/backends/skyrl_train/inference_engines/base.py | Extends interface with anthropic_messages() method. |
| skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py | Adds sticky routing for /v1/messages requests via session_id. |
| skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py | Adds FastAPI POST /v1/messages endpoint with validation and status mapping. |
| skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py | Forwards anthropic_messages() to the Ray actor implementation. |
| skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py | Adds remote forwarding to /v1/messages. |
| skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py | Implements Anthropic↔OpenAI conversion in async vLLM; improves LoRA adapter swap and routing; propagates RAY_ADDRESS. |
| skyrl-agent/skyrl_agent/functional/utils.py | Hardens transitions_to_training_data() to handle missing/mismatched tokens/logprobs robustly. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | ||
| use_meta = True if self.cfg.strategy == "fsdp2" else (not model_config.tie_word_embeddings) | ||
| init_context = get_init_weight_context_manager( | ||
| use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh | ||
| use_meta_tensor=use_meta, mesh=self.strategy.device_mesh | ||
| ) |
There was a problem hiding this comment.
Same issue as policy init: direct access to model_config.tie_word_embeddings can raise AttributeError for configs that don’t define it. Prefer getattr(model_config, "tie_word_embeddings", False) when computing use_meta.
| real_gpu_id = local_rank_to_real_gpu_id(self._local_rank) | ||
| num_gpus_per_numa = max(1, 8 // real_numa_nodes) # e.g. 8//2 = 4 | ||
| # Clamp to [0, max_node] — guaranteed safe | ||
| target_nid = min(max_node, real_gpu_id // num_gpus_per_numa) |
There was a problem hiding this comment.
num_gpus_per_numa is still computed with a hard-coded 8 // real_numa_nodes. On hosts with fewer (or more) visible GPUs, this will skew the GPU→NUMA mapping even though the preceding change fixed the IndexError. Consider deriving the GPU count from CUDA_VISIBLE_DEVICES (or torch.cuda.device_count()/len(ray.get_gpu_ids())) instead of assuming 8.
| await self.reset_prefix_cache() | ||
|
|
||
| logger.info(f"Loaded new LoRA {new_id} to {self.llm.list_loras()}") | ||
| return {"status": "ok", "lora_id": new_id} |
There was a problem hiding this comment.
self.llm.list_loras() is logged without awaiting. In this file, LoRA listing is treated as synchronous for the sync engine (self.llm.llm_engine.list_loras()), but on the async engine list_loras() is typically an async API. As written, this may log a coroutine object or never execute. Consider awaiting (or calling the correct sync accessor) before logging.
| if has_valid_logprobs: | ||
| response_logprobs = sampled_logprobs[first_nonzero:] | ||
|
|
||
| if len(response_logprobs) != len(response_tokens): | ||
| logger.error( | ||
| f"response_logprobs length ({len(response_logprobs)}) " | ||
| f"!= response_tokens length ({len(response_tokens)})" | ||
| ) | ||
| return None | ||
| else: | ||
| response_logprobs = None | ||
|
|
There was a problem hiding this comment.
finalize_datum() sets response_logprobs = None when logprobs are missing/mismatched. This is a good semantic, but it currently conflicts with the TrainingDatum type annotation (response_logprobs: List[float]). Please update the TrainingDatum dataclass (and any consumers) to make response_logprobs optional so the new behavior is type-correct and intentional.
| is_moe = getattr(model_config, "num_experts", None) is not None or \ | ||
| getattr(model_config, "num_local_experts", None) is not None | ||
|
|
||
| if is_moe: | ||
| _patch_moe_experts_for_fsdp2(wrapped_model.model) | ||
| # _patch_checkpoint_for_moe() # uncomment if expert patch alone doesn't fix tensor count mismatch |
There was a problem hiding this comment.
_patch_moe_experts_for_fsdp2() is applied whenever the config looks MoE (is_moe), regardless of whether the run is actually using FSDP2 + non-reentrant checkpointing. This patch forces iterating all experts and can be a major performance regression for MoE training in other modes (e.g., FSDP1 or re-entrant/no checkpointing). Suggest gating this to the specific failure mode (e.g., self.cfg.strategy == "fsdp2" and self.cfg.gradient_checkpointing and not use_reentrant) or making it opt-in via config.
| @app.post("/v1/messages") | ||
| async def anthropic_messages(raw_request: Request): | ||
| """Anthropic-compatible Messages API endpoint.""" | ||
| try: | ||
| request_json = await raw_request.json() | ||
|
|
||
| if _global_inference_engine_client is None: | ||
| return JSONResponse( | ||
| content={"error": {"message": "Inference engine client not initialized", "type": "internal_error"}}, | ||
| status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, | ||
| ) | ||
| if "model" not in request_json: | ||
| return JSONResponse( | ||
| content={"error": {"message": "The field `model` is required", "type": "invalid_request_error"}}, | ||
| status_code=HTTPStatus.BAD_REQUEST.value, | ||
| ) | ||
| if "messages" not in request_json or not request_json["messages"]: | ||
| return JSONResponse( | ||
| content={"error": {"message": "The field `messages` is required and cannot be empty", "type": "invalid_request_error"}}, | ||
| status_code=HTTPStatus.BAD_REQUEST.value, | ||
| ) | ||
|
|
||
| payload = { | ||
| "json": request_json, | ||
| "headers": dict(raw_request.headers) if hasattr(raw_request, "headers") else {}, | ||
| } | ||
| anthropic_response = await _global_inference_engine_client.anthropic_messages(payload) | ||
|
|
||
| if "error" in anthropic_response: | ||
| error_type = anthropic_response["error"].get("type", "internal_error") | ||
| status_code = HTTPStatus.BAD_REQUEST.value if error_type == "invalid_request_error" else HTTPStatus.INTERNAL_SERVER_ERROR.value | ||
| return JSONResponse(content=anthropic_response, status_code=status_code) | ||
|
|
||
| return JSONResponse(content=anthropic_response) |
There was a problem hiding this comment.
This PR adds a new public HTTP endpoint (POST /v1/messages) with request validation and error mapping, but there doesn’t appear to be any test coverage for it (there are extensive tests for /v1/chat/completions and /v1/completions). Adding a basic integration test (success + missing/empty messages + missing model) would help prevent regressions.
| # Store position_ids in thread-local for model architectures (e.g., GraniteMoeHybrid) | ||
| # that don't propagate position_ids through decoder layers to flash attention. | ||
| # Must be set here (outside the model) to survive gradient checkpointing backward re-execution. | ||
| if self.sequence_parallel_size > 1: | ||
| set_ulysses_position_ids(position_ids_fwd) |
There was a problem hiding this comment.
The comment says position_ids is stored in “thread-local”, but set_ulysses_position_ids() stores it in module-level globals (see ulysses/monkey_patch.py). To avoid confusion (especially if this code is later used in multi-threaded contexts), consider updating the comment to reflect the actual storage mechanism and its assumptions (single-threaded Ray worker process).
| for expert_idx in range(mod.num_experts): | ||
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | ||
| if token_idx.numel() > 0: | ||
| current_state = hidden_states[token_idx] | ||
| else: | ||
| current_state = hidden_states[:0] | ||
|
|
||
| gate, up = nn.functional.linear( | ||
| current_state, mod.gate_up_proj[expert_idx] | ||
| ).chunk(2, dim=-1) | ||
| current_hidden_states = mod.act_fn(gate) * up | ||
| current_hidden_states = nn.functional.linear( | ||
| current_hidden_states, mod.down_proj[expert_idx] | ||
| ) | ||
|
|
||
| if token_idx.numel() > 0: | ||
| current_hidden_states = ( | ||
| current_hidden_states * top_k_weights[token_idx, top_k_pos, None] | ||
| ) | ||
| final_hidden_states.index_add_( | ||
| 0, token_idx, | ||
| current_hidden_states.to(final_hidden_states.dtype), | ||
| ) |
There was a problem hiding this comment.
In the patched Qwen3MoeExperts.forward, the if token_idx.numel() > 0: branches (around scaling/index_add) are still data-dependent on routing. That can reintroduce nondeterministic saved-tensor counts under non-reentrant checkpoint recompute (the original problem this patch is trying to fix). Consider executing the same ops unconditionally (indexing/multiplying with empty indices is a no-op) so the code path is identical for empty vs non-empty experts.
| model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | ||
| # FSDP2 handles tied embeddings correctly via broadcast + tie_weights(), | ||
| # so meta tensor init is always safe. FSDP1 needs CPU init for tied embeddings. | ||
| use_meta = True if self.cfg.strategy == "fsdp2" else (not model_config.tie_word_embeddings) | ||
| init_context = get_init_weight_context_manager( | ||
| use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh | ||
| use_meta_tensor=use_meta, mesh=self.strategy.device_mesh | ||
| ) |
There was a problem hiding this comment.
model_config.tie_word_embeddings is accessed directly when deciding use_meta. Some custom/older HF configs in this repo are handled with getattr(..., "tie_word_embeddings", False) (e.g., LoRA utilities), so this direct access can raise AttributeError and break model init. Consider switching to getattr(model_config, "tie_word_embeddings", False) here (and in the critic/ref init blocks).
| model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | ||
| use_meta = True if self.cfg.strategy == "fsdp2" else (not model_config.tie_word_embeddings) | ||
| init_context = get_init_weight_context_manager( | ||
| use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh | ||
| use_meta_tensor=use_meta, mesh=self.strategy.device_mesh | ||
| ) |
There was a problem hiding this comment.
Same issue as policy init: direct access to model_config.tie_word_embeddings can raise AttributeError for configs that don’t define it. Prefer getattr(model_config, "tie_word_embeddings", False) when computing use_meta.
…ovements
Merges two sets of changes:
1. MoE/FSDP2 training stability:
- Patch MoE expert modules (Qwen3MoeSparseMoeBlock/Qwen3MoeExperts) to iterate
all experts unconditionally, ensuring deterministic computation graph for
FSDP2 + non-reentrant gradient checkpointing
- FSDP2 meta tensor init always enabled (handles tied embeddings via tie_weights)
- Batched/coalesced broadcast for FSDP2 state dict loading (500MB batches,
reduces MoE init from minutes to seconds for 18K+ params)
- Fix NCCL deadlock: non-finite grad_norm no longer skips optimizer.step()
- Move all-reduce metrics after all micro-batches to avoid NCCL deadlock
with FSDP2 backward gradient reductions
- Pre-gather Ulysses position_ids outside checkpointed region to avoid
NCCL all_gather during gradient checkpointing backward recompute
- Guard output_router_logits on actual MoE models (num_local_experts > 0)
- Rewrite NUMA affinity to use integer API (fixes segfaults from bitmask
pointer corruption) and numa_max_node (fixes NVLink virtual NUMA IDs)
- Per-micro-batch dispatch with progress logging in trainer
- Fix RAY_ADDRESS propagation for vLLM EngineCore subprocess
2. Anthropic Messages API and agent improvements:
- Add /v1/messages endpoint across inference engine stack (base, client,
HTTP endpoint, ray wrapped, remote, vLLM)
- Full Anthropic-to-OpenAI format conversion in AsyncVLLMInferenceEngine
- Improved LoRA weight swap: abort generation, remove old adapter, add new,
reset prefix cache, track active_lora_id with monkey-patched adapter lookup
- Robust transitions_to_training_data: validate None/empty observations and
actions, track logprobs validity per-datum, explicit length-mismatch checks
df7352c to
606edb1
Compare
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | ||
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | ||
| # We zero_grad before stepping so the non-finite update is harmless. | ||
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | ||
| if non_finite: | ||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | ||
| optimizer.zero_grad() | ||
| return grad_norm | ||
|
|
||
| t0 = _time.time() | ||
| optimizer.step() | ||
| if scheduler is not None: | ||
| logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") | ||
| # Only advance LR schedule when gradients were finite (non-finite steps are no-ops) | ||
| if scheduler is not None and not non_finite: | ||
| scheduler.step() | ||
| optimizer.zero_grad() |
There was a problem hiding this comment.
🔴 optimizer.step() with zeroed gradients is NOT a no-op: weight decay and momentum still modify parameters
When grad_norm is non-finite, the new code zeroes gradients then calls optimizer.step(). The comment says "We zero_grad before stepping so the non-finite update is harmless" — but this is incorrect for AdamW. optimizer.step() with zero gradients still: (1) applies weight decay (param *= (1 - lr * wd)), shrinking all parameters, (2) decays momentum buffers (exp_avg *= beta1, exp_avg_sq *= beta2) then uses decayed momentum to compute a non-zero parameter update, and (3) increments the internal step counter, affecting Adam bias correction. The old code returned early and truly skipped the update. The new code silently modifies model weights when the intent was a no-op.
Old vs new behavior on non-finite grad_norm
Old code returned early — true no-op:
optimizer.zero_grad()
return grad_norm # skip step entirelyNew code still updates parameters:
optimizer.zero_grad() # zeros grads
optimizer.step() # applies weight_decay + momentum-based updateA correct fix would be to also set all exp_avg and exp_avg_sq to zero, or temporarily set lr=0 and weight_decay=0 before stepping.
Prompt for agents
In skyrl/backends/skyrl_train/distributed/fsdp_strategy.py, the optimizer_step method (lines 155-197) calls optimizer.step() with zeroed gradients when grad_norm is non-finite. This is NOT a no-op for AdamW because weight decay is applied independently of gradients, and momentum buffers are decayed and then used for a non-zero parameter update.
To truly make the non-finite case a no-op while still calling optimizer.step() (to avoid NCCL deadlock), you should temporarily set lr=0 and weight_decay=0 before stepping, then restore them afterward. For example:
if non_finite:
optimizer.zero_grad()
saved_lrs = []
saved_wds = []
for pg in optimizer.param_groups:
saved_lrs.append(pg['lr'])
saved_wds.append(pg['weight_decay'])
pg['lr'] = 0.0
pg['weight_decay'] = 0.0
optimizer.step()
if non_finite:
for pg, lr, wd in zip(optimizer.param_groups, saved_lrs, saved_wds):
pg['lr'] = lr
pg['weight_decay'] = wd
This ensures all ranks participate in the NCCL collectives inside optimizer.step() but no parameters or optimizer state are actually modified.
Was this helpful? React with 👍 or 👎 to provide feedback.
|
Please let me know how i can get these changes merged. The cpu/gpu tests are failing due to some checks in anyscale accounts, are contributors supposed to have anyscale accounts/subscriptions? What's the preferred method of contributing to SkyRL? |
Summary
This PR combines two sets of changes:
MoE/FSDP2 training stability — fixes for training Qwen3 MoE models under FSDP2 with gradient checkpointing,
including NCCL deadlock prevention, batched broadcasts, NUMA affinity rewrite, and Ulysses SP fixes. Details in
MoE (Qwen3) + FSDP2 training fixes: expert patching, NCCL deadlocks, batched broadcast, NUMA, Ulysses SP #1297.
Anthropic Messages API, LoRA swap, and agent robustness —
/v1/messagesendpoint across the inference enginestack, improved LoRA weight swap, and robust
transitions_to_training_data. Details in [train][agent] Add Anthropic Messages API endpoint, fix LoRA weight swap, and improve transition-to-training-data logic #1222.Resolves #1297
Resolves #1222
Test plan
/v1/messagesendpoint with Anthropic-compatible formatCC: @SumanthRH @CharlieFRuan