Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def importance_sampling_loss(
def reduce_loss(
loss: torch.Tensor,
loss_mask: Optional[torch.Tensor],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm", "sum"],
max_seq_len: Optional[int] = None,
) -> torch.Tensor:
if loss_reduction == "token_mean":
Expand All @@ -1004,6 +1004,11 @@ def reduce_loss(
# If no mask, assume all tokens are valid
seq_losses = torch.sum(loss, dim=-1) / max_seq_len
loss = torch.mean(seq_losses)
elif loss_reduction == "sum":
if loss_mask is not None:
loss = (loss * loss_mask).sum()
else:
loss = loss.sum()
else:
raise ValueError(f"Invalid loss reduction type: {loss_reduction}")
return loss
Expand Down
2 changes: 1 addition & 1 deletion skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ trainer:
advantage_batch_normalize: false
value_head_prefix: "value_head"
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", "clip_cov", "kl_cov", or customizable with PolicyLossRegistry
loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm"
loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm", "sum"
grpo_norm_by_std: true # set to false to disable normalization by std in GRPO
zero_variance_filter: false # set to true to loss mask out prompts with zero variance rewards. only applicable when rewards are response-level.
# GAE parameters
Expand Down