Skip to content

MoE/FSDP2 training fixes, Anthropic Messages API, and robustness improvements#1298

Open
ashutoshuiuc wants to merge 1 commit intoNovaSky-AI:mainfrom
ashutoshuiuc:feature/moe-fsdp2-anthropic-api-fixes
Open

MoE/FSDP2 training fixes, Anthropic Messages API, and robustness improvements#1298
ashutoshuiuc wants to merge 1 commit intoNovaSky-AI:mainfrom
ashutoshuiuc:feature/moe-fsdp2-anthropic-api-fixes

Conversation

@ashutoshuiuc
Copy link

@ashutoshuiuc ashutoshuiuc commented Mar 9, 2026

Summary

This PR combines two sets of changes:

  1. MoE/FSDP2 training stability — fixes for training Qwen3 MoE models under FSDP2 with gradient checkpointing,
    including NCCL deadlock prevention, batched broadcasts, NUMA affinity rewrite, and Ulysses SP fixes. Details in
    MoE (Qwen3) + FSDP2 training fixes: expert patching, NCCL deadlocks, batched broadcast, NUMA, Ulysses SP #1297.

  2. Anthropic Messages API, LoRA swap, and agent robustness/v1/messages endpoint across the inference engine
    stack, improved LoRA weight swap, and robust transitions_to_training_data. Details in [train][agent] Add Anthropic Messages API endpoint, fix LoRA weight swap, and improve transition-to-training-data logic #1222.

Resolves #1297
Resolves #1222

Test plan

  • Tested MoE model (Qwen3-30B-A3B) training under FSDP2 — no NCCL deadlocks
  • Tested /v1/messages endpoint with Anthropic-compatible format
  • Tested LoRA weight swap with HTTP inference endpoint
  • Verified non-MoE models unaffected

CC: @SumanthRH @CharlieFRuan


Open with Devin

Copilot AI review requested due to automatic review settings March 9, 2026 21:52
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant set of changes, including fixes for MoE/FSDP2 training stability, the addition of the Anthropic Messages API, and several robustness improvements. The FSDP2 fixes, such as batched broadcasts and NUMA affinity rewrite, are crucial for stable and efficient training of large models. The introduction of the Anthropic Messages API is well-implemented across the stack. The robustness improvements in transitions_to_training_data and LoRA swapping logic are also valuable. My review found one minor opportunity for code simplification by removing a redundant check. Overall, these are high-quality changes that enhance the capabilities and stability of the platform.

Comment on lines +197 to +199
if not response_tokens:
logger.warning("Response tokens are empty, skipping datum")
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check if not response_tokens: appears to be redundant. The response_tokens list can only be empty if first_nonzero == len(full_sequence). Since len(full_sequence) == len(mask), this is equivalent to first_nonzero == len(mask). This case is already handled by the check at lines 177-179, which would cause the function to return None earlier. Therefore, this block is unreachable and can be removed to simplify the code.

devin-ai-integration[bot]

This comment was marked as resolved.

@ashutoshuiuc ashutoshuiuc force-pushed the feature/moe-fsdp2-anthropic-api-fixes branch 2 times, most recently from 245274b to df7352c Compare March 9, 2026 21:59
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves MoE/FSDP2 training stability and observability (avoiding known NCCL deadlocks and slow init paths), and extends the inference stack with an Anthropic-compatible /v1/messages endpoint plus LoRA swap robustness and agent data conversion hardening.

