Conversation
There was a problem hiding this comment.
Pull request overview
This PR targets loss-time OOMs by reducing autograd graph duplication when computing PPO log-probs/entropy, and by adjusting checkpointing behavior in the Megatron loss path.
Changes:
- Reworks
get_log_probs_and_entropy()to compute log-probs/entropy over the full packed logits tensor once, then slice per-sample response segments. - Adds an option to compute entropy under
torch.no_grad()when entropy gradients aren’t needed (e.g.,entropy_coef == 0). - Switches loss recomputation checkpointing to
use_reentrant=False.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
slime/utils/ppo_utils.py |
Adds need_entropy_grad to avoid cloning / graph retention for entropy when gradients aren’t needed. |
slime/backends/megatron_utils/loss.py |
Refactors log-prob/entropy computation to a single full-tensor pass with new token-building/slicing helpers; updates CP redistribution API; adjusts checkpoint configuration. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| tp_group = mpu.get_tensor_model_parallel_group() | ||
| chunk_size = args.log_probs_chunk_size | ||
| need_entropy_grad = with_entropy and args.entropy_coef != 0 | ||
|
|
||
| # --- build full shifted-token target tensor --- | ||
| full_tokens = _build_shifted_tokens( | ||
| T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp | ||
| ) | ||
|
|
||
| # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- | ||
| log_prob_full, entropy_full = calculate_log_probs_and_entropy( | ||
| logits, | ||
| full_tokens, | ||
| tp_group, | ||
| with_entropy=with_entropy, | ||
| chunk_size=chunk_size, | ||
| need_entropy_grad=need_entropy_grad, | ||
| ) | ||
| log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] | ||
|
|
||
| # --- extract per-sample response portions --- | ||
| log_probs_list, entropy_list = _extract_per_sample( | ||
| log_prob_full, | ||
| entropy_full, | ||
| unconcat_tokens, | ||
| total_lengths, | ||
| response_lengths, | ||
| qkv_format, | ||
| max_seq_lens, | ||
| args.allgather_cp, | ||
| ) | ||
|
|
||
| res = {"log_probs": log_probs_list} | ||
| if with_entropy: | ||
| res["entropy"] = entropy_list | ||
|
|
||
| # we need to turn the all gather kv into zigzag ring attn kv | ||
| if args.allgather_cp: | ||
| _allgather_cp_redistribute( | ||
| res, | ||
| logits=logits, | ||
| logits_local_len=T, | ||
| args=args, | ||
| total_lengths=total_lengths, | ||
| response_lengths=response_lengths, | ||
| max_seq_lens=max_seq_lens, | ||
| ) |
There was a problem hiding this comment.
need_entropy_grad correctly computes entropy under torch.no_grad() when entropy_coef == 0, but if args.allgather_cp is enabled, _allgather_cp_redistribute() currently materializes missing segments with requires_grad=True and uses differentiable dist.nn.all_reduce, which can cause the redistributed entropy tensors to become grad-tracked again and negate the intended memory reduction. Consider ensuring _allgather_cp_redistribute() preserves requires_grad from the incoming tensors (or using a non-autograd all-reduce for no-grad metrics).
| assert non_loss_data | ||
| log_probs_list = [] | ||
| entropy_list = [] | ||
| for logits_chunk, tokens_chunk in get_responses( | ||
| logits, | ||
| args=args, | ||
| unconcat_tokens=unconcat_tokens, | ||
| total_lengths=total_lengths, | ||
| response_lengths=response_lengths, | ||
| max_seq_lens=max_seq_lens, | ||
| ): | ||
| log_prob, entropy = calculate_log_probs_and_entropy( | ||
| logits_chunk, | ||
| tokens_chunk, | ||
| mpu.get_tensor_model_parallel_group(), | ||
| with_entropy=with_entropy, | ||
| chunk_size=args.log_probs_chunk_size, | ||
| ) | ||
| qkv_format = args.qkv_format | ||
|
|
||
| log_probs_list.append(log_prob.squeeze(-1)) | ||
| entropy_list.append(entropy) | ||
| assert logits.dtype == torch.float32, f"{logits.dtype}" | ||
| assert len(logits.shape) == 3, f"{logits.shape}" | ||
|
|
||
| res = { | ||
| "log_probs": log_probs_list, | ||
| } | ||
| if qkv_format == "thd": | ||
| assert logits.size(0) == 1, f"{logits.shape}" | ||
| logits = logits.squeeze(0) | ||
| else: | ||
| assert max_seq_lens is not None | ||
| logits = logits.view(-1, logits.size(-1)) |
There was a problem hiding this comment.
args.allgather_cp appears to be implemented only for qkv_format == "thd" in megatron_utils/data.py (the bshd path never does the global-concat + chunk split). Here, get_log_probs_and_entropy() will still run the allgather-CP code paths even when qkv_format != "thd", which would produce incorrect token/logit alignment under misconfiguration. Consider adding an explicit assert not args.allgather_cp or qkv_format == "thd" near the top of this function to fail fast.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| entropy_input = logits.clone() if need_entropy_grad else logits | ||
| entropy = compute_entropy_from_logits(entropy_input, tp_group) | ||
|
|
||
| log_prob = compute_log_probs(logits.clone(), tokens, tp_group) |
There was a problem hiding this comment.
I wonder if some of the code movement (moving log_prob down here) is necessary. And also if we don't need to calculate the grad of entropy, do we still need to do logits.clone()?
before optimize



after optmize
grad norm