feat: Add GDPO Support#3353
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughThis pull request introduces GDPO (Group Reward-Decoupled Normalization Policy Optimization) support to Axolotl's RLHF training framework. Changes include adding RLType.GDPO enum, creating GDPOStrategy for centralized configuration, defining AxolotlGDPOConfig with per-reward normalization parameters, implementing GDPO trainer classes with decoupled advantage computation, extending the builder to route GDPO configuration, updating data utilities, and adding comprehensive end-to-end tests covering multi-GPU scenarios with varied reward configurations and parallelism strategies. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/core/builders/rl.py (1)
209-210: Add GDPO to the peft_config exclusion.Line 209 excludes
RLType.GRPOfrom passingpeft_configto the trainer, butRLType.GDPOis not similarly excluded. Since GDPO inherits from GRPO's trainer and is tested with adapters, it should be excluded to prevent unexpected behavior.- if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO: + if self.cfg.adapter and self.peft_config and self.cfg.rl not in {RLType.GRPO, RLType.GDPO}: trainer_kwargs["peft_config"] = self.peft_config
🤖 Fix all issues with AI agents
In @src/axolotl/core/trainers/gdpo/__init__.py:
- Around line 161-169: TRLConfig is missing the reward_processing_classes field
referenced by set_trainer_kwargs; add a typed optional field to the TRL config
schema (e.g., reward_processing_classes: list[str] | None = None) in the
TRLConfig/DictDefault definition so the attribute exists for validation and
docs, import typing primitives if needed, and include a short docstring/default;
also apply the same addition to the equivalent schema used by GRPO so
src/axolotl/core/trainers/grpo/__init__.py can safely rely on
cfg.trl.reward_processing_classes.
- Around line 48-69: Move the vllm_cfg = cfg.vllm assignment inside the if
trl.use_vllm: block and replace any uses of trl.vllm.host and trl.vllm.port with
vllm_cfg.host and vllm_cfg.port in the GDPO initializer (references:
gdpo_args_kwargs, trl.use_vllm, trl.vllm_mode, trl.vllm_enable_sleep_mode) so we
don't access trl.vllm when vllm is disabled; apply the same change in the
GRPOStrategy implementation to avoid the AttributeError.
In @src/axolotl/core/trainers/gdpo/trainer.py:
- Line 173: Ruff RUF012 flags the module-level/class variable _tag_names;
annotate it as a ClassVar to indicate it's not an instance attribute: add "from
typing import ClassVar" (or "from typing import ClassVar, List" if you prefer
explicit List) and change "_tag_names = [\"trl\", \"gdpo\", \"axolotl\"]" to
"_tag_names: ClassVar[list[str]] = [\"trl\", \"gdpo\", \"axolotl\"]" (apply the
same change to the other occurrence of _tag_names at the second location).
- Around line 440-457: The std() call used to compute std_grouped_rewards can
return NaN when group size self.num_generations == 1 because torch.std defaults
to unbiased=True; update the std computation to use an unbiased=False (i.e.,
std(dim=1, unbiased=False)) to yield 0 instead of NaN, and apply this same
change in both the GRPO fallback block (where std_grouped_rewards is computed
from rewards.view(-1, self.num_generations).std(...)) and the sequence-parallel
fallback block so that advantages scaling (self.args.scale_rewards) is safe when
num_generations == 1.
- Around line 53-154: compute_gdpo_advantages assumes num_samples %
num_generations == 0 and that tensors are already on the requested device; fix
by (1) at the top of compute_gdpo_advantages, if device is provided call
rewards_per_func = rewards_per_func.to(device) and reward_weights =
reward_weights.to(device) so all ops use the same device, and create
combined_advantages with torch.zeros(num_samples, device=device); (2) before
doing reward_grouped = reward_i.view(-1, num_generations) validate that
num_samples % num_generations == 0 (or compute num_prompts = num_samples //
num_generations) and raise a clear ValueError if not divisible, then reshape
with reward_grouped = reward_i.view(num_prompts, num_generations) to avoid
silent crashes; these changes touch symbols compute_gdpo_advantages, reward_i,
reward_grouped, reward_weights, and combined_advantages.
In @tests/e2e/multigpu/solo/test_gdpo.py:
- Around line 37-67: The test writes the reward module to the CWD using
open(f"rewards_gdpo_{suffix}.py", ...) which can leak files; change the write
target to the provided temp_dir (e.g., os.path.join(temp_dir,
f"rewards_gdpo_{suffix}.py")) and ensure the test can import the module by
either adding temp_dir to sys.path before import or by loading it with importlib
(so the functions like format_reward, correctness_reward, safety_reward,
single_reward and oai_gsm8k_transform remain importable); update the open call
and any subsequent import logic to use that temp_dir path.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
src/axolotl/core/builders/rl.pysrc/axolotl/core/trainers/gdpo/__init__.pysrc/axolotl/core/trainers/gdpo/args.pysrc/axolotl/core/trainers/gdpo/trainer.pysrc/axolotl/utils/data/rl.pysrc/axolotl/utils/schemas/enums.pysrc/axolotl/utils/schemas/trl.pytests/e2e/multigpu/solo/test_gdpo.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-22T08:27:00.129Z
Learnt from: NanoCode012
Repo: axolotl-ai-cloud/axolotl PR: 2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
Applied to files:
src/axolotl/utils/data/rl.pysrc/axolotl/utils/schemas/enums.pysrc/axolotl/core/builders/rl.py
🧬 Code graph analysis (4)
src/axolotl/utils/data/rl.py (1)
src/axolotl/utils/schemas/enums.py (1)
RLType(25-34)
src/axolotl/core/trainers/gdpo/args.py (1)
src/axolotl/core/training_args_base.py (1)
AxolotlTrainingMixins(12-265)
src/axolotl/core/trainers/gdpo/__init__.py (5)
src/axolotl/core/trainers/gdpo/args.py (1)
AxolotlGDPOConfig(16-58)src/axolotl/core/trainers/gdpo/trainer.py (2)
AxolotlGDPOSequenceParallelTrainer(536-974)AxolotlGDPOTrainer(156-533)src/axolotl/utils/dict.py (1)
DictDefault(6-38)src/axolotl/utils/schemas/trl.py (1)
TRLConfig(8-211)src/axolotl/utils/schemas/vllm.py (1)
VllmConfig(8-59)
tests/e2e/multigpu/solo/test_gdpo.py (4)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/e2e/multigpu/solo/test_grpo.py (2)
recursive_kill(130-141)start_vllm(25-127)tests/e2e/utils.py (1)
require_vllm(117-127)src/axolotl/telemetry/runtime_metrics.py (1)
to_dict(75-109)
🪛 GitHub Actions: lint
src/axolotl/core/trainers/gdpo/trainer.py
[error] 212-212: Mypy: Cannot determine type of "_last_loaded_step" [has-type]
[error] 536-536: Mypy: Cannot determine type of "_last_loaded_step" in base class (misc)
[error] 536-536: Mypy: Cannot determine type of "_last_loaded_step" in base class "AxolotlGDPOTrainer" [misc]
🪛 Ruff (0.14.10)
src/axolotl/core/trainers/gdpo/trainer.py
173-173: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
546-546: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
src/axolotl/core/trainers/gdpo/__init__.py
199-201: Avoid specifying long messages outside the exception class
(TRY003)
202-202: Consider moving this statement to an else block
(TRY300)
213-213: Consider moving this statement to an else block
(TRY300)
215-217: Avoid specifying long messages outside the exception class
(TRY003)
239-241: Prefer TypeError exception for invalid type
(TRY004)
239-241: Avoid specifying long messages outside the exception class
(TRY003)
243-243: Consider moving this statement to an else block
(TRY300)
246-246: Avoid specifying long messages outside the exception class
(TRY003)
tests/e2e/multigpu/solo/test_gdpo.py
73-73: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
150-150: Possible binding to all interfaces
(S104)
177-177: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
250-250: Possible binding to all interfaces
(S104)
277-277: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
347-347: Possible binding to all interfaces
(S104)
374-374: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
440-440: Possible binding to all interfaces
(S104)
467-467: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
539-539: Possible binding to all interfaces
(S104)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.9.1)
- GitHub Check: PyTest (3.11, 2.9.1)
- GitHub Check: PyTest (3.11, 2.9.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: test-axolotl-multigpu (128, 12.8.1, 3.11, 2.9.1, fbgemm-gpu, 2, true)
- GitHub Check: test-axolotl-multigpu (128, 12.8.1, 3.11, 2.8.0, fbgemm-gpu, 2, true)
- GitHub Check: test-axolotl-multigpu (130, 13.0.0, 3.11, 2.9.1, fbgemm-gpu, 2, true)
- GitHub Check: preview
🔇 Additional comments (17)
src/axolotl/utils/data/rl.py (1)
176-177: LGTM!The change correctly extends the GRPO behavior to GDPO. Both RL types handle sequence length filtering differently (during generation), so returning
Trueto keep all samples at this stage is appropriate. This aligns with learned behavior that GRPO should be excluded from dataset label checking during preprocessing, and GDPO should follow the same pattern.src/axolotl/utils/schemas/enums.py (1)
29-29: LGTM!The new
GDPOenum member is correctly added and follows the established naming and ordering conventions.src/axolotl/core/builders/rl.py (3)
14-14: LGTM!Import is correctly added for the new GDPOStrategy.
64-70: LGTM!The GDPO trainer selection correctly mirrors the GRPO implementation pattern, using the strategy class to get trainer class, args, and kwargs.
172-175: LGTM!The GDPO training arguments configuration correctly mirrors the GRPO pattern.
src/axolotl/utils/schemas/trl.py (1)
182-210: LGTM!The GDPO-specific configuration fields are well-defined with clear descriptions and sensible defaults. The fields align correctly with
AxolotlGDPOConfigin the args module.tests/e2e/multigpu/solo/test_gdpo.py (6)
23-23: Entire test class is skipped.The class-level skip decorator means none of these tests will run. If GDPO support is being merged, consider enabling at least some basic tests or documenting when these will be un-skipped.
69-172: Test structure looks good.The
test_gdpo_multi_reward_loratest properly:
- Uses parameterization for GPU counts
- Sets up vLLM server with appropriate cleanup in
finallyblock- Tests GDPO-specific configuration options (
gdpo_decoupled_norm,gdpo_batch_norm, etc.)- Uses multiple reward functions as GDPO is designed for
The static analysis warnings about
random.randintand0.0.0.0binding are acceptable in this test context (non-cryptographic use and local test server).
174-272: Test coverage for three reward functions with batch norm.Good coverage of the
gdpo_batch_norm: Trueconfiguration path.
274-369: Single reward fallback test.Good test for verifying GDPO gracefully handles single-reward scenarios (falling back to GRPO-like behavior).
371-462: Full fine-tuning test.Good coverage of the non-adapter GDPO path.
464-561: Sequence parallel test.Good coverage of
context_parallel_sizewith GDPO. However, the vLLM server is started on GPU 1 (CUDA_VISIBLE_DEVICES": "1") while training runs with 2 processes. Ensure this GPU allocation is intentional and doesn't cause resource conflicts.src/axolotl/core/trainers/gdpo/__init__.py (3)
25-38: LGTM!The trainer class and training args class selectors are correctly implemented, following the established strategy pattern.
176-217: Reward function loading logic is sound but could be clearer.The fallback from importable function → local directory → HuggingFace model is reasonable. Consider the static analysis suggestions for cleaner exception handling, though the current implementation is functional.
Minor: Lines 198-201 check parameter count but don't validate the expected signature (
prompts,completions). This is acceptable as runtime will catch mismatches.
219-246: LGTM!The rollout function loading follows a clean pattern with proper error handling.
src/axolotl/core/trainers/gdpo/args.py (1)
1-58: LGTM!The
AxolotlGDPOConfigdataclass is well-structured with:
- Clear module docstring explaining GDPO's purpose
- Comprehensive class docstring with attribute descriptions
- Field defaults matching the
TRLConfigschema- Helpful metadata for each GDPO-specific parameter
The multiple inheritance order (
AxolotlTrainingMixins, GRPOConfig) is consistent across all Axolotl training configs in the codebase (ORPO, KTO, CPO, DPO, GRPO, PRM, Reward).src/axolotl/core/trainers/gdpo/trainer.py (1)
624-679: The context-parallel vLLM indexing and slicing in GDPO exactly matches GRPO's sequence-parallel behavior. Both implementations use identical logic for:
ordered_set_of_promptsconstruction withnum_generationsandcontext_parallel_sizeinterleavingcompletion_idsslicing viasp_group_idandsp_group_startcalculationadvantagesslicing with the same SP group offset logicThe code is correct and requires no changes. Consider adding inline comments (similar to GRPO) to clarify the intent of the slicing operations for future maintainers.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| """Extend the base GRPOTrainer for axolotl helpers""" | ||
|
|
||
| _tag_names = ["trl", "grpo", "axolotl"] | ||
| _last_loaded_step: int |
| type: rewards.oai_gsm8k_transform | ||
| ``` | ||
|
|
||
| You can also use GRPO with explicit aggregation control: |
There was a problem hiding this comment.
| You can also use GRPO with explicit aggregation control: | |
| You can also use GDPO with explicit aggregation control: |
typo?
Edit: Wait, I noticed this part uses grpo, so unsure if intended.
winglian
left a comment
There was a problem hiding this comment.
requirements.txt needs to be updated to trl==0.27.0 in order for the test to actually work I think.
| "description": "Path to custom rollout function. Must be importable from current dir." | ||
| }, | ||
| ) | ||
| multi_objective_aggregation: ( |
There was a problem hiding this comment.
in RLValidationMixin we should add a check that if rl: gdpo is set and multi_objective_aggregation is set incorrectly, it should error
| from tests.e2e.utils import require_vllm | ||
|
|
||
|
|
||
| @pytest.mark.skip(reason="flaky vllm tests in modal") |
There was a problem hiding this comment.
did we test that these are actually flaky in modal? or is this just copy paste from the GRPO test suite?
winglian
left a comment
There was a problem hiding this comment.
good to go. We can re-enable the vllm+g[rd]po tests after the vllm upgrade pr lands
used trl
multi_objective_aggregation: ( Literal["sum_then_normalize", "normalize_then_sum"]huggingface/trl#4785✏️ Tip: You can customize this high-level summary in your review settings.