Skip to content

Commit 838ba6c

Browse files
author
Vincent Moens
committed
[Feature,Refactor] Chess improvements: fen, pgn, pixels, san
ghstack-source-id: 850ef45 Pull Request resolved: #2702
1 parent 289b2da commit 838ba6c

File tree

6 files changed

+29825
-101
lines changed

6 files changed

+29825
-101
lines changed

test/test_env.py

Lines changed: 142 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3346,6 +3346,10 @@ def test_batched_dynamic(self, break_when_any_done):
33463346
)
33473347
del env_no_buffers
33483348
gc.collect()
3349+
# print(dummy_rollouts)
3350+
# print(rollout_no_buffers_serial)
3351+
# # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3352+
# assert_allclose_td(a, b)
33493353
assert_allclose_td(
33503354
dummy_rollouts.exclude("action"),
33513355
rollout_no_buffers_serial.exclude("action"),
@@ -3441,35 +3445,146 @@ def test_partial_rest(self, batched):
34413445

34423446
# fen strings for board positions generated with:
34433447
# https://lichess.org/editor
3444-
@pytest.mark.parametrize("stateful", [False, True])
34453448
@pytest.mark.skipif(not _has_chess, reason="chess not found")
34463449
class TestChessEnv:
3447-
def test_env(self, stateful):
3448-
env = ChessEnv(stateful=stateful)
3449-
check_env_specs(env)
3450+
@pytest.mark.parametrize("include_pgn", [False, True])
3451+
@pytest.mark.parametrize("include_fen", [False, True])
3452+
@pytest.mark.parametrize("stateful", [False, True])
3453+
@pytest.mark.parametrize("include_hash", [False, True])
3454+
@pytest.mark.parametrize("include_san", [False, True])
3455+
def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san):
3456+
with pytest.raises(
3457+
RuntimeError, match="At least one state representation"
3458+
) if not stateful and not include_pgn and not include_fen else contextlib.nullcontext():
3459+
env = ChessEnv(
3460+
stateful=stateful,
3461+
include_pgn=include_pgn,
3462+
include_fen=include_fen,
3463+
include_hash=include_hash,
3464+
include_san=include_san,
3465+
)
3466+
check_env_specs(env)
3467+
if include_hash:
3468+
if include_fen:
3469+
assert "fen_hash" in env.observation_spec.keys()
3470+
if include_pgn:
3471+
assert "pgn_hash" in env.observation_spec.keys()
3472+
if include_san:
3473+
assert "san_hash" in env.observation_spec.keys()
3474+
3475+
def test_pgn_bijectivity(self):
3476+
np.random.seed(0)
3477+
pgn = ChessEnv._PGN_RESTART
3478+
board = ChessEnv._pgn_to_board(pgn)
3479+
pgn_prev = pgn
3480+
for _ in range(10):
3481+
moves = list(board.legal_moves)
3482+
move = np.random.choice(moves)
3483+
board.push(move)
3484+
pgn_move = ChessEnv._board_to_pgn(board)
3485+
assert pgn_move != pgn_prev
3486+
assert pgn_move == ChessEnv._board_to_pgn(ChessEnv._pgn_to_board(pgn_move))
3487+
assert pgn_move == ChessEnv._add_move_to_pgn(pgn_prev, move)
3488+
pgn_prev = pgn_move
3489+
3490+
def test_consistency(self):
3491+
env0_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=True)
3492+
env1_stateful = ChessEnv(stateful=True, include_pgn=False, include_fen=True)
3493+
env2_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=False)
3494+
env0_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=True)
3495+
env1_stateless = ChessEnv(stateful=False, include_pgn=False, include_fen=True)
3496+
env2_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=False)
3497+
torch.manual_seed(0)
3498+
r1_stateless = env1_stateless.rollout(50, break_when_any_done=False)
3499+
torch.manual_seed(0)
3500+
r1_stateful = env1_stateful.rollout(50, break_when_any_done=False)
3501+
torch.manual_seed(0)
3502+
r2_stateless = env2_stateless.rollout(50, break_when_any_done=False)
3503+
torch.manual_seed(0)
3504+
r2_stateful = env2_stateful.rollout(50, break_when_any_done=False)
3505+
torch.manual_seed(0)
3506+
r0_stateless = env0_stateless.rollout(50, break_when_any_done=False)
3507+
torch.manual_seed(0)
3508+
r0_stateful = env0_stateful.rollout(50, break_when_any_done=False)
3509+
assert (r0_stateless["action"] == r1_stateless["action"]).all()
3510+
assert (r0_stateless["action"] == r2_stateless["action"]).all()
3511+
assert (r0_stateless["action"] == r0_stateful["action"]).all()
3512+
assert (r1_stateless["action"] == r1_stateful["action"]).all()
3513+
assert (r2_stateless["action"] == r2_stateful["action"]).all()
3514+
3515+
@pytest.mark.parametrize(
3516+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3517+
)
3518+
@pytest.mark.parametrize("stateful", [False, True])
3519+
def test_san(self, stateful, include_fen, include_pgn):
3520+
torch.manual_seed(0)
3521+
env = ChessEnv(
3522+
stateful=stateful,
3523+
include_pgn=include_pgn,
3524+
include_fen=include_fen,
3525+
include_san=True,
3526+
)
3527+
r = env.rollout(100, break_when_any_done=False)
3528+
sans = r["next", "san"]
3529+
actions = [env.san_moves.index(san) for san in sans]
3530+
i = 0
3531+
3532+
def policy(td):
3533+
nonlocal i
3534+
td["action"] = actions[i]
3535+
i += 1
3536+
return td
34503537

