-
Notifications
You must be signed in to change notification settings - Fork 126
Open
Description
Problem Description:
There appears to be a mismatch between the defined observation space and the actual observations returned in the Hanabi environment.
After using the reset()
method, I received an observation with shape (658,)
, but the observation space is defined as Discrete(658)
, which implies a scalar shape ( )
.
Given this discrepancy, would using the MultiDiscrete
space class be a better fit for representing the observation space, considering the observation has multiple discrete elements?
Code to reproduce:
import jax
import jaxmarl
rng = jax.random.PRNGKey(0)
env = jaxmarl.make("hanabi")
obs, state = env.reset(rng)
# Discrete(658)
obs_space = env.observation_space("agent_0")
# Output: {'agent_0': (658,), 'agent_1': (658,)}
print(jax.tree.map(lambda x: x.shape, obs))
# Output: ()
print(obs_space.shape)
Metadata
Metadata
Assignees
Labels
No labels