Skip to content

feat: Add GDPO Support#3353

Merged
winglian merged 22 commits into
axolotl-ai-cloud:mainfrom
ved1beta:gdpo-ved
Jan 21, 2026
Merged

feat: Add GDPO Support#3353
winglian merged 22 commits into
axolotl-ai-cloud:mainfrom
ved1beta:gdpo-ved

Conversation

@ved1beta

@ved1beta ved1beta commented Jan 11, 2026

Copy link
Copy Markdown
Member

used trl multi_objective_aggregation: ( Literal["sum_then_normalize", "normalize_then_sum"] huggingface/trl#4785

  • Tests
    • Added comprehensive end-to-end test suite for GDPO across multi-GPU setups with diverse configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai

coderabbitai Bot commented Jan 11, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

This pull request introduces GDPO (Group Reward-Decoupled Normalization Policy Optimization) support to Axolotl's RLHF training framework. Changes include adding RLType.GDPO enum, creating GDPOStrategy for centralized configuration, defining AxolotlGDPOConfig with per-reward normalization parameters, implementing GDPO trainer classes with decoupled advantage computation, extending the builder to route GDPO configuration, updating data utilities, and adding comprehensive end-to-end tests covering multi-GPU scenarios with varied reward configurations and parallelism strategies.

Changes

Cohort / File(s) Summary
Configuration & Enums
src/axolotl/utils/schemas/enums.py, src/axolotl/utils/schemas/trl.py
Added RLType.GDPO enum member; extended TRLConfig with four GDPO-specific fields: gdpo_decoupled_norm, gdpo_batch_norm, gdpo_epsilon, and gdpo_per_reward_scale with default values and help text.
GDPO Args & Strategy
src/axolotl/core/trainers/gdpo/args.py, src/axolotl/core/trainers/gdpo/__init__.py
Introduced AxolotlGDPOConfig dataclass extending AxolotlTrainingMixins and GRPOConfig with per-reward normalization and context parallelism parameters. Created GDPOStrategy class with methods for trainer class selection, training args assembly, reward function resolution with dynamic imports, and blocklist configuration.
GDPO Trainer Implementation
src/axolotl/core/trainers/gdpo/trainer.py
Implemented compute_gdpo_advantages function for decoupled per-reward normalization; added AxolotlGDPOTrainer extending AxolotlGRPOTrainer and AxolotlGDPOSequenceParallelTrainer extending both AxolotlGRPOSequenceParallelTrainer and AxolotlGDPOTrainer. Both handle multi-reward generation and scoring with per-reward computation, optional batch normalization, and extensive logging.
Builder Integration
src/axolotl/core/builders/rl.py
Extended RLHFTrainerBuilder._get_trainer_cls and _build_training_arguments to detect RLType.GDPO and delegate trainer class selection, argument extension, and kwargs updates to GDPOStrategy.
Data Utilities
src/axolotl/utils/data/rl.py
Updated _drop_long_sequences to treat GDPO identically to GRPO via membership check.
End-to-End Tests
tests/e2e/multigpu/solo/test_gdpo.py
Added TestGDPO class with five test scenarios: multi-reward LoRA across 1–2 GPUs, three-reward with batch normalization toggle, single-reward fallback to GRPO behavior, full fine-tuning variant, and sequence parallelism with context_parallel_size. Tests spawn VLLM servers and execute training via axolotl CLI.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Suggested reviewers

  • winglian
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: Add GDPO Support' directly and clearly summarizes the main change—introducing GDPO (Group Reward-Decoupled Normalization Policy Optimization) support to the codebase, which is confirmed by the file changes adding GDPOStrategy, trainers, config, and tests.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/axolotl/core/builders/rl.py (1)

209-210: Add GDPO to the peft_config exclusion.

Line 209 excludes RLType.GRPO from passing peft_config to the trainer, but RLType.GDPO is not similarly excluded. Since GDPO inherits from GRPO's trainer and is tested with adapters, it should be excluded to prevent unexpected behavior.

-        if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
+        if self.cfg.adapter and self.peft_config and self.cfg.rl not in {RLType.GRPO, RLType.GDPO}:
             trainer_kwargs["peft_config"] = self.peft_config
🤖 Fix all issues with AI agents
In @src/axolotl/core/trainers/gdpo/__init__.py:
- Around line 161-169: TRLConfig is missing the reward_processing_classes field
referenced by set_trainer_kwargs; add a typed optional field to the TRL config
schema (e.g., reward_processing_classes: list[str] | None = None) in the
TRLConfig/DictDefault definition so the attribute exists for validation and
docs, import typing primitives if needed, and include a short docstring/default;
also apply the same addition to the equivalent schema used by GRPO so
src/axolotl/core/trainers/grpo/__init__.py can safely rely on
cfg.trl.reward_processing_classes.
- Around line 48-69: Move the vllm_cfg = cfg.vllm assignment inside the if
trl.use_vllm: block and replace any uses of trl.vllm.host and trl.vllm.port with
vllm_cfg.host and vllm_cfg.port in the GDPO initializer (references:
gdpo_args_kwargs, trl.use_vllm, trl.vllm_mode, trl.vllm_enable_sleep_mode) so we
don't access trl.vllm when vllm is disabled; apply the same change in the
GRPOStrategy implementation to avoid the AttributeError.

In @src/axolotl/core/trainers/gdpo/trainer.py:
- Line 173: Ruff RUF012 flags the module-level/class variable _tag_names;
annotate it as a ClassVar to indicate it's not an instance attribute: add "from
typing import ClassVar" (or "from typing import ClassVar, List" if you prefer
explicit List) and change "_tag_names = [\"trl\", \"gdpo\", \"axolotl\"]" to
"_tag_names: ClassVar[list[str]] = [\"trl\", \"gdpo\", \"axolotl\"]" (apply the
same change to the other occurrence of _tag_names at the second location).
- Around line 440-457: The std() call used to compute std_grouped_rewards can
return NaN when group size self.num_generations == 1 because torch.std defaults
to unbiased=True; update the std computation to use an unbiased=False (i.e.,
std(dim=1, unbiased=False)) to yield 0 instead of NaN, and apply this same
change in both the GRPO fallback block (where std_grouped_rewards is computed
from rewards.view(-1, self.num_generations).std(...)) and the sequence-parallel
fallback block so that advantages scaling (self.args.scale_rewards) is safe when
num_generations == 1.
- Around line 53-154: compute_gdpo_advantages assumes num_samples %
num_generations == 0 and that tensors are already on the requested device; fix
by (1) at the top of compute_gdpo_advantages, if device is provided call
rewards_per_func = rewards_per_func.to(device) and reward_weights =
reward_weights.to(device) so all ops use the same device, and create
combined_advantages with torch.zeros(num_samples, device=device); (2) before
doing reward_grouped = reward_i.view(-1, num_generations) validate that
num_samples % num_generations == 0 (or compute num_prompts = num_samples //
num_generations) and raise a clear ValueError if not divisible, then reshape
with reward_grouped = reward_i.view(num_prompts, num_generations) to avoid
silent crashes; these changes touch symbols compute_gdpo_advantages, reward_i,
reward_grouped, reward_weights, and combined_advantages.