3451-
def test_rollout(self, stateful):
3452-
env = ChessEnv(stateful=stateful)
3453-
env.rollout(5000)
3538+
r2 = env.rollout(100, policy=policy, break_when_any_done=False)
3539+
assert_allclose_td(r, r2)
34543540

3455-
def test_reset_white_to_move(self, stateful):
3456-
env = ChessEnv(stateful=stateful)
3541+
@pytest.mark.parametrize(
3542+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3543+
)
3544+
@pytest.mark.parametrize("stateful", [False, True])
3545+
def test_rollout(self, stateful, include_pgn, include_fen):
3546+
torch.manual_seed(0)
3547+
env = ChessEnv(
3548+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3549+
)
3550+
r = env.rollout(500, break_when_any_done=False)
3551+
assert r.shape == (500,)
3552+
3553+
@pytest.mark.parametrize(
3554+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3555+
)
3556+
@pytest.mark.parametrize("stateful", [False, True])
3557+
def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
3558+
env = ChessEnv(
3559+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3560+
)
34573561
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
34583562
td = env.reset(TensorDict({"fen": fen}))
34593563
assert td["fen"] == fen
3564+
if include_fen:
3565+
assert env.board.fen() == fen
34603566
assert td["turn"] == env.lib.WHITE
34613567
assert not td["done"]
34623568

3463-
def test_reset_black_to_move(self, stateful):
3464-
env = ChessEnv(stateful=stateful)
3569+
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
3570+
@pytest.mark.parametrize("stateful", [False, True])
3571+
def test_reset_black_to_move(self, stateful, include_pgn, include_fen):
3572+
env = ChessEnv(
3573+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3574+
)
34653575
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
34663576
td = env.reset(TensorDict({"fen": fen}))
34673577
assert td["fen"] == fen
3578+
assert env.board.fen() == fen
34683579
assert td["turn"] == env.lib.BLACK
34693580
assert not td["done"]
34703581

3471-
def test_reset_done_error(self, stateful):
3472-
env = ChessEnv(stateful=stateful)
3582+
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
3583+
@pytest.mark.parametrize("stateful", [False, True])
3584+
def test_reset_done_error(self, stateful, include_pgn, include_fen):
3585+
env = ChessEnv(
3586+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3587+
)
34733588
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
34743589
with pytest.raises(ValueError) as e_info:
34753590
env.reset(TensorDict({"fen": fen}))
@@ -3480,12 +3595,19 @@ def test_reset_done_error(self, stateful):
34803595
@pytest.mark.parametrize(
34813596
"endstate", ["white win", "black win", "stalemate", "50 move", "insufficient"]
34823597
)
3483-
def test_reward(self, stateful, reset_without_fen, endstate):
3598+
@pytest.mark.parametrize("include_pgn", [False, True])
3599+
@pytest.mark.parametrize("include_fen", [True])
3600+
@pytest.mark.parametrize("stateful", [False, True])
3601+
def test_reward(
3602+
self, stateful, reset_without_fen, endstate, include_pgn, include_fen
3603+
):
34843604
if stateful and reset_without_fen:
34853605
# reset_without_fen is only used for stateless env
34863606
return
34873607

3488-
env = ChessEnv(stateful=stateful)
3608+
env = ChessEnv(
3609+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3610+
)
34893611

34903612
if endstate == "white win":
34913613
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3498,28 +3620,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34983620
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
34993621
expected_turn = env.lib.BLACK
35003622
move = "Rg1#"
3501-
expected_reward = -1
3623+
expected_reward = 1
35023624
expected_done = True
35033625

35043626
elif endstate == "stalemate":
35053627
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
35063628
expected_turn = env.lib.BLACK
35073629
move = "Rb7"
3508-
expected_reward = 0
3630+
expected_reward = 0.5
35093631
expected_done = True
35103632

35113633
elif endstate == "insufficient":
35123634
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
35133635
expected_turn = env.lib.WHITE
35143636
move = "Kxd4"
3515-
expected_reward = 0
3637+
expected_reward = 0.5
35163638
expected_done = True
35173639

35183640
elif endstate == "50 move":
35193641
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
35203642
expected_turn = env.lib.BLACK
35213643
move = "Kf7"
3522-
expected_reward = 0
3644+
expected_reward = 0.5
35233645
expected_done = True
35243646

35253647
elif endstate == "not_done":
@@ -3538,8 +3660,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
35383660
td = env.reset(TensorDict({"fen": fen}))
35393661
assert td["turn"] == expected_turn
35403662

3541-
moves = env.get_legal_moves(None if stateful else td)
3542-
td["action"] = moves.index(move)
3663+
td["action"] = env._san_moves.index(move)
35433664
td = env.step(td)["next"]
35443665
assert td["done"] == expected_done
35453666
assert td["reward"] == expected_reward

torchrl/envs/batched_envs.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -718,17 +718,13 @@ def _create_td(self) -> None:
718718
env_output_keys = set()
719719
env_obs_keys = set()
720720
for meta_data in self.meta_data:
721-
env_obs_keys = env_obs_keys.union(
722-
key
723-
for key in meta_data.specs["output_spec"][
724-
"full_observation_spec"
725-
].keys(True, True)
726-
)
727-
env_output_keys = env_output_keys.union(
728-
meta_data.specs["output_spec"]["full_observation_spec"].keys(
729-
True, True
730-
)
721+
keys = meta_data.specs["output_spec"]["full_observation_spec"].keys(
722+
True, True
731723
)
724+
keys = list(keys)
725+
env_obs_keys = env_obs_keys.union(keys)
726+
727+
env_output_keys = env_output_keys.union(keys)
732728
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
733729
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
734730
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
@@ -1003,7 +999,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1003999
for i, _env in enumerate(self._envs):
10041000
if not needs_resetting[i]:
10051001
if out_tds is not None and tensordict is not None:
1006-
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
1002+
ftd = _env.observation_spec.zero()
1003+
if self.device is None:
1004+
ftd.clear_device_()
1005+
else:
1006+
ftd = ftd.to(self.device)
1007+
out_tds[i] = ftd
10071008
continue
10081009
if tensordict is not None:
10091010
tensordict_ = tensordict[i]

torchrl/envs/common.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,11 +2505,26 @@ def reset(
25052505
Returns:
25062506
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.
25072507
2508+
.. note:: `reset` should not be overwritten by :class:`~torchrl.envs.EnvBase` subclasses. The method to
2509+
modify is :meth:`~torchrl.envs.EnvBase._reset`.
2510+
25082511
"""
25092512
if tensordict is not None:
25102513
self._assert_tensordict_shape(tensordict)
25112514

2512-
tensordict_reset = self._reset(tensordict, **kwargs)
2515+
select_reset_only = kwargs.pop("select_reset_only", False)
2516+
if select_reset_only and tensordict is not None:
2517+
# When making rollouts with step_and_maybe_reset, it can happen that a tensordict has
2518+
# keys that are used by reset to optionally set the reset state (eg, the fen in chess). If that's the
2519+
# case and we don't throw them away here, reset will just be a no-op (put the env in the state reached
2520+
# during the previous step).
2521+
# Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
2522+
# To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
2523+
tensordict_reset = self._reset(
2524+
tensordict.select(*self.reset_keys, strict=False), **kwargs
2525+
)
2526+
else:
2527+
tensordict_reset = self._reset(tensordict, **kwargs)
25132528
# We assume that this is done properly
25142529
# if reset.device != self.device:
25152530
# reset = reset.to(self.device, non_blocking=True)
@@ -3293,7 +3308,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
32933308
else:
32943309
any_done = False
32953310
if any_done:
3296-
tensordict._set_str(
3311+
tensordict = tensordict._set_str(
32973312
"_reset",
32983313
done.clone(),
32993314
validated=True,
@@ -3307,7 +3322,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
33073322
key="_reset",
33083323
)
33093324
if any_done:
3310-
tensordict = self.reset(tensordict)
3325+
return self.reset(tensordict, select_reset_only=True)
33113326
return tensordict
33123327

33133328
def empty_cache(self):

0 commit comments

Comments
 (0)