[bug][train] Fix max_seq_len calculation#1303
[bug][train] Fix max_seq_len calculation#1303tamoghnokandar wants to merge 3 commits intoNovaSky-AI:mainfrom
Conversation
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| if cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": | ||
| if cfg.trainer.algorithm.max_seq_len is None: | ||
| raise ValueError( | ||
| "`trainer.algorithm.max_seq_len` must be set explicitly when " | ||
| "`trainer.algorithm.loss_reduction='seq_mean_token_sum_norm'`. " | ||
| "Choose the total sequence-length normalization constant for your setup; " | ||
| "this often matches the model context window / vLLM `max_model_len` when appropriate." |
There was a problem hiding this comment.
🔴 Breaking change: Dr. GRPO example script fails because auto-calculated max_seq_len fallback was removed
The PR removes the max_seq_len auto-calculation from SkyRLTrainConfig.__post_init__ (skyrl/train/config/config.py:713-722 on LEFT) and adds a hard assertion requiring it to be set explicitly when loss_reduction='seq_mean_token_sum_norm'. However, the official Dr. GRPO example script at examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh:15,23 uses LOSS_REDUCTION="seq_mean_token_sum_norm" but never passes trainer.algorithm.max_seq_len. This script previously worked because __post_init__ auto-computed max_seq_len = max_input_length + max_generate_length. Now it will crash with an AssertionError at validation time.
Same issue in skyrl-agent example
skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh:67 also sets trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" without setting max_seq_len, so it will also fail.
Prompt for agents
Two example scripts need to be updated to explicitly pass trainer.algorithm.max_seq_len now that the auto-calculation fallback has been removed:
1. examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh: Add a line like trainer.algorithm.max_seq_len=1536 (512 + 1024, matching max_prompt_length + max_generate_length from the script) to the uv run command.
2. skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh: Add a line like trainer.algorithm.max_seq_len=40768 (8000 + 32768, matching max_prompt_length + max_generate_length from the script) to the uv run command.
Both scripts use loss_reduction=seq_mean_token_sum_norm and will now fail the new assertion at skyrl/train/utils/utils.py:279-285 without this fix.
Was this helpful? React with 👍 or 👎 to provide feedback.
| if cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": | ||
| if cfg.trainer.algorithm.max_seq_len is None: | ||
| raise ValueError( | ||
| "`trainer.algorithm.max_seq_len` must be set explicitly when " | ||
| "`trainer.algorithm.loss_reduction='seq_mean_token_sum_norm'`. " | ||
| "Choose the total sequence-length normalization constant for your setup; " | ||
| "this often matches the model context window / vLLM `max_model_len` when appropriate." |
There was a problem hiding this comment.
🔴 Breaking change: Dr. GRPO example script fails because auto-calculated max_seq_len fallback was removed
The PR removes the max_seq_len auto-calculation from SkyRLTrainConfig.__post_init__ (skyrl/train/config/config.py:713-722 on LEFT) and adds a hard assertion requiring it to be set explicitly when loss_reduction='seq_mean_token_sum_norm'. However, the official Dr. GRPO example script at examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh:15,23 uses LOSS_REDUCTION="seq_mean_token_sum_norm" but never passes trainer.algorithm.max_seq_len. This script previously worked because __post_init__ auto-computed max_seq_len = max_input_length + max_generate_length. Now it will crash with an AssertionError at validation time.
Same issue in skyrl-agent example
skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh:67 also sets trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" without setting max_seq_len, so it will also fail.
Prompt for agents
Two example scripts need to be updated to explicitly pass trainer.algorithm.max_seq_len now that the auto-calculation fallback has been removed:
1. examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh: Add a line like trainer.algorithm.max_seq_len=1536 (512 + 1024, matching max_prompt_length + max_generate_length from the script) to the uv run command.
2. skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh: Add a line like trainer.algorithm.max_seq_len=40768 (8000 + 32768, matching max_prompt_length + max_generate_length from the script) to the uv run command.
Both scripts use loss_reduction=seq_mean_token_sum_norm and will now fail the new assertion at skyrl/train/utils/utils.py:279-285 without this fix.
Was this helpful? React with 👍 or 👎 to provide feedback.
| """Used for ``seq_mean_token_sum_norm`` loss reduction. | ||
| Must be set explicitly for that reduction mode; otherwise can remain ``None``.""" |
There was a problem hiding this comment.
🟡 Stale comment in reduce_loss claims max_seq_len has a default fallback that no longer exists
At skyrl/backends/skyrl_train/utils/ppo_utils.py:999-1000, the comment says "NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length". This auto-calculation default was removed by this PR (deleted from skyrl/train/config/config.py:713-722), and max_seq_len must now always be set explicitly when using seq_mean_token_sum_norm. The stale comment will mislead developers into thinking a fallback still exists.
Prompt for agents
Update the stale comment in skyrl/backends/skyrl_train/utils/ppo_utils.py at lines 999-1000. The comment currently reads:
# NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to
# cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length
It should be updated to something like:
# NOTE: max_seq_len must be set explicitly via algorithm.max_seq_len when using seq_mean_token_sum_norm loss reduction.
This aligns with the new docstring at skyrl/train/config/config.py:374-375 and the validation at skyrl/train/utils/utils.py:279-285.
Was this helpful? React with 👍 or 👎 to provide feedback.
Fixes #1154
Summary
This PR removes the implicit
max_seq_lenheuristic calculation and requires users to set it explicitly when usingtrainer.algorithm.loss_reduction=seq_mean_token_sum_norm.Changes
trainer.algorithm.max_seq_lendefault fromSkyRLTrainConfig.__post_init__trainer.algorithm.max_seq_lento be explicitly set whenloss_reduction == "seq_mean_token_sum_norm"max_seq_lenmust be chosen based on the user’s intended sequence-length normalization budgetmax_seq_lenremainingNoneby defaultmax_seq_lenvalues being preservedvalidate_cfg()failing whenseq_mean_token_sum_normis used withoutmax_seq_lenvalidate_cfg()continuing to allowtoken_meanandsequence_meanwithoutmax_seq_lenvalidate_cfg()passing whenseq_mean_token_sum_normis used with an explicitmax_seq_lenTesting
tests/train/test_config.pyfor the new behavior