[RLlib] Fix DQN RLModule forward methods such that they can handle dict spaces#60451
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses an issue in the DQN RLModule where dict spaces were not properly handled in the forward method. The changes include modifications to the forward_train method in default_dqn_torch_rl_module.py to correctly process dict observation spaces and the addition of a new test file test_dqn_rl_module.py to verify the fix and ensure compatibility with different observation spaces and configurations.
| class DictFlattenEncoder(nn.Module): | ||
| def __init__(self, obs_space, output_dim=64): | ||
| super().__init__() | ||
| total_dim = sum( |
There was a problem hiding this comment.
For nested composite spaces this won't work, Gymnasium has flatdim for this but I'm worried that we don't use Gymnasium's flatten function so the resulting spaces might be mismatched.
Therefore, this is more a note for the future
There was a problem hiding this comment.
Yea. DictFlattenEncoder is just for testing here. Flattening all elements of the dict can be counterproductive if these are, for example, images. I think for any more involved obs space we should not just autogenerate some model under the hood.
There was a problem hiding this comment.
I agree that for models that are too complex, we shouldn't allow autogeneration
There was a problem hiding this comment.
Makes sense. Let's not do it for this test file though.
| def forward(self, inputs): | ||
| obs = inputs[Columns.OBS] | ||
| flat_obs = torch.cat( | ||
| [obs[k].reshape(obs[k].shape[0], -1) for k in sorted(obs.keys())], |
There was a problem hiding this comment.
Does this work for discrete inputs?
There was a problem hiding this comment.
Why not? Should be 1 hot encoded, mhh?
There was a problem hiding this comment.
Yes, thats a good point that rationally for discrete inputs they should be one-hotted and we wouldn't get this problem but what happens if a discrete is used? Can we provide a helpful error?
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
…ct spaces (ray-project#60451) ## Description We don't natively build encoders for dict spaces and so we don't account for them in the forward method of the DQN rlm. This is an issue because users may still want to use encoder configs for dictionaries or they may want to override DQNRLModule.build_encoder etc. This PR makes a fix and introduces testing for different types of forward passes, observations spaces and configurations for the DQN RL Module. --------- Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com> Signed-off-by: Adel Nour <ans9868@nyu.edu>
Description
We don't natively build encoders for dict spaces and so we don't account for them in the forward method of the DQN rlm.
This is an issue because users may still want to use encoder configs for dictionaries or they may want to override DQNRLModule.build_encoder etc.
This PR makes a fix and introduces testing for different types of forward passes, observations spaces and configurations for the DQN RL Module.