Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/gsm8k_qwen3.5_2B_metrics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
124 changes: 124 additions & 0 deletions docs/qwen3_5_fast_path_setup.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Qwen3.5 Fast Path Setup for FSDP Training

## Background

Qwen3.5 is a hybrid architecture with both standard attention layers and **linear attention** (gated delta rule) layers. The linear attention layers require two specialized libraries for fast forward and backward passes:

1. **`flash-linear-attention` (fla)** - Triton kernels for `chunk_gated_delta_rule` and `fused_recurrent_gated_delta_rule`
2. **`causal-conv1d`** - CUDA kernels for causal 1D convolution (`causal_conv1d_fn`, `causal_conv1d_bwd`)

Without both libraries, HuggingFace transformers falls back to pure PyTorch implementations (`torch_chunk_gated_delta_rule` and `F.silu(self.conv1d(...))`) which are **significantly slower** for both forward and backward passes. The check is at:

```python
# transformers/models/qwen3_5/modeling_qwen3_5.py:295
is_fast_path_available = all(
(causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
)
```

## Current State

- `flash-linear-attention` is vendored at `./flash-linear-attention/` and properly wired as a uv source in `pyproject.toml`. It installs cleanly (pure Python/Triton, no CUDA compilation).
- `causal-conv1d` is vendored at `./causal-conv1d/` but is **blocked** from installation by an override in `pyproject.toml`:
```toml
"causal-conv1d; sys_platform == 'never'",
```
This was originally added to suppress transitive resolution from `flash-linear-attention[conv1d]`.

## The Problem with `causal-conv1d`

`causal-conv1d` requires CUDA compilation (it has `.cu` files in `csrc/`). This creates issues with uv + Ray:

1. **uv build isolation**: `causal-conv1d`'s `pyproject.toml` only declares `setuptools`, `wheel`, `torch` as build requirements, but with `match-runtime = true` for torch, uv can't resolve because `causal-conv1d` doesn't declare static metadata.

2. **`no-build-isolation`**: Works locally but fails on Ray workers because they create fresh venvs that lack setuptools.

3. **Local wheel path**: Ray's working directory packaging doesn't include the 160MB wheel file (excluded by `.gitignore` or size limits), so `path = "./dist/..."` fails on workers.

## Solution: Pre-built Wheel via URL

The correct approach (matching how `flash-attn` is handled) is to host a pre-built wheel and reference it by URL.

### Step 1: Build the wheel locally

```bash
# From repo root, with the venv that has torch installed:
uv build ./causal-conv1d --no-build-isolation --python .venv/bin/python --wheel --out-dir ./dist/
```

This produces `dist/causal_conv1d-1.6.1-cp312-cp312-linux_x86_64.whl` (~160MB).

### Step 2: Host the wheel

Upload to a URL accessible by Ray workers. Options:
- GitHub release (like flash-attn uses `mjun0812/flash-attention-prebuild-wheels`)
- S3/GCS bucket
- Any HTTP server accessible from the cluster

Example: if uploaded to `https://github.com/YOUR_ORG/prebuild-wheels/releases/download/v1.0/causal_conv1d-1.6.1-cp312-cp312-linux_x86_64.whl`

### Step 3: Update `pyproject.toml`

```toml
# In [project.optional-dependencies]
fsdp = [
# ... existing deps ...
"causal-conv1d>=1.4.0; sys_platform == 'linux'",
"flash-linear-attention>=0.4.2; sys_platform == 'linux'",
]

# In override-dependencies: REMOVE the sys_platform == 'never' line for causal-conv1d
# Replace with a real constraint:
"causal-conv1d>=1.4.0; sys_platform == 'linux'",

# In [tool.uv.sources]
causal-conv1d = { url = "https://YOUR_URL/causal_conv1d-1.6.1-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux'" }
flash-linear-attention = { path = "./flash-linear-attention", editable = true }
```

### Step 4: Verify

```python
python -c "
from transformers.models.qwen3_5.modeling_qwen3_5 import is_fast_path_available
print('is_fast_path_available:', is_fast_path_available)
"
# Should print: is_fast_path_available: True
```

The model_wrapper.py logging (added in this branch) will also confirm:
```
INFO | Qwen3.5 fast path is ENABLED (causal-conv1d + flash-linear-attention)
```

## Alternative: Docker Image

For production, the cleanest approach is to include `causal-conv1d` in the Docker image:

```dockerfile
# In docker/Dockerfile, after CUDA toolkit is installed:
COPY causal-conv1d /tmp/causal-conv1d
RUN cd /tmp/causal-conv1d && pip install --no-build-isolation . && rm -rf /tmp/causal-conv1d
```

This avoids wheel hosting entirely since all Ray workers share the same Docker image.

## Quick Local-Only Test (no Ray workers)

If you just want to verify the fast path works locally (single-node, no Ray worker venvs):

```bash
# Install directly into the venv
uv pip install -e ./causal-conv1d --python .venv/bin/python --no-build-isolation

# Verify
.venv/bin/python -c "from transformers.models.qwen3_5.modeling_qwen3_5 import is_fast_path_available; print(is_fast_path_available)"
# True
```

## Notes

- `causal-conv1d` v1.6.1 does have a `CachedWheelsCommand` that auto-downloads prebuilt wheels from [Dao-AILab/causal-conv1d releases](https://github.com/Dao-AILab/causal-conv1d/releases), but as of March 2026, no wheels exist for torch 2.10.0 + CUDA 12.8.
- The `flash-linear-attention` fla package lists `causal-conv1d>=1.4.0` as an optional dependency under the `conv1d` extra. The `sys_platform == 'never'` override in SkyRL's pyproject.toml suppresses this.
- Prime-RL does not install either `fla` or `causal-conv1d` — they don't have Qwen3.5 linear attention support.
- The torch fallback `torch_chunk_gated_delta_rule` has a known issue with CPU tensor creation during gradient checkpointing recomputation (see commented-out Patch 2 in `model_wrapper.py`). The fla fast path avoids this entirely.
204 changes: 204 additions & 0 deletions docs/qwen3_5_gsm8k_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Running GSM8K with Qwen3.5 on SkyRL

## Status

**Working**: Model loading, vLLM inference/generation, weight sync, training forward pass, training backward pass.

**Known limitations**: The pure Python `torch_chunk_gated_delta_rule` (Qwen3.5 linear attention) is very slow during training. A fused kernel (e.g., from the `fla` library) is needed for practical training speeds.

## Quick Start

```bash
# 1. Install dependencies (from repo root)
uv sync --extra fsdp

# 2. Prepare dataset
uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k

# 3. Run training (Qwen3.5-2B on 4 GPUs)
NUM_GPUS=4 bash examples/train/gsm8k/run_gsm8k_qwen3_5.sh \
trainer.policy.model.path="Qwen/Qwen3.5-2B"
```

## Required Changes from Upstream SkyRL

### 1. `pyproject.toml` - Version Bumps

| Package | Before | After | Reason |
|---------|--------|-------|--------|
| vllm | 0.16.0 | 0.17.0 | Native Qwen3.5 model support |
| torch | 2.9.1 | 2.10.0 | Required by vllm 0.17.0 |
| transformers | >=4.56.1,<5 | >=5.3.0 | Native `qwen3_5` model type for FSDP |
| flashinfer-python | 0.6.3 | 0.6.4 | Required by vllm 0.17.0 |
| flashinfer-jit-cache | 0.6.3 | 0.6.4 | Matches flashinfer-python |
| accelerate | (unpinned) | >=1.13.0 | Fixes `_is_hf_initialized` TypeError with transformers 5.x |

### 2. flash-attn Pre-built Wheel

flash-attn 2.8.3 has NO pre-built wheel for torch 2.10.0 on PyPI or official GitHub releases.
Use community pre-built wheel from [mjun0812/flash-attention-prebuild-wheels](https://github.com/mjun0812/flash-attention-prebuild-wheels):

```toml
# In [tool.uv.sources]:
flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.8.3+cu128torch2.10-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux'" }
```

### 3. Override Dependencies

vllm 0.17.0 caps `transformers<5`, but FSDP training needs `>=5.3.0` for native Qwen3.5:

```toml
# In override-dependencies:
"transformers>=5.3.0",
```

Also update megatron-bridge to commit `0034ddaa` (supports transformers 5.x).

### 4. transformers 5.x `return_dict=False`

transformers 5.x changed `apply_chat_template(tokenize=True)` to return `BatchEncoding` instead of `list[int]`. Add `return_dict=False` to all call sites in:
- `skyrl/train/generators/skyrl_gym_generator.py`
- `skyrl/train/generators/utils.py`
- `skyrl/train/dataset/dataset.py`
- `skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py`
- `examples/train/mini_swe_agent/mini_swe_generator.py`

### 5. `vllm_worker.py` - Weight Sync Layer Naming

Qwen3.5 uses `ForConditionalGeneration` in vLLM (VL wrapper) but `ForCausalLM` in FSDP training. Weight names differ by `language_model.` prefix:

```python
model_cls = type(self.model_runner.model).__name__
needs_lm_prefix = "ForConditionalGeneration" in model_cls
for name, tensor in self._weight_receiver.receive_weights(request):
if needs_lm_prefix and not (name.startswith("language_model.") or name.startswith("visual.")):
name = f"language_model.{name}"
```

### 6. Qwen3.5 Monkey-Patches (in `model_wrapper.py`)

**Patch 1: 3D position_ids fix** (from [prime-rl](https://github.com/PrimeIntellect-ai/prime-rl/commit/2767dea))
Qwen3.5 passes 3D MRoPE position_ids to decoder layers, breaking flash attention:

```python
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
```

Upstream fix: huggingface/transformers#44399

**Patch 2: CPU tensor creation in `chunk_gated_delta_rule`**

Qwen3.5's hybrid architecture uses **Gated Delta Rule** linear attention layers (`Qwen3_5GatedDeltaNet`)
alongside standard full-attention layers. The linear attention is implemented by two pure-Python
reference functions in transformers 5.3.0:

- `torch_chunk_gated_delta_rule()` — chunked version (used during training)
- `torch_recurrent_gated_delta_rule()` — token-by-token version (used during generation)

Both have a bug on the `initial_state` tensor creation (3 occurrences total at L376, L420, L422):

```python
# Original (transformers 5.3.0, modeling_qwen3_5.py L376):
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) # BUG
if initial_state is None
else initial_state.to(value)
)
```

**What happens:**
1. `torch.zeros(...)` creates a tensor on **CPU** (default device)
2. `.to(value)` copies it to the GPU where `value` lives

**Why it fails during training:**
SkyRL uses `torch.utils.checkpoint` (gradient checkpointing) to reduce memory usage during
FSDP training. During the backward pass, the checkpointing mechanism:
1. Frees activations saved during the forward pass
2. Re-runs the forward pass to recompute them
3. During this recomputation, CUDA memory has been freed and reallocated

The `.to(value)` call during recomputation encounters a CUDA context where the source CPU tensor
allocation interacts with a modified GPU memory layout, triggering:
```
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress'
```

The error trace always points to `modeling_qwen3_5.py L376` (`torch_chunk_gated_delta_rule`)
inside the backward pass recomputation:
```
torch/utils/checkpoint.py:1173, in unpack_hook
_run_fn_with_dynamo_disabled(frame.recompute_fn, *args)
...
modeling_qwen3_5.py:376, in torch_chunk_gated_delta_rule
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
```

**The fix** (in `model_wrapper.py`) wraps `torch_chunk_gated_delta_rule` to pre-create
`initial_state` on the correct device before the function is called:

```python
@functools.wraps(orig_fn)
def _patched_delta_rule(*args, **kwargs):
sig = inspect.signature(orig_fn)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
if bound.arguments.get("initial_state") is None:
key = bound.arguments["key"] # shape: (batch, seq_len, num_heads, k_head_dim)
value = bound.arguments["value"] # shape: (batch, seq_len, num_heads, v_head_dim)
# Create directly on GPU — key is pre-transpose so dims are [0]=batch, [2]=heads, [3]=head_dim
b, num_heads, kd, vd = key.shape[0], key.shape[2], key.shape[3], value.shape[3]
bound.arguments["initial_state"] = torch.zeros(
b, num_heads, kd, vd, device=value.device, dtype=value.dtype
)
return orig_fn(*bound.args, **bound.kwargs)
```

Note the shape subtlety: the function receives tensors in `(batch, seq_len, num_heads, head_dim)`
layout, then transposes them internally to `(batch, num_heads, seq_len, head_dim)` at L339-340.
The wrapper must use the pre-transpose shapes (`key.shape[2]` for num_heads, `key.shape[3]` for k_head_dim).

**Alternative**: SLIME avoids this entirely by using the `fla` library's fused Triton kernels
(`fla.ops.gated_delta_rule.chunk_gated_delta_rule`) instead of transformers' Python reference
implementation. prime-rl does not have this patch and may not have encountered it (possibly
they don't use gradient checkpointing on linear attention layers, or their transformers pin
doesn't trigger this path).

**Note on performance**: The pure-Python `torch_chunk_gated_delta_rule` is a nested for-loop
over chunks and is extremely slow (~10x slower than the fused `fla` kernel). For practical
training, integrating the `fla` library's CUDA/Triton kernels is recommended.

### 7. Training Script Key Config

```bash
trainer.policy.model.path="Qwen/Qwen3.5-2B"
trainer.policy.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['Qwen3_5DecoderLayer']"
trainer.ref.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['Qwen3_5DecoderLayer']"
generator.inference_engine.engine_init_kwargs.language_model_only=true
```

## Training Results

### Qwen3.5-2B GSM8K GRPO (37 steps, 1x H100, batch_size=8)

![Qwen3.5-2B Training Metrics](gsm8k_qwen3.5_2B_metrics.png)

**Key observations:**
- **Rewards**: Fluctuating 0.2-0.9 (small batch, high variance), avg ~0.45
- **Pass@2**: Generally 0.6-0.9
- **Policy KL**: Steadily increasing (expected with GRPO)
- **Grad Norm**: Growing - training is active
- **Step Time**: ~35s/step (generation ~16s, policy train ~10s, weight sync ~4s)
- **Response Length**: 280-480 tokens

### Wandb Runs
- **Qwen3.5-2B** (16 steps): https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k-qwen3.5/runs/krvj0f5s
- **Qwen3.5-0.8B** (ongoing): https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k-qwen3.5/runs/lr53x6z9

## References

- [SkyRL Issue #1254](https://github.com/NovaSky-AI/SkyRL/issues/1254) - Qwen3.5 support tracking
- [prime-rl commit 2767dea](https://github.com/PrimeIntellect-ai/prime-rl/commit/2767dea) - Qwen3.5 patches, flash-attn wheel
- [prime-rl PR #1980](https://github.com/PrimeIntellect-ai/prime-rl/pull/1980) - vllm 0.17.0 official release
- Frontier-CS-Evolve/SkyRL branch `runnable-for-non-qmang` - weight sync layer naming fix
67 changes: 67 additions & 0 deletions examples/train/gsm8k/run_gsm8k_qwen3_5.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
set -x

# Colocated GRPO training+generation for Qwen3.5-2B on GSM8K.

# uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/gsm8k/run_gsm8k_qwen3_5.sh

# You can override the default values with e.g.: `NUM_GPUS=4 bash examples/train/gsm8k/run_gsm8k_qwen3_5.sh`.

export TRITON_PRINT_AUTOTUNING=1

: "${DATA_DIR:="$HOME/data/gsm8k"}"
: "${NUM_GPUS:=8}"
: "${LOGGER:=wandb}" # change to "console" to print to stdout

: "${INFERENCE_BACKEND:=vllm}"

TIS_IMP_RATIO_CAP=2.0
USE_TIS=true

uv run --extra fsdp -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path="Qwen/Qwen3.5-2B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
trainer.policy.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['Qwen3_5DecoderLayer']" \
trainer.ref.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['Qwen3_5DecoderLayer']" \
generator.inference_engine.num_engines=$NUM_GPUS \
generator.inference_engine.tensor_parallel_size=1 \
generator.inference_engine.engine_init_kwargs.language_model_only=true \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=1024 \
trainer.policy_mini_batch_size=256 \
trainer.micro_forward_batch_size_per_gpu=64 \
trainer.micro_train_batch_size_per_gpu=64 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.inference_engine.backend=$INFERENCE_BACKEND \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=6 \
generator.inference_engine.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k-qwen3.5" \
trainer.run_name="gsm8k_qwen3.5_2B" \
trainer.resume_mode=null \
trainer.log_path="/tmp/skyrl-logs" \
trainer.ckpt_path="$HOME/ckpts/gsm8k_qwen3.5_2B_ckpt" \
$@
Loading
Loading