Skip to content

[RLlib] Fix DQN RLModule forward methods such that they can handle dict spaces#60451

Merged
ArturNiederfahrenhorst merged 5 commits into
ray-project:masterfrom
ArturNiederfahrenhorst:fixdqntorchrlm
Jan 27, 2026
Merged

[RLlib] Fix DQN RLModule forward methods such that they can handle dict spaces#60451
ArturNiederfahrenhorst merged 5 commits into
ray-project:masterfrom
ArturNiederfahrenhorst:fixdqntorchrlm

Conversation

@ArturNiederfahrenhorst

Copy link
Copy Markdown
Contributor

Description

We don't natively build encoders for dict spaces and so we don't account for them in the forward method of the DQN rlm.
This is an issue because users may still want to use encoder configs for dictionaries or they may want to override DQNRLModule.build_encoder etc.

This PR makes a fix and introduces testing for different types of forward passes, observations spaces and configurations for the DQN RL Module.

@ArturNiederfahrenhorst ArturNiederfahrenhorst requested a review from a team as a code owner January 23, 2026 11:01

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue in the DQN RLModule where dict spaces were not properly handled in the forward method. The changes include modifications to the forward_train method in default_dqn_torch_rl_module.py to correctly process dict observation spaces and the addition of a new test file test_dqn_rl_module.py to verify the fix and ensure compatibility with different observation spaces and configurations.

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, minor changes and a question about working with Discrete observations

class DictFlattenEncoder(nn.Module):
def __init__(self, obs_space, output_dim=64):
super().__init__()
total_dim = sum(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For nested composite spaces this won't work, Gymnasium has flatdim for this but I'm worried that we don't use Gymnasium's flatten function so the resulting spaces might be mismatched.
Therefore, this is more a note for the future

@ArturNiederfahrenhorst ArturNiederfahrenhorst Jan 25, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea. DictFlattenEncoder is just for testing here. Flattening all elements of the dict can be counterproductive if these are, for example, images. I think for any more involved obs space we should not just autogenerate some model under the hood.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that for models that are too complex, we shouldn't allow autogeneration

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Let's not do it for this test file though.

def forward(self, inputs):
obs = inputs[Columns.OBS]
flat_obs = torch.cat(
[obs[k].reshape(obs[k].shape[0], -1) for k in sorted(obs.keys())],

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for discrete inputs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? Should be 1 hot encoded, mhh?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thats a good point that rationally for discrete inputs they should be one-hotted and we wouldn't get this problem but what happens if a discrete is used? Can we provide a helpful error?

Comment thread rllib/algorithms/dqn/tests/test_dqn_rl_module.py
Comment thread rllib/algorithms/dqn/tests/test_dqn_rl_module.py Outdated
@ray-gardener ray-gardener Bot added the rllib RLlib related issues label Jan 23, 2026
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pseudo-rnd-thoughts pseudo-rnd-thoughts added rllib-models An issue related to RLlib (default or custom) Models. go add ONLY when ready to merge, run all tests labels Jan 26, 2026
@ArturNiederfahrenhorst ArturNiederfahrenhorst merged commit d406677 into ray-project:master Jan 27, 2026
7 checks passed
ans9868 pushed a commit to ans9868/ray that referenced this pull request Feb 18, 2026
…ct spaces (ray-project#60451)

## Description

We don't natively build encoders for dict spaces and so we don't account
for them in the forward method of the DQN rlm.
This is an issue because users may still want to use encoder configs for
dictionaries or they may want to override DQNRLModule.build_encoder etc.

This PR makes a fix and introduces testing for different types of
forward passes, observations spaces and configurations for the DQN RL
Module.

---------

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Adel Nour <ans9868@nyu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests rllib RLlib related issues rllib-models An issue related to RLlib (default or custom) Models.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants