Skip to content

[WIP] fix loss oom#1788

Merged
zhuzilin merged 2 commits intoTHUDM:mainfrom
lilei199908:fix/loss_oom
Apr 4, 2026
Merged

[WIP] fix loss oom#1788
zhuzilin merged 2 commits intoTHUDM:mainfrom
lilei199908:fix/loss_oom

Conversation

@lilei199908
Copy link
Copy Markdown
Collaborator

@lilei199908 lilei199908 commented Mar 31, 2026

before optimize
image
after optmize
image
grad norm
image

Copilot AI review requested due to automatic review settings March 31, 2026 09:17
@lilei199908 lilei199908 changed the title fix loss oom [WIP] fix loss oom Mar 31, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets loss-time OOMs by reducing autograd graph duplication when computing PPO log-probs/entropy, and by adjusting checkpointing behavior in the Megatron loss path.

Changes:

  • Reworks get_log_probs_and_entropy() to compute log-probs/entropy over the full packed logits tensor once, then slice per-sample response segments.
  • Adds an option to compute entropy under torch.no_grad() when entropy gradients aren’t needed (e.g., entropy_coef == 0).
  • Switches loss recomputation checkpointing to use_reentrant=False.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
slime/utils/ppo_utils.py Adds need_entropy_grad to avoid cloning / graph retention for entropy when gradients aren’t needed.
slime/backends/megatron_utils/loss.py Refactors log-prob/entropy computation to a single full-tensor pass with new token-building/slicing helpers; updates CP redistribution API; adjusts checkpoint configuration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +418 to 463
tp_group = mpu.get_tensor_model_parallel_group()
chunk_size = args.log_probs_chunk_size
need_entropy_grad = with_entropy and args.entropy_coef != 0

# --- build full shifted-token target tensor ---
full_tokens = _build_shifted_tokens(
T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp
)

# --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy ---
log_prob_full, entropy_full = calculate_log_probs_and_entropy(
logits,
full_tokens,
tp_group,
with_entropy=with_entropy,
chunk_size=chunk_size,
need_entropy_grad=need_entropy_grad,
)
log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T]

# --- extract per-sample response portions ---
log_probs_list, entropy_list = _extract_per_sample(
log_prob_full,
entropy_full,
unconcat_tokens,
total_lengths,
response_lengths,
qkv_format,
max_seq_lens,
args.allgather_cp,
)

res = {"log_probs": log_probs_list}
if with_entropy:
res["entropy"] = entropy_list

# we need to turn the all gather kv into zigzag ring attn kv
if args.allgather_cp:
_allgather_cp_redistribute(
res,
logits=logits,
logits_local_len=T,
args=args,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need_entropy_grad correctly computes entropy under torch.no_grad() when entropy_coef == 0, but if args.allgather_cp is enabled, _allgather_cp_redistribute() currently materializes missing segments with requires_grad=True and uses differentiable dist.nn.all_reduce, which can cause the redistributed entropy tensors to become grad-tracked again and negate the intended memory reduction. Consider ensuring _allgather_cp_redistribute() preserves requires_grad from the incoming tensors (or using a non-autograd all-reduce for no-grad metrics).

Copilot uses AI. Check for mistakes.
Comment on lines 402 to +413
assert non_loss_data
log_probs_list = []
entropy_list = []
for logits_chunk, tokens_chunk in get_responses(
logits,
args=args,
unconcat_tokens=unconcat_tokens,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
):
log_prob, entropy = calculate_log_probs_and_entropy(
logits_chunk,
tokens_chunk,
mpu.get_tensor_model_parallel_group(),
with_entropy=with_entropy,
chunk_size=args.log_probs_chunk_size,
)
qkv_format = args.qkv_format

log_probs_list.append(log_prob.squeeze(-1))
entropy_list.append(entropy)
assert logits.dtype == torch.float32, f"{logits.dtype}"
assert len(logits.shape) == 3, f"{logits.shape}"

res = {
"log_probs": log_probs_list,
}
if qkv_format == "thd":
assert logits.size(0) == 1, f"{logits.shape}"
logits = logits.squeeze(0)
else:
assert max_seq_lens is not None
logits = logits.view(-1, logits.size(-1))
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.allgather_cp appears to be implemented only for qkv_format == "thd" in megatron_utils/data.py (the bshd path never does the global-concat + chunk split). Here, get_log_probs_and_entropy() will still run the allgather-CP code paths even when qkv_format != "thd", which would produce incorrect token/logit alignment under misconfiguration. Consider adding an explicit assert not args.allgather_cp or qkv_format == "thd" near the top of this function to fail fast.

Copilot uses AI. Check for mistakes.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
entropy_input = logits.clone() if need_entropy_grad else logits
entropy = compute_entropy_from_logits(entropy_input, tp_group)

log_prob = compute_log_probs(logits.clone(), tokens, tp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if some of the code movement (moving log_prob down here) is necessary. And also if we don't need to calculate the grad of entropy, do we still need to do logits.clone()?

@zhuzilin zhuzilin merged commit 2a89c16 into THUDM:main Apr 4, 2026
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants