Skip to content

Add rollout sampling-mask support#1795

Open
yitianlian wants to merge 7 commits intoTHUDM:mainfrom
yitianlian:topp-mask
Open

Add rollout sampling-mask support#1795
yitianlian wants to merge 7 commits intoTHUDM:mainfrom
yitianlian:topp-mask

Conversation

@yitianlian
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 93353a6821

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

if "sampling_token_ids" in rollout_data:
key = "sampling_token_ids"
rollout_data[key] = [
slice_log_prob_with_cp(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip zigzag CP slicing for sampling ids in allgather mode

This unconditionally applies slice_log_prob_with_cp to sampling_token_ids, but in --allgather_cp mode the training path first consumes contiguous CP chunks and only later redistributes to zigzag layout. That means the per-position candidate lists are misaligned with get_responses(..., allgather_cp=True) output, which can produce out-of-bounds indexing or incorrect token masks when use_topp_mask/use_topk_mask is enabled with CP>1. Guard this branch for allgather_cp (or defer slicing until after redistribution) so mask positions match the logits chunk layout.

Useful? React with 👍 / 👎.

max_seq_lens=max_seq_lens,
)

masked_old_log_probs = [olp - tlse for olp, tlse in zip(old_log_probs, mask_logprob_sum)]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize old log-probs over the same masked support

This subtracts sampling_logprob_sum directly, but that sum is built from output_top_logprobs candidates only; when rollout_top_logprobs_num is too small, the sampled token can be missing from that set. The new-policy side explicitly keeps sampled tokens in the mask, so the two sides can end up normalized over different supports, biasing PPO ratios/KL instead of preserving rollout-training consistency. Include sampled-token probability in the stored sum (or adjust it here) before computing masked_old_log_probs.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds infrastructure to record a rollout-time sampling “support” (top‑p / top‑k candidate token sets) and reuse it during Megatron training to align log-prob normalization between rollout and training.

Changes:

  • Extend Sample / rollout→train data plumbing to carry per-position sampling candidates (sampling_token_ids) and normalization mass (sampling_logprob_sum).
  • Add rollout-side extraction of sampling candidates from SGLang output_top_logprobs, plus CLI flags to enable/validate this behavior.
  • Add training-side utilities to mask logits to the recorded candidate set and recompute masked log-probs for PPO loss.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
slime/utils/types.py Adds sampling-mask-related fields to Sample.
slime/utils/ppo_utils.py Introduces mask_logits_for_token_ids utility used during training.
slime/utils/arguments.py Adds --use-topp-mask/--use-topk-mask and validation for sampling-mask settings.
slime/rollout/sglang_rollout.py Extracts per-position candidate sets from output_top_logprobs and stores them on samples; requests top_logprobs_num when enabled.
slime/ray/rollout.py Ships sampling-mask fields to training and logs related metrics.
slime/backends/megatron_utils/model.py Plumbs new batch keys into Megatron get_batch(...) call sites.
slime/backends/megatron_utils/loss.py Applies sampling-mask normalization to train-time log-probs and old log-probs.
slime/backends/megatron_utils/data.py Adjusts rollout logging behavior; adds special handling for sampling_token_ids.
slime/backends/megatron_utils/actor.py CP-slices and stages sampling-mask fields onto GPU for training.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

top_p=args.rollout_top_p,
use_topk_mask=getattr(args, "use_topk_mask", False),
top_k=args.rollout_top_k,
)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

append_sampling_mask_to_sample builds sampling_token_ids/sampling_logprob_sum solely from output_top_logprobs. If the actually generated token is missing from output_top_logprobs (possible when top_logprobs_num is too small, as noted elsewhere), training later force-keeps that token in the mask but sampling_logprob_sum will not include its probability mass, so old log-prob renormalization becomes inconsistent. Consider explicitly inserting the generated token ID per position into sampling_token_ids (and updating the corresponding log-sum-exp) or failing fast when it is missing.

Suggested change
)
)
# Optional consistency check: ensure the actually generated token at each
# position is present in the corresponding sampling mask when we have
# access to the generated token ids in meta_info. If it is missing, we
# fail fast instead of silently constructing an inconsistent mask.
generated_token_ids = meta_info.get("output_token_ids")
if generated_token_ids is None:
generated_token_ids = meta_info.get("output_ids")
if isinstance(generated_token_ids, list) and len(generated_token_ids) == len(new_sampling_mask_ids):
for pos, (gen_id, candidate_ids) in enumerate(zip(generated_token_ids, new_sampling_mask_ids)):
if gen_id not in candidate_ids:
raise ValueError(
"Generated token id is missing from output_top_logprobs at position "
f"{pos}. This can happen if top_logprobs_num is too small; "
"increase it or disable use_topp_mask/use_topk_mask."
)