Changes:

  • Reworks training dispatch/metric reduction timing to avoid FSDP2+MoE NCCL deadlocks and improve per-micro-batch observability.
  • Adds /v1/messages support across the inference engine stack (client routing, HTTP endpoint, vLLM async implementation).
  • Improves robustness/performance in initialization paths (FSDP2 state-dict load coalescing, NUMA affinity rewrite, Ulysses SP position_ids caching) and agent transition conversion.

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
skyrl/train/trainer.py Dispatches per-micro-batch slices for better step progress visibility and timing logs.
skyrl/backends/skyrl_train/workers/worker_dispatch.py Adds get_dp_size() and timing logs around Ray calls/snapshots.
skyrl/backends/skyrl_train/workers/worker.py Moves metric all-reduce to after all micro-batches; rewrites NUMA affinity logic.
skyrl/backends/skyrl_train/workers/model_wrapper.py Caches Ulysses SP position_ids outside checkpointed regions; adds MoE router-logits guard.
skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py Adds MoE expert patching for FSDP2+checkpointing; adjusts meta init logic; adds checkpoint logging.
skyrl/backends/skyrl_train/distributed/fsdp_strategy.py Ensures tied embeddings are linked before state_dict load; avoids non-finite grad deadlock by stepping on all ranks; adds timing logs.
skyrl/backends/skyrl_train/distributed/fsdp_utils.py Coalesces broadcasts to speed up FSDP2 full-state load for large models.
skyrl/backends/skyrl_train/distributed/ulysses/monkey_patch.py Pre-gathers position_ids to avoid checkpoint-backward all_gather deadlocks in Ulysses SP.
skyrl/backends/skyrl_train/inference_engines/base.py Extends interface with anthropic_messages() method.
skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py Adds sticky routing for /v1/messages requests via session_id.
skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py Adds FastAPI POST /v1/messages endpoint with validation and status mapping.
skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py Forwards anthropic_messages() to the Ray actor implementation.
skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py Adds remote forwarding to /v1/messages.
skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py Implements Anthropic↔OpenAI conversion in async vLLM; improves LoRA adapter swap and routing; propagates RAY_ADDRESS.
skyrl-agent/skyrl_agent/functional/utils.py Hardens transitions_to_training_data() to handle missing/mismatched tokens/logprobs robustly.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 603 to 607
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
use_meta = True if self.cfg.strategy == "fsdp2" else (not model_config.tie_word_embeddings)
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
use_meta_tensor=use_meta, mesh=self.strategy.device_mesh
)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as policy init: direct access to model_config.tie_word_embeddings can raise AttributeError for configs that don’t define it. Prefer getattr(model_config, "tie_word_embeddings", False) when computing use_meta.

Copilot uses AI. Check for mistakes.
Comment on lines +216 to +219
real_gpu_id = local_rank_to_real_gpu_id(self._local_rank)
num_gpus_per_numa = max(1, 8 // real_numa_nodes) # e.g. 8//2 = 4
# Clamp to [0, max_node] — guaranteed safe
target_nid = min(max_node, real_gpu_id // num_gpus_per_numa)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_gpus_per_numa is still computed with a hard-coded 8 // real_numa_nodes. On hosts with fewer (or more) visible GPUs, this will skew the GPU→NUMA mapping even though the preceding change fixed the IndexError. Consider deriving the GPU count from CUDA_VISIBLE_DEVICES (or torch.cuda.device_count()/len(ray.get_gpu_ids())) instead of assuming 8.

Copilot uses AI. Check for mistakes.
Comment on lines +464 to +467
await self.reset_prefix_cache()

logger.info(f"Loaded new LoRA {new_id} to {self.llm.list_loras()}")
return {"status": "ok", "lora_id": new_id}
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.llm.list_loras() is logged without awaiting. In this file, LoRA listing is treated as synchronous for the sync engine (self.llm.llm_engine.list_loras()), but on the async engine list_loras() is typically an async API. As written, this may log a coroutine object or never execute. Consider awaiting (or calling the correct sync accessor) before logging.

Copilot uses AI. Check for mistakes.
Comment on lines +185 to +196
if has_valid_logprobs:
response_logprobs = sampled_logprobs[first_nonzero:]

if len(response_logprobs) != len(response_tokens):
logger.error(
f"response_logprobs length ({len(response_logprobs)}) "
f"!= response_tokens length ({len(response_tokens)})"
)
return None
else:
response_logprobs = None

Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finalize_datum() sets response_logprobs = None when logprobs are missing/mismatched. This is a good semantic, but it currently conflicts with the TrainingDatum type annotation (response_logprobs: List[float]). Please update the TrainingDatum dataclass (and any consumers) to make response_logprobs optional so the new behavior is type-correct and intentional.

Copilot uses AI. Check for mistakes.
Comment on lines +379 to +384
is_moe = getattr(model_config, "num_experts", None) is not None or \
getattr(model_config, "num_local_experts", None) is not None

if is_moe:
_patch_moe_experts_for_fsdp2(wrapped_model.model)
# _patch_checkpoint_for_moe() # uncomment if expert patch alone doesn't fix tensor count mismatch
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_patch_moe_experts_for_fsdp2() is applied whenever the config looks MoE (is_moe), regardless of whether the run is actually using FSDP2 + non-reentrant checkpointing. This patch forces iterating all experts and can be a major performance regression for MoE training in other modes (e.g., FSDP1 or re-entrant/no checkpointing). Suggest gating this to the specific failure mode (e.g., self.cfg.strategy == "fsdp2" and self.cfg.gradient_checkpointing and not use_reentrant) or making it opt-in via config.

Copilot uses AI. Check for mistakes.
Comment on lines +302 to +335
@app.post("/v1/messages")
async def anthropic_messages(raw_request: Request):
"""Anthropic-compatible Messages API endpoint."""
try:
request_json = await raw_request.json()

if _global_inference_engine_client is None:
return JSONResponse(
content={"error": {"message": "Inference engine client not initialized", "type": "internal_error"}},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
if "model" not in request_json:
return JSONResponse(
content={"error": {"message": "The field `model` is required", "type": "invalid_request_error"}},
status_code=HTTPStatus.BAD_REQUEST.value,
)
if "messages" not in request_json or not request_json["messages"]:
return JSONResponse(
content={"error": {"message": "The field `messages` is required and cannot be empty", "type": "invalid_request_error"}},
status_code=HTTPStatus.BAD_REQUEST.value,
)

payload = {
"json": request_json,
"headers": dict(raw_request.headers) if hasattr(raw_request, "headers") else {},
}
anthropic_response = await _global_inference_engine_client.anthropic_messages(payload)

if "error" in anthropic_response:
error_type = anthropic_response["error"].get("type", "internal_error")
status_code = HTTPStatus.BAD_REQUEST.value if error_type == "invalid_request_error" else HTTPStatus.INTERNAL_SERVER_ERROR.value
return JSONResponse(content=anthropic_response, status_code=status_code)

return JSONResponse(content=anthropic_response)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds a new public HTTP endpoint (POST /v1/messages) with request validation and error mapping, but there doesn’t appear to be any test coverage for it (there are extensive tests for /v1/chat/completions and /v1/completions). Adding a basic integration test (success + missing/empty messages + missing model) would help prevent regressions.

Copilot uses AI. Check for mistakes.
Comment on lines +307 to +311
# Store position_ids in thread-local for model architectures (e.g., GraniteMoeHybrid)
# that don't propagate position_ids through decoder layers to flash attention.
# Must be set here (outside the model) to survive gradient checkpointing backward re-execution.
if self.sequence_parallel_size > 1:
set_ulysses_position_ids(position_ids_fwd)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says position_ids is stored in “thread-local”, but set_ulysses_position_ids() stores it in module-level globals (see ulysses/monkey_patch.py). To avoid confusion (especially if this code is later used in multi-threaded contexts), consider updating the comment to reflect the actual storage mechanism and its assumptions (single-threaded Ray worker process).

Copilot uses AI. Check for mistakes.
Comment on lines +113 to +135
for expert_idx in range(mod.num_experts):
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.numel() > 0:
current_state = hidden_states[token_idx]
else:
current_state = hidden_states[:0]

gate, up = nn.functional.linear(
current_state, mod.gate_up_proj[expert_idx]
).chunk(2, dim=-1)
current_hidden_states = mod.act_fn(gate) * up
current_hidden_states = nn.functional.linear(
current_hidden_states, mod.down_proj[expert_idx]
)

if token_idx.numel() > 0:
current_hidden_states = (
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
)
final_hidden_states.index_add_(
0, token_idx,
current_hidden_states.to(final_hidden_states.dtype),
)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the patched Qwen3MoeExperts.forward, the if token_idx.numel() > 0: branches (around scaling/index_add) are still data-dependent on routing. That can reintroduce nondeterministic saved-tensor counts under non-reentrant checkpoint recompute (the original problem this patch is trying to fix). Consider executing the same ops unconditionally (indexing/multiplying with empty indices is a no-op) so the code path is identical for empty vs non-empty experts.

Copilot uses AI. Check for mistakes.
Comment on lines 338 to 344
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# FSDP2 handles tied embeddings correctly via broadcast + tie_weights(),
# so meta tensor init is always safe. FSDP1 needs CPU init for tied embeddings.
use_meta = True if self.cfg.strategy == "fsdp2" else (not model_config.tie_word_embeddings)
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
use_meta_tensor=use_meta, mesh=self.strategy.device_mesh
)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_config.tie_word_embeddings is accessed directly when deciding use_meta. Some custom/older HF configs in this repo are handled with getattr(..., "tie_word_embeddings", False) (e.g., LoRA utilities), so this direct access can raise AttributeError and break model init. Consider switching to getattr(model_config, "tie_word_embeddings", False) here (and in the critic/ref init blocks).

Copilot uses AI. Check for mistakes.
Comment on lines 520 to 524
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
use_meta = True if self.cfg.strategy == "fsdp2" else (not model_config.tie_word_embeddings)
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
use_meta_tensor=use_meta, mesh=self.strategy.device_mesh
)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as policy init: direct access to model_config.tie_word_embeddings can raise AttributeError for configs that don’t define it. Prefer getattr(model_config, "tie_word_embeddings", False) when computing use_meta.

Copilot uses AI. Check for mistakes.
devin-ai-integration[bot]

This comment was marked as resolved.

…ovements

Merges two sets of changes:

1. MoE/FSDP2 training stability:
   - Patch MoE expert modules (Qwen3MoeSparseMoeBlock/Qwen3MoeExperts) to iterate
     all experts unconditionally, ensuring deterministic computation graph for
     FSDP2 + non-reentrant gradient checkpointing
   - FSDP2 meta tensor init always enabled (handles tied embeddings via tie_weights)
   - Batched/coalesced broadcast for FSDP2 state dict loading (500MB batches,
     reduces MoE init from minutes to seconds for 18K+ params)
   - Fix NCCL deadlock: non-finite grad_norm no longer skips optimizer.step()
   - Move all-reduce metrics after all micro-batches to avoid NCCL deadlock
     with FSDP2 backward gradient reductions
   - Pre-gather Ulysses position_ids outside checkpointed region to avoid
     NCCL all_gather during gradient checkpointing backward recompute
   - Guard output_router_logits on actual MoE models (num_local_experts > 0)
   - Rewrite NUMA affinity to use integer API (fixes segfaults from bitmask
     pointer corruption) and numa_max_node (fixes NVLink virtual NUMA IDs)
   - Per-micro-batch dispatch with progress logging in trainer
   - Fix RAY_ADDRESS propagation for vLLM EngineCore subprocess

2. Anthropic Messages API and agent improvements:
   - Add /v1/messages endpoint across inference engine stack (base, client,
     HTTP endpoint, ray wrapped, remote, vLLM)
   - Full Anthropic-to-OpenAI format conversion in AsyncVLLMInferenceEngine
   - Improved LoRA weight swap: abort generation, remove old adapter, add new,
     reset prefix cache, track active_lora_id with monkey-patched adapter lookup
   - Robust transitions_to_training_data: validate None/empty observations and
     actions, track logprobs validity per-datum, explicit length-mismatch checks
@ashutoshuiuc ashutoshuiuc force-pushed the feature/moe-fsdp2-anthropic-api-fixes branch from df7352c to 606edb1 Compare March 9, 2026 22:10
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 15 additional findings in Devin Review.

Open in Devin Review

Comment on lines +182 to 196
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad()
return grad_norm

t0 = _time.time()
optimizer.step()
if scheduler is not None:
logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
# Only advance LR schedule when gradients were finite (non-finite steps are no-ops)
if scheduler is not None and not non_finite:
scheduler.step()
optimizer.zero_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 optimizer.step() with zeroed gradients is NOT a no-op: weight decay and momentum still modify parameters

When grad_norm is non-finite, the new code zeroes gradients then calls optimizer.step(). The comment says "We zero_grad before stepping so the non-finite update is harmless" — but this is incorrect for AdamW. optimizer.step() with zero gradients still: (1) applies weight decay (param *= (1 - lr * wd)), shrinking all parameters, (2) decays momentum buffers (exp_avg *= beta1, exp_avg_sq *= beta2) then uses decayed momentum to compute a non-zero parameter update, and (3) increments the internal step counter, affecting Adam bias correction. The old code returned early and truly skipped the update. The new code silently modifies model weights when the intent was a no-op.

Old vs new behavior on non-finite grad_norm

Old code returned early — true no-op:

optimizer.zero_grad()
return grad_norm  # skip step entirely

New code still updates parameters:

optimizer.zero_grad()  # zeros grads
optimizer.step()       # applies weight_decay + momentum-based update

A correct fix would be to also set all exp_avg and exp_avg_sq to zero, or temporarily set lr=0 and weight_decay=0 before stepping.

Prompt for agents
In skyrl/backends/skyrl_train/distributed/fsdp_strategy.py, the optimizer_step method (lines 155-197) calls optimizer.step() with zeroed gradients when grad_norm is non-finite. This is NOT a no-op for AdamW because weight decay is applied independently of gradients, and momentum buffers are decayed and then used for a non-zero parameter update.

To truly make the non-finite case a no-op while still calling optimizer.step() (to avoid NCCL deadlock), you should temporarily set lr=0 and weight_decay=0 before stepping, then restore them afterward. For example:

if non_finite:
    optimizer.zero_grad()
    saved_lrs = []
    saved_wds = []
    for pg in optimizer.param_groups:
        saved_lrs.append(pg['lr'])
        saved_wds.append(pg['weight_decay'])
        pg['lr'] = 0.0
        pg['weight_decay'] = 0.0

optimizer.step()

if non_finite:
    for pg, lr, wd in zip(optimizer.param_groups, saved_lrs, saved_wds):
        pg['lr'] = lr
        pg['weight_decay'] = wd

This ensures all ranks participate in the NCCL collectives inside optimizer.step() but no parameters or optimizer state are actually modified.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@ashutoshuiuc
Copy link
Author

Please let me know how i can get these changes merged. The cpu/gpu tests are failing due to some checks in anyscale accounts, are contributors supposed to have anyscale accounts/subscriptions? What's the preferred method of contributing to SkyRL?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants