Skip to content

[Bug] IndexError in log_rollout_data when --log-correct-samples enabled with DP > 1 #1784

@zchuz

Description

@zchuz

Bug Description

Summary

log_rollout_data crashes with IndexError when --log-correct-samples is enabled in multi-node / multi-GPU training (DP > 1). The root cause is that raw_reward is not partitioned by DP rank in process_rollout_data, while response_lengths is.

Error

File "slime/backends/megatron_utils/data.py", line 520, in log_rollout_data
    correct_response_lengths.append(response_lengths[i])
                                    ~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range

Root Cause

In slime/ray/rollout.py:749-778, data is split into two groups before sending to the training side:

  • Partitioned by DP rank (line 753-770): tokens, response_lengths, rewards, loss_masks, etc.
  • Kept as full/global (line 771-778): raw_reward, total_lengths — commented as "keys that need to be split at train side"

On the training side, slime/utils/data.py:process_rollout_data (line 299-310) correctly splits total_lengths by partition, but does not split raw_reward.

When log_rollout_data (slime/backends/megatron_utils/data.py:512-523) iterates over raw_rewards (global, length = total samples) and indexes into response_lengths (partitioned, length = samples for this DP rank), the index goes out of bounds.

Reproduction

  • Use DP > 1 (e.g., multi-node training)
  • Enable --log-correct-samples
  • Run at least one training step

The bug does not trigger with DP = 1 because the partition covers all samples.

Suggested Fix

Add raw_reward partitioning in slime/utils/data.py:process_rollout_data:

def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):
    assert len(rollout_data_ref) == dp_size
    rollout_data = ray.get(rollout_data_ref[dp_rank].inner)

    partition = rollout_data.pop("partition")
    total_lengths = rollout_data["total_lengths"]

    Timer().seq_lens = total_lengths
    rollout_data["total_lengths"] = [total_lengths[i] for i in partition]

    # raw_reward is also passed as full (unsplit) from rollout side, split it here
    if "raw_reward" in rollout_data:
        raw_reward = rollout_data["raw_reward"]
        rollout_data["raw_reward"] = [raw_reward[i] for i in partition]

    return rollout_data

Environment

  • slime version: 0.2.3

Workaround

Remove --log-correct-samples from the training script.

Steps to Reproduce

Reproduction

  • Use DP > 1 (e.g., multi-node training)
  • Enable --log-correct-samples
  • Run at least one training step

The bug does not trigger with DP = 1 because the partition covers all samples.

Expected Behavior

When --log-correct-samples is enabled in multi-node training (DP > 1), log_rollout_data should correctly log metrics (response length, entropy, etc.) for reward=1 samples on each DP rank, without crashing.

Actual Behavior

Training crashes with IndexError at the first training step after rollout completes:

Traceback (most recent call last):
  File "train.py", line 106, in <module>
    train(args)
  File "train.py", line 84, in train
    ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
  ...
  File "slime/backends/megatron_utils/actor.py", line 465, in train_actor
    log_rollout_data(
  File "slime/backends/megatron_utils/data.py", line 520, in log_rollout_data
    correct_response_lengths.append(response_lengths[i])
                                    ~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range

raw_reward has length = total samples (global), while response_lengths has length = samples for this DP rank (partitioned). Iterating raw_reward and indexing into response_lengths goes out of bounds.

Environment

  • slime version: 0.2.3
  • Python version: Python 3.12.13
  • PyTorch version:
  • CUDA/ROCm version:
  • GPU type and count:
  • OS:
  • SGLang version (if relevant):
  • Megatron-LM version (if relevant):

Logs

File "./slime_v023/slime/slime/backends/megatron_utils/actor.py", line 368, in train
    return self.train_actor(rollout_id, rollout_data)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".slime_v023/slime/slime/backends/megatron_utils/actor.py", line 465, in train_actor
    log_rollout_data(
  File "./slime_v023/slime/slime/backends/megatron_utils/data.py", line 520, in log_rollout_data
    correct_response_lengths.append(response_lengths[i])

Additional Context

No response

Pre-submission Checklist

  • I have read the CONTRIBUTING.md and understand the collaboration scope.
  • I have read the documentation and my issue is not addressed there.
  • I have searched for existing issues and this is not a duplicate.
  • I have provided a minimal, reproducible example.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions