Skip to content

[Bug] Re-tracing is caused after step due to the change of weak_type field. #144

@HeavyCrab

Description

@HeavyCrab

The weak_type field of jaxmarl.environments.smax.heuristic_enemy.HeuristicPolicyState.last_attacked_enemy can change after step, which can cause a re-tracing of functions accepting the state as an input.

Here is a minimal example of the bug.

import jax
import jax.numpy as jnp
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario

key = jax.random.PRNGKey(0)


env = make(
    "HeuristicEnemySMAX",
    scenario=map_name_to_scenario('3m'),
    see_enemy_actions=False,
)
obs, state_init = env.reset(key)
obs, state, reward, done, info = env.step(key, state_init, {
    agent: jnp.array(0, dtype=jnp.int32)
    for _, agent in enumerate(env.agents)
})

print(jax.tree.map(lambda x: x.aval, state_init))
print(jax.tree.map(lambda x: x.aval, state))

@jax.jit
def f(x):
    print('tracing')

f(state_init)
f(state)

The output is

State(state=State(unit_positions=ShapedArray(float32[6,2]), unit_alive=ShapedArray(bool[6]), unit_teams=ShapedArray(float32[6]), unit_health=ShapedArray(float32[6]), unit_types=ShapedArray(uint8[6]), unit_weapon_cooldowns=ShapedArray(float32[6]), prev_movement_actions=ShapedArray(float32[6,2]), prev_attack_actions=ShapedArray(int32[6]), time=ShapedArray(int32[], weak_type=True), terminal=ShapedArray(bool[])), enemy_policy_state=HeuristicPolicyState(default_target=ShapedArray(float32[3,2]), last_attacked_enemy=ShapedArray(int32[3], weak_type=True)))
State(state=State(unit_positions=ShapedArray(float32[6,2]), unit_alive=ShapedArray(bool[6]), unit_teams=ShapedArray(float32[6]), unit_health=ShapedArray(float32[6]), unit_types=ShapedArray(uint8[6]), unit_weapon_cooldowns=ShapedArray(float32[6]), prev_movement_actions=ShapedArray(float32[6,2]), prev_attack_actions=ShapedArray(int32[6]), time=ShapedArray(int32[], weak_type=True), terminal=ShapedArray(bool[])), enemy_policy_state=HeuristicPolicyState(default_target=ShapedArray(float32[3,2]), last_attacked_enemy=ShapedArray(int32[3])))
tracing
tracing

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions