Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 93353a6821
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if "sampling_token_ids" in rollout_data: | ||
| key = "sampling_token_ids" | ||
| rollout_data[key] = [ | ||
| slice_log_prob_with_cp( |
There was a problem hiding this comment.
Skip zigzag CP slicing for sampling ids in allgather mode
This unconditionally applies slice_log_prob_with_cp to sampling_token_ids, but in --allgather_cp mode the training path first consumes contiguous CP chunks and only later redistributes to zigzag layout. That means the per-position candidate lists are misaligned with get_responses(..., allgather_cp=True) output, which can produce out-of-bounds indexing or incorrect token masks when use_topp_mask/use_topk_mask is enabled with CP>1. Guard this branch for allgather_cp (or defer slicing until after redistribution) so mask positions match the logits chunk layout.
Useful? React with 👍 / 👎.
| max_seq_lens=max_seq_lens, | ||
| ) | ||
|
|
||
| masked_old_log_probs = [olp - tlse for olp, tlse in zip(old_log_probs, mask_logprob_sum)] |
There was a problem hiding this comment.
Normalize old log-probs over the same masked support
This subtracts sampling_logprob_sum directly, but that sum is built from output_top_logprobs candidates only; when rollout_top_logprobs_num is too small, the sampled token can be missing from that set. The new-policy side explicitly keeps sampled tokens in the mask, so the two sides can end up normalized over different supports, biasing PPO ratios/KL instead of preserving rollout-training consistency. Include sampled-token probability in the stored sum (or adjust it here) before computing masked_old_log_probs.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Pull request overview
Adds infrastructure to record a rollout-time sampling “support” (top‑p / top‑k candidate token sets) and reuse it during Megatron training to align log-prob normalization between rollout and training.
Changes:
- Extend
Sample/ rollout→train data plumbing to carry per-position sampling candidates (sampling_token_ids) and normalization mass (sampling_logprob_sum). - Add rollout-side extraction of sampling candidates from SGLang
output_top_logprobs, plus CLI flags to enable/validate this behavior. - Add training-side utilities to mask logits to the recorded candidate set and recompute masked log-probs for PPO loss.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
slime/utils/types.py |
Adds sampling-mask-related fields to Sample. |
slime/utils/ppo_utils.py |
Introduces mask_logits_for_token_ids utility used during training. |
slime/utils/arguments.py |
Adds --use-topp-mask/--use-topk-mask and validation for sampling-mask settings. |
slime/rollout/sglang_rollout.py |
Extracts per-position candidate sets from output_top_logprobs and stores them on samples; requests top_logprobs_num when enabled. |
slime/ray/rollout.py |
Ships sampling-mask fields to training and logs related metrics. |
slime/backends/megatron_utils/model.py |
Plumbs new batch keys into Megatron get_batch(...) call sites. |
slime/backends/megatron_utils/loss.py |
Applies sampling-mask normalization to train-time log-probs and old log-probs. |
slime/backends/megatron_utils/data.py |
Adjusts rollout logging behavior; adds special handling for sampling_token_ids. |
slime/backends/megatron_utils/actor.py |
CP-slices and stages sampling-mask fields onto GPU for training. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
slime/rollout/sglang_rollout.py
Outdated
| top_p=args.rollout_top_p, | ||
| use_topk_mask=getattr(args, "use_topk_mask", False), | ||
| top_k=args.rollout_top_k, | ||
| ) |
There was a problem hiding this comment.
append_sampling_mask_to_sample builds sampling_token_ids/sampling_logprob_sum solely from output_top_logprobs. If the actually generated token is missing from output_top_logprobs (possible when top_logprobs_num is too small, as noted elsewhere), training later force-keeps that token in the mask but sampling_logprob_sum will not include its probability mass, so old log-prob renormalization becomes inconsistent. Consider explicitly inserting the generated token ID per position into sampling_token_ids (and updating the corresponding log-sum-exp) or failing fast when it is missing.
| ) | |
| ) | |
| # Optional consistency check: ensure the actually generated token at each | |
| # position is present in the corresponding sampling mask when we have | |
| # access to the generated token ids in meta_info. If it is missing, we | |
| # fail fast instead of silently constructing an inconsistent mask. | |
| generated_token_ids = meta_info.get("output_token_ids") | |
| if generated_token_ids is None: | |
| generated_token_ids = meta_info.get("output_ids") | |
| if isinstance(generated_token_ids, list) and len(generated_token_ids) == len(new_sampling_mask_ids): | |
| for pos, (gen_id, candidate_ids) in enumerate(zip(generated_token_ids, new_sampling_mask_ids)): | |
| if gen_id not in candidate_ids: | |
| raise ValueError( | |
| "Generated token id is missing from output_top_logprobs at position " | |
| f"{pos}. This can happen if top_logprobs_num is too small; " | |
| "increase it or disable use_topp_mask/use_topk_mask." | |
| ) |
| mask = torch.zeros_like(logits, dtype=torch.bool) | ||
| for t, ids in enumerate(sampling_token_ids): | ||
| local_ids = [gid - vocab_start for gid in ids if vocab_start <= gid < vocab_start + vocab_shard_size] | ||
| if local_ids: | ||
| idx = torch.tensor(local_ids, dtype=torch.long, device=logits.device) | ||
| mask[t].scatter_(0, idx, True) | ||
| # Ensure the actually generated tokens are always kept in the mask. | ||
| if tokens is not None: | ||
| local_tokens = tokens.long() - vocab_start | ||
| valid = (local_tokens >= 0) & (local_tokens < vocab_shard_size) | ||
| positions = torch.arange(min(tokens.size(0), mask.size(0)), device=tokens.device) | ||
| mask[positions[valid], local_tokens[valid]] = True | ||
| return torch.where(mask, logits, torch.tensor(float("-inf"), device=logits.device, dtype=logits.dtype)) |
There was a problem hiding this comment.
mask_logits_for_token_ids constructs a dense boolean mask of shape [seq_len, vocab_shard_size] and fills it in a Python loop. For typical RLHF vocab sizes this can add significant memory overhead and runtime (allocating a full mask per forward). Consider a sparse approach (e.g., compute per-position logsumexp over the selected indices and subtract it from the gathered token logits, with TP all-reduce), or at least avoid per-step tensor allocations in the loop.
| mask = torch.zeros_like(logits, dtype=torch.bool) | |
| for t, ids in enumerate(sampling_token_ids): | |
| local_ids = [gid - vocab_start for gid in ids if vocab_start <= gid < vocab_start + vocab_shard_size] | |
| if local_ids: | |
| idx = torch.tensor(local_ids, dtype=torch.long, device=logits.device) | |
| mask[t].scatter_(0, idx, True) | |
| # Ensure the actually generated tokens are always kept in the mask. | |
| if tokens is not None: | |
| local_tokens = tokens.long() - vocab_start | |
| valid = (local_tokens >= 0) & (local_tokens < vocab_shard_size) | |
| positions = torch.arange(min(tokens.size(0), mask.size(0)), device=tokens.device) | |
| mask[positions[valid], local_tokens[valid]] = True | |
| return torch.where(mask, logits, torch.tensor(float("-inf"), device=logits.device, dtype=logits.dtype)) | |
| # Initialize all logits to -inf, then selectively copy over allowed positions. | |
| masked_logits = logits.new_full(logits.shape, float("-inf")) | |
| for t, ids in enumerate(sampling_token_ids): | |
| local_ids = [gid - vocab_start for gid in ids if vocab_start <= gid < vocab_start + vocab_shard_size] | |
| if local_ids: | |
| idx = torch.tensor(local_ids, dtype=torch.long, device=logits.device) | |
| masked_logits[t, idx] = logits[t, idx] | |
| # Ensure the actually generated tokens are always kept. | |
| if tokens is not None: | |
| local_tokens = tokens.long() - vocab_start | |
| valid = (local_tokens >= 0) & (local_tokens < vocab_shard_size) | |
| max_pos = min(tokens.size(0), masked_logits.size(0)) | |
| positions = torch.arange(max_pos, device=tokens.device) | |
| valid_positions = positions[valid[:max_pos]] | |
| valid_local_tokens = local_tokens[:max_pos][valid[:max_pos]] | |
| masked_logits[valid_positions, valid_local_tokens] = logits[valid_positions, valid_local_tokens] | |
| return masked_logits |
| def mask_logits_for_token_ids( | ||
| logits: torch.Tensor, | ||
| sampling_token_ids: list[list[int]], | ||
| vocab_shard_size: int, | ||
| tp_rank: int, | ||
| tokens: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """Mask logits to keep only the sampling token subset, setting others to -inf. | ||
|
|
||
| During training, this restricts the softmax normalization domain to an | ||
| externally provided sampling token subset for each position. | ||
|
|
||
| Uses ``torch.where`` (not in-place ``masked_fill_``) so the returned tensor | ||
| has a clean autograd graph that is safe for downstream custom autograd | ||
| functions (e.g. ``fused_vocab_parallel_cross_entropy``) which modify their | ||
| input in-place. | ||
|
|
||
| Args: | ||
| logits: Logits tensor of shape ``[seq_len, vocab_shard_size]``. | ||
| sampling_token_ids: Per-position list of *global* token IDs to keep. | ||
| vocab_shard_size: Size of the local vocabulary shard on this TP rank. | ||
| tp_rank: Tensor-parallel rank (to map global IDs to local indices). | ||
| tokens: Optional ``[seq_len]`` tensor of generated token IDs. When | ||
| provided, the generated token at each position is always kept in | ||
| the mask as a safety net (it must be in the sampling set, but may | ||
| be missing if ``rollout_top_logprobs_num`` was too small). | ||
|
|
||
| Returns: | ||
| A **new** tensor with non-selected entries replaced by ``-inf``. | ||
| """ |
There was a problem hiding this comment.
New sampling-mask functionality isn’t covered by tests. Since slime/utils/ppo_utils.py already has unit test coverage for other utilities (e.g., chunked_gae), it would be good to add a focused unit test for mask_logits_for_token_ids (e.g., verifies global→local ID mapping across TP ranks, generated-token fallback behavior, and that masked entries become -inf).
| not getattr(args, "use_topp_mask", False) and not getattr(args, "use_topk_mask", False) | ||
| ) or sampling_token_ids is None: | ||
| return log_probs, old_log_probs | ||
|
|
There was a problem hiding this comment.
apply_sampling_mask_to_log_probs reads sampling_logprob_sum but only guards on sampling_token_ids being present. If sampling_logprob_sum is missing (or not aligned), the later zip(old_log_probs, mask_logprob_sum) will throw or silently truncate. Add validation that sampling_logprob_sum exists and matches the per-sample/per-token structure whenever sampling masks are enabled.
| # When sampling masks are enabled, ensure that sampling_logprob_sum exists | |
| # and is aligned with old_log_probs so that zip() does not silently truncate | |
| # or iterate over a None value. | |
| if mask_logprob_sum is None: | |
| raise ValueError( | |
| "batch['sampling_logprob_sum'] must be provided when sampling masks are enabled " | |
| "and 'sampling_token_ids' is present." | |
| ) | |
| if len(mask_logprob_sum) != len(old_log_probs): | |
| raise ValueError( | |
| "batch['sampling_logprob_sum'] must have the same number of elements as " | |
| "'old_log_probs' when sampling masks are enabled: " | |
| f"{len(mask_logprob_sum)} != {len(old_log_probs)}" | |
| ) | |
| for idx, (olp, tlse) in enumerate(zip(old_log_probs, mask_logprob_sum)): | |
| if olp.shape != tlse.shape: | |
| raise ValueError( | |
| "Shape mismatch between old_log_probs and sampling_logprob_sum at index " | |
| f"{idx}: {olp.shape} != {tlse.shape}" | |
| ) |
| # sampling_token_ids: variable-length nested lists, apply CP slicing but keep as lists | ||
| if "sampling_token_ids" in rollout_data: | ||
| key = "sampling_token_ids" | ||
| rollout_data[key] = [ | ||
| slice_log_prob_with_cp( | ||
| token_ids, | ||
| total_length, | ||
| response_length, | ||
| self.args.qkv_format, | ||
| rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, | ||
| ) |
There was a problem hiding this comment.
sampling_token_ids is CP-sliced with slice_log_prob_with_cp, which produces the zigzag CP layout. However, when --allgather-cp is enabled, get_responses() (used by sampling-mask logprob computation) operates on contiguous per-rank chunks before _allgather_cp_redistribute. This mismatch means the per-position sampling_token_ids won’t line up with the logits positions being masked. Either (a) add a dedicated slicing path for sampling_token_ids under allgather_cp that matches the contiguous layout, or (b) explicitly disallow sampling masks with allgather_cp via argument validation.
| else: | ||
| val = torch.cat(val).clone().detach() | ||
| val = val.mean() * cp_size |
There was a problem hiding this comment.
sampling_logprob_sum (introduced for sampling-mask support) will hit the generic tensor aggregation path here (torch.cat(val).mean() * cp_size) because it isn’t handled in the token-level metrics branch. Since it’s per-token and CP-sliced, aggregating it like a scalar can produce misleading values under CP. Consider treating it like other token-level tensors (use get_sum_of_sample_mean(...)) or explicitly skipping it in this logger.
- Update sglang patch (latest + v0.5.9) with top-p logprobs feature from sgl-project/sglang#22244: * New top_logprobs_p request parameter for variable-length top-p logprobs * New return_logprobs_in_base64 for ~50% response size reduction * Full pipeline support (logprob computation, scheduler, tokenizer) - Update slime rollout to use native sglang top-p instead of requesting large top_logprobs_num (256) and filtering client-side: * Sends top_logprobs_p + return_logprobs_in_base64 to sglang * Decodes base64-encoded top-p logprobs efficiently via numpy * Falls back to list-based format when base64 not available * Keeps backward compat with old top-K based filtering - Relax --rollout-top-logprobs-num validation: only required for --use-topk-mask, not --use-topp-mask (which uses native API) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Replace 5 functions (_select_sampling_mask_candidates, extract_sampling_mask_candidates, _decode_base64_top_p_logprobs, _extract_native_top_p_candidates, append_sampling_mask_to_sample) with 4 simpler ones (_logsumexp, _extract_top_p_candidates, _extract_topk_candidates, append_sampling_mask_to_sample) - Use numpy for logsumexp instead of manual math.exp/math.log loops - For top-p: sglang returns exact candidates natively, just decode and compute logsumexp - For top-k: sort and take top-K from output_top_logprobs - Add if-guard around apply_sampling_mask_to_log_probs call in policy_loss_function so readers can skip the block when the feature is not enabled Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Extract functions now return per-token (token_id, logprob) pairs instead of pre-aggregated logsumexp, enabling set operations - _extract_topk_candidates now handles base64 format (needed when return_logprobs_in_base64=True is set alongside top-k) - When both use_topp_mask and use_topk_mask are on, take the intersection of the two candidate sets per position - logsumexp computed once on the final candidate set Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Fix ScheduleBatch.top_logprobs_nums missing default (= None) - Remove duplicate top_logprobs_ps field in ScheduleBatch - Add top_logprobs_ps field to ModelWorkerBatch dataclass - Fix process_batch_result_decode: append top-p logprobs per decode step - Fix _create_batch_output: include top-p output in batch response - Enable return_logprobs_in_base64 for both top-p and top-k modes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
No description provided.