In @tests/e2e/multigpu/solo/test_gdpo.py:
- Around line 37-67: The test writes the reward module to the CWD using
open(f"rewards_gdpo_{suffix}.py", ...) which can leak files; change the write
target to the provided temp_dir (e.g., os.path.join(temp_dir,
f"rewards_gdpo_{suffix}.py")) and ensure the test can import the module by
either adding temp_dir to sys.path before import or by loading it with importlib
(so the functions like format_reward, correctness_reward, safety_reward,
single_reward and oai_gsm8k_transform remain importable); update the open call
and any subsequent import logic to use that temp_dir path.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7f0d4b and 64f33bd.

📒 Files selected for processing (8)
  • src/axolotl/core/builders/rl.py
  • src/axolotl/core/trainers/gdpo/__init__.py
  • src/axolotl/core/trainers/gdpo/args.py
  • src/axolotl/core/trainers/gdpo/trainer.py
  • src/axolotl/utils/data/rl.py
  • src/axolotl/utils/schemas/enums.py
  • src/axolotl/utils/schemas/trl.py
  • tests/e2e/multigpu/solo/test_gdpo.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-22T08:27:00.129Z
Learnt from: NanoCode012
Repo: axolotl-ai-cloud/axolotl PR: 2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

Applied to files:

  • src/axolotl/utils/data/rl.py
  • src/axolotl/utils/schemas/enums.py
  • src/axolotl/core/builders/rl.py
🧬 Code graph analysis (4)
src/axolotl/utils/data/rl.py (1)
src/axolotl/utils/schemas/enums.py (1)
  • RLType (25-34)
src/axolotl/core/trainers/gdpo/args.py (1)
src/axolotl/core/training_args_base.py (1)
  • AxolotlTrainingMixins (12-265)
src/axolotl/core/trainers/gdpo/__init__.py (5)
src/axolotl/core/trainers/gdpo/args.py (1)
  • AxolotlGDPOConfig (16-58)
src/axolotl/core/trainers/gdpo/trainer.py (2)
  • AxolotlGDPOSequenceParallelTrainer (536-974)
  • AxolotlGDPOTrainer (156-533)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/schemas/trl.py (1)
  • TRLConfig (8-211)
src/axolotl/utils/schemas/vllm.py (1)
  • VllmConfig (8-59)
tests/e2e/multigpu/solo/test_gdpo.py (4)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/e2e/multigpu/solo/test_grpo.py (2)
  • recursive_kill (130-141)
  • start_vllm (25-127)
tests/e2e/utils.py (1)
  • require_vllm (117-127)
src/axolotl/telemetry/runtime_metrics.py (1)
  • to_dict (75-109)
🪛 GitHub Actions: lint
src/axolotl/core/trainers/gdpo/trainer.py

[error] 212-212: Mypy: Cannot determine type of "_last_loaded_step" [has-type]


[error] 536-536: Mypy: Cannot determine type of "_last_loaded_step" in base class (misc)


[error] 536-536: Mypy: Cannot determine type of "_last_loaded_step" in base class "AxolotlGDPOTrainer" [misc]

🪛 Ruff (0.14.10)
src/axolotl/core/trainers/gdpo/trainer.py

173-173: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


546-546: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)

src/axolotl/core/trainers/gdpo/__init__.py

199-201: Avoid specifying long messages outside the exception class

(TRY003)


202-202: Consider moving this statement to an else block

(TRY300)


213-213: Consider moving this statement to an else block

(TRY300)


215-217: Avoid specifying long messages outside the exception class

(TRY003)


239-241: Prefer TypeError exception for invalid type

(TRY004)


239-241: Avoid specifying long messages outside the exception class

(TRY003)


243-243: Consider moving this statement to an else block

(TRY300)


246-246: Avoid specifying long messages outside the exception class

(TRY003)

tests/e2e/multigpu/solo/test_gdpo.py

73-73: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


150-150: Possible binding to all interfaces

(S104)


177-177: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


250-250: Possible binding to all interfaces

(S104)


277-277: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


347-347: Possible binding to all interfaces

(S104)


374-374: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


440-440: Possible binding to all interfaces

(S104)


467-467: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


539-539: Possible binding to all interfaces

(S104)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.9.1)
  • GitHub Check: PyTest (3.11, 2.9.1)
  • GitHub Check: PyTest (3.11, 2.9.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: test-axolotl-multigpu (128, 12.8.1, 3.11, 2.9.1, fbgemm-gpu, 2, true)
  • GitHub Check: test-axolotl-multigpu (128, 12.8.1, 3.11, 2.8.0, fbgemm-gpu, 2, true)
  • GitHub Check: test-axolotl-multigpu (130, 13.0.0, 3.11, 2.9.1, fbgemm-gpu, 2, true)
  • GitHub Check: preview
🔇 Additional comments (17)
src/axolotl/utils/data/rl.py (1)

176-177: LGTM!

The change correctly extends the GRPO behavior to GDPO. Both RL types handle sequence length filtering differently (during generation), so returning True to keep all samples at this stage is appropriate. This aligns with learned behavior that GRPO should be excluded from dataset label checking during preprocessing, and GDPO should follow the same pattern.

src/axolotl/utils/schemas/enums.py (1)

29-29: LGTM!

The new GDPO enum member is correctly added and follows the established naming and ordering conventions.

src/axolotl/core/builders/rl.py (3)

14-14: LGTM!

Import is correctly added for the new GDPOStrategy.


64-70: LGTM!

The GDPO trainer selection correctly mirrors the GRPO implementation pattern, using the strategy class to get trainer class, args, and kwargs.


172-175: LGTM!

The GDPO training arguments configuration correctly mirrors the GRPO pattern.

src/axolotl/utils/schemas/trl.py (1)

182-210: LGTM!

The GDPO-specific configuration fields are well-defined with clear descriptions and sensible defaults. The fields align correctly with AxolotlGDPOConfig in the args module.

tests/e2e/multigpu/solo/test_gdpo.py (6)

23-23: Entire test class is skipped.

The class-level skip decorator means none of these tests will run. If GDPO support is being merged, consider enabling at least some basic tests or documenting when these will be un-skipped.


69-172: Test structure looks good.

The test_gdpo_multi_reward_lora test properly:

  • Uses parameterization for GPU counts
  • Sets up vLLM server with appropriate cleanup in finally block
  • Tests GDPO-specific configuration options (gdpo_decoupled_norm, gdpo_batch_norm, etc.)
  • Uses multiple reward functions as GDPO is designed for

The static analysis warnings about random.randint and 0.0.0.0 binding are acceptable in this test context (non-cryptographic use and local test server).


174-272: Test coverage for three reward functions with batch norm.

Good coverage of the gdpo_batch_norm: True configuration path.


274-369: Single reward fallback test.

Good test for verifying GDPO gracefully handles single-reward scenarios (falling back to GRPO-like behavior).


371-462: Full fine-tuning test.

Good coverage of the non-adapter GDPO path.


464-561: Sequence parallel test.

Good coverage of context_parallel_size with GDPO. However, the vLLM server is started on GPU 1 (CUDA_VISIBLE_DEVICES": "1") while training runs with 2 processes. Ensure this GPU allocation is intentional and doesn't cause resource conflicts.

src/axolotl/core/trainers/gdpo/__init__.py (3)

25-38: LGTM!

The trainer class and training args class selectors are correctly implemented, following the established strategy pattern.


176-217: Reward function loading logic is sound but could be clearer.

The fallback from importable function → local directory → HuggingFace model is reasonable. Consider the static analysis suggestions for cleaner exception handling, though the current implementation is functional.

Minor: Lines 198-201 check parameter count but don't validate the expected signature (prompts, completions). This is acceptable as runtime will catch mismatches.


219-246: LGTM!

The rollout function loading follows a clean pattern with proper error handling.

src/axolotl/core/trainers/gdpo/args.py (1)

1-58: LGTM!

The AxolotlGDPOConfig dataclass is well-structured with:

  • Clear module docstring explaining GDPO's purpose
  • Comprehensive class docstring with attribute descriptions
  • Field defaults matching the TRLConfig schema
  • Helpful metadata for each GDPO-specific parameter

The multiple inheritance order (AxolotlTrainingMixins, GRPOConfig) is consistent across all Axolotl training configs in the codebase (ORPO, KTO, CPO, DPO, GRPO, PRM, Reward).

src/axolotl/core/trainers/gdpo/trainer.py (1)

624-679: The context-parallel vLLM indexing and slicing in GDPO exactly matches GRPO's sequence-parallel behavior. Both implementations use identical logic for:

  • ordered_set_of_prompts construction with num_generations and context_parallel_size interleaving
  • completion_ids slicing via sp_group_id and sp_group_start calculation
  • advantages slicing with the same SP group offset logic

The code is correct and requires no changes. Consider adding inline comments (similar to GRPO) to clarify the intent of the slicing operations for future maintainers.

Comment thread src/axolotl/core/trainers/gdpo/__init__.py Outdated
Comment thread src/axolotl/core/trainers/gdpo/__init__.py Outdated
Comment thread src/axolotl/core/trainers/gdpo/trainer.py Outdated
Comment thread src/axolotl/core/trainers/gdpo/trainer.py Outdated
Comment thread src/axolotl/core/trainers/gdpo/trainer.py Outdated
Comment thread src/axolotl/core/trainers/gdpo/trainer.py Outdated
Comment thread src/axolotl/core/trainers/gdpo/trainer.py Outdated
Comment thread tests/e2e/multigpu/solo/test_gdpo.py
@codecov

codecov Bot commented Jan 11, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.00000% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/core/builders/rl.py 80.00% 1 Missing ⚠️
src/axolotl/core/trainers/grpo/__init__.py 50.00% 1 Missing ⚠️
src/axolotl/utils/data/rl.py 0.00% 1 Missing ⚠️
src/axolotl/utils/schemas/validation.py 83.33% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@ved1beta ved1beta marked this pull request as draft January 12, 2026 11:07
@ved1beta ved1beta marked this pull request as ready for review January 13, 2026 05:50
"""Extend the base GRPOTrainer for axolotl helpers"""

_tag_names = ["trl", "grpo", "axolotl"]
_last_loaded_step: int

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for?

Comment thread docs/rlhf.qmd
type: rewards.oai_gsm8k_transform
```

You can also use GRPO with explicit aggregation control:

@NanoCode012 NanoCode012 Jan 19, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
You can also use GRPO with explicit aggregation control:
You can also use GDPO with explicit aggregation control:

typo?

Edit: Wait, I noticed this part uses grpo, so unsure if intended.

@winglian winglian left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requirements.txt needs to be updated to trl==0.27.0 in order for the test to actually work I think.

"description": "Path to custom rollout function. Must be importable from current dir."
},
)
multi_objective_aggregation: (

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in RLValidationMixin we should add a check that if rl: gdpo is set and multi_objective_aggregation is set incorrectly, it should error

from tests.e2e.utils import require_vllm


@pytest.mark.skip(reason="flaky vllm tests in modal")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we test that these are actually flaky in modal? or is this just copy paste from the GRPO test suite?

@winglian winglian left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to go. We can re-enable the vllm+g[rd]po tests after the vllm upgrade pr lands

@winglian winglian merged commit d0d26d5 into axolotl-ai-cloud:main Jan 21, 2026
22 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants