-
Notifications
You must be signed in to change notification settings - Fork 126
Open
Description
The weak_type
field of jaxmarl.environments.smax.heuristic_enemy.HeuristicPolicyState.last_attacked_enemy
can change after step
, which can cause a re-tracing of functions accepting the state as an input.
Here is a minimal example of the bug.
import jax
import jax.numpy as jnp
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
key = jax.random.PRNGKey(0)
env = make(
"HeuristicEnemySMAX",
scenario=map_name_to_scenario('3m'),
see_enemy_actions=False,
)
obs, state_init = env.reset(key)
obs, state, reward, done, info = env.step(key, state_init, {
agent: jnp.array(0, dtype=jnp.int32)
for _, agent in enumerate(env.agents)
})
print(jax.tree.map(lambda x: x.aval, state_init))
print(jax.tree.map(lambda x: x.aval, state))
@jax.jit
def f(x):
print('tracing')
f(state_init)
f(state)
The output is
State(state=State(unit_positions=ShapedArray(float32[6,2]), unit_alive=ShapedArray(bool[6]), unit_teams=ShapedArray(float32[6]), unit_health=ShapedArray(float32[6]), unit_types=ShapedArray(uint8[6]), unit_weapon_cooldowns=ShapedArray(float32[6]), prev_movement_actions=ShapedArray(float32[6,2]), prev_attack_actions=ShapedArray(int32[6]), time=ShapedArray(int32[], weak_type=True), terminal=ShapedArray(bool[])), enemy_policy_state=HeuristicPolicyState(default_target=ShapedArray(float32[3,2]), last_attacked_enemy=ShapedArray(int32[3], weak_type=True)))
State(state=State(unit_positions=ShapedArray(float32[6,2]), unit_alive=ShapedArray(bool[6]), unit_teams=ShapedArray(float32[6]), unit_health=ShapedArray(float32[6]), unit_types=ShapedArray(uint8[6]), unit_weapon_cooldowns=ShapedArray(float32[6]), prev_movement_actions=ShapedArray(float32[6,2]), prev_attack_actions=ShapedArray(int32[6]), time=ShapedArray(int32[], weak_type=True), terminal=ShapedArray(bool[])), enemy_policy_state=HeuristicPolicyState(default_target=ShapedArray(float32[3,2]), last_attacked_enemy=ShapedArray(int32[3])))
tracing
tracing
Metadata
Metadata
Assignees
Labels
No labels