Copilot uses AI. Check for mistakes.
Comment on lines +182 to +194
mask = torch.zeros_like(logits, dtype=torch.bool)
for t, ids in enumerate(sampling_token_ids):
local_ids = [gid - vocab_start for gid in ids if vocab_start <= gid < vocab_start + vocab_shard_size]
if local_ids:
idx = torch.tensor(local_ids, dtype=torch.long, device=logits.device)
mask[t].scatter_(0, idx, True)
# Ensure the actually generated tokens are always kept in the mask.
if tokens is not None:
local_tokens = tokens.long() - vocab_start
valid = (local_tokens >= 0) & (local_tokens < vocab_shard_size)
positions = torch.arange(min(tokens.size(0), mask.size(0)), device=tokens.device)
mask[positions[valid], local_tokens[valid]] = True
return torch.where(mask, logits, torch.tensor(float("-inf"), device=logits.device, dtype=logits.dtype))
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_logits_for_token_ids constructs a dense boolean mask of shape [seq_len, vocab_shard_size] and fills it in a Python loop. For typical RLHF vocab sizes this can add significant memory overhead and runtime (allocating a full mask per forward). Consider a sparse approach (e.g., compute per-position logsumexp over the selected indices and subtract it from the gathered token logits, with TP all-reduce), or at least avoid per-step tensor allocations in the loop.

Suggested change
mask = torch.zeros_like(logits, dtype=torch.bool)
for t, ids in enumerate(sampling_token_ids):
local_ids = [gid - vocab_start for gid in ids if vocab_start <= gid < vocab_start + vocab_shard_size]
if local_ids:
idx = torch.tensor(local_ids, dtype=torch.long, device=logits.device)
mask[t].scatter_(0, idx, True)
# Ensure the actually generated tokens are always kept in the mask.
if tokens is not None:
local_tokens = tokens.long() - vocab_start
valid = (local_tokens >= 0) & (local_tokens < vocab_shard_size)
positions = torch.arange(min(tokens.size(0), mask.size(0)), device=tokens.device)
mask[positions[valid], local_tokens[valid]] = True
return torch.where(mask, logits, torch.tensor(float("-inf"), device=logits.device, dtype=logits.dtype))
# Initialize all logits to -inf, then selectively copy over allowed positions.
masked_logits = logits.new_full(logits.shape, float("-inf"))
for t, ids in enumerate(sampling_token_ids):
local_ids = [gid - vocab_start for gid in ids if vocab_start <= gid < vocab_start + vocab_shard_size]
if local_ids:
idx = torch.tensor(local_ids, dtype=torch.long, device=logits.device)
masked_logits[t, idx] = logits[t, idx]
# Ensure the actually generated tokens are always kept.
if tokens is not None:
local_tokens = tokens.long() - vocab_start
valid = (local_tokens >= 0) & (local_tokens < vocab_shard_size)
max_pos = min(tokens.size(0), masked_logits.size(0))
positions = torch.arange(max_pos, device=tokens.device)
valid_positions = positions[valid[:max_pos]]
valid_local_tokens = local_tokens[:max_pos][valid[:max_pos]]
masked_logits[valid_positions, valid_local_tokens] = logits[valid_positions, valid_local_tokens]
return masked_logits

Copilot uses AI. Check for mistakes.
Comment on lines +151 to +180
def mask_logits_for_token_ids(
logits: torch.Tensor,
sampling_token_ids: list[list[int]],
vocab_shard_size: int,
tp_rank: int,
tokens: torch.Tensor | None = None,
) -> torch.Tensor:
"""Mask logits to keep only the sampling token subset, setting others to -inf.

During training, this restricts the softmax normalization domain to an
externally provided sampling token subset for each position.

Uses ``torch.where`` (not in-place ``masked_fill_``) so the returned tensor
has a clean autograd graph that is safe for downstream custom autograd
functions (e.g. ``fused_vocab_parallel_cross_entropy``) which modify their
input in-place.

Args:
logits: Logits tensor of shape ``[seq_len, vocab_shard_size]``.
sampling_token_ids: Per-position list of *global* token IDs to keep.
vocab_shard_size: Size of the local vocabulary shard on this TP rank.
tp_rank: Tensor-parallel rank (to map global IDs to local indices).
tokens: Optional ``[seq_len]`` tensor of generated token IDs. When
provided, the generated token at each position is always kept in
the mask as a safety net (it must be in the sampling set, but may
be missing if ``rollout_top_logprobs_num`` was too small).

Returns:
A **new** tensor with non-selected entries replaced by ``-inf``.
"""
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New sampling-mask functionality isn’t covered by tests. Since slime/utils/ppo_utils.py already has unit test coverage for other utilities (e.g., chunked_gae), it would be good to add a focused unit test for mask_logits_for_token_ids (e.g., verifies global→local ID mapping across TP ranks, generated-token fallback behavior, and that masked entries become -inf).

Copilot uses AI. Check for mistakes.
not getattr(args, "use_topp_mask", False) and not getattr(args, "use_topk_mask", False)
) or sampling_token_ids is None:
return log_probs, old_log_probs

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_sampling_mask_to_log_probs reads sampling_logprob_sum but only guards on sampling_token_ids being present. If sampling_logprob_sum is missing (or not aligned), the later zip(old_log_probs, mask_logprob_sum) will throw or silently truncate. Add validation that sampling_logprob_sum exists and matches the per-sample/per-token structure whenever sampling masks are enabled.

Suggested change
# When sampling masks are enabled, ensure that sampling_logprob_sum exists
# and is aligned with old_log_probs so that zip() does not silently truncate
# or iterate over a None value.
if mask_logprob_sum is None:
raise ValueError(
"batch['sampling_logprob_sum'] must be provided when sampling masks are enabled "
"and 'sampling_token_ids' is present."
)
if len(mask_logprob_sum) != len(old_log_probs):
raise ValueError(
"batch['sampling_logprob_sum'] must have the same number of elements as "
"'old_log_probs' when sampling masks are enabled: "
f"{len(mask_logprob_sum)} != {len(old_log_probs)}"
)
for idx, (olp, tlse) in enumerate(zip(old_log_probs, mask_logprob_sum)):
if olp.shape != tlse.shape:
raise ValueError(
"Shape mismatch between old_log_probs and sampling_logprob_sum at index "
f"{idx}: {olp.shape} != {tlse.shape}"
)

Copilot uses AI. Check for mistakes.
Comment on lines +256 to +266
# sampling_token_ids: variable-length nested lists, apply CP slicing but keep as lists
if "sampling_token_ids" in rollout_data:
key = "sampling_token_ids"
rollout_data[key] = [
slice_log_prob_with_cp(
token_ids,
total_length,
response_length,
self.args.qkv_format,
rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None,
)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_token_ids is CP-sliced with slice_log_prob_with_cp, which produces the zigzag CP layout. However, when --allgather-cp is enabled, get_responses() (used by sampling-mask logprob computation) operates on contiguous per-rank chunks before _allgather_cp_redistribute. This mismatch means the per-position sampling_token_ids won’t line up with the logits positions being masked. Either (a) add a dedicated slicing path for sampling_token_ids under allgather_cp that matches the contiguous layout, or (b) explicitly disallow sampling masks with allgather_cp via argument validation.

Copilot uses AI. Check for mistakes.
Comment on lines 450 to 452
else:
val = torch.cat(val).clone().detach()
val = val.mean() * cp_size
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_logprob_sum (introduced for sampling-mask support) will hit the generic tensor aggregation path here (torch.cat(val).mean() * cp_size) because it isn’t handled in the token-level metrics branch. Since it’s per-token and CP-sliced, aggregating it like a scalar can produce misleading values under CP. Consider treating it like other token-level tensors (use get_sum_of_sample_mean(...)) or explicitly skipping it in this logger.

Copilot uses AI. Check for mistakes.
yitianlian and others added 5 commits April 2, 2026 14:41
- Update sglang patch (latest + v0.5.9) with top-p logprobs feature
  from sgl-project/sglang#22244:
  * New top_logprobs_p request parameter for variable-length top-p logprobs
  * New return_logprobs_in_base64 for ~50% response size reduction
  * Full pipeline support (logprob computation, scheduler, tokenizer)
- Update slime rollout to use native sglang top-p instead of
  requesting large top_logprobs_num (256) and filtering client-side:
  * Sends top_logprobs_p + return_logprobs_in_base64 to sglang
  * Decodes base64-encoded top-p logprobs efficiently via numpy
  * Falls back to list-based format when base64 not available
  * Keeps backward compat with old top-K based filtering
- Relax --rollout-top-logprobs-num validation: only required for
  --use-topk-mask, not --use-topp-mask (which uses native API)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Replace 5 functions (_select_sampling_mask_candidates, extract_sampling_mask_candidates,
  _decode_base64_top_p_logprobs, _extract_native_top_p_candidates, append_sampling_mask_to_sample)
  with 4 simpler ones (_logsumexp, _extract_top_p_candidates, _extract_topk_candidates,
  append_sampling_mask_to_sample)
- Use numpy for logsumexp instead of manual math.exp/math.log loops
- For top-p: sglang returns exact candidates natively, just decode and compute logsumexp
- For top-k: sort and take top-K from output_top_logprobs
- Add if-guard around apply_sampling_mask_to_log_probs call in policy_loss_function
  so readers can skip the block when the feature is not enabled

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Extract functions now return per-token (token_id, logprob) pairs
  instead of pre-aggregated logsumexp, enabling set operations
- _extract_topk_candidates now handles base64 format (needed when
  return_logprobs_in_base64=True is set alongside top-k)
- When both use_topp_mask and use_topk_mask are on, take the
  intersection of the two candidate sets per position
- logsumexp computed once on the final candidate set

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Fix ScheduleBatch.top_logprobs_nums missing default (= None)
- Remove duplicate top_logprobs_ps field in ScheduleBatch
- Add top_logprobs_ps field to ModelWorkerBatch dataclass
- Fix process_batch_result_decode: append top-p logprobs per decode step
- Fix _create_batch_output: include top-p output in batch response
- Enable return_logprobs_in_base64 for both top-p and top-k modes

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants