@@ -3346,6 +3346,10 @@ def test_batched_dynamic(self, break_when_any_done):
3346
3346
)
3347
3347
del env_no_buffers
3348
3348
gc .collect ()
3349
+ # print(dummy_rollouts)
3350
+ # print(rollout_no_buffers_serial)
3351
+ # # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3352
+ # assert_allclose_td(a, b)
3349
3353
assert_allclose_td (
3350
3354
dummy_rollouts .exclude ("action" ),
3351
3355
rollout_no_buffers_serial .exclude ("action" ),
@@ -3441,35 +3445,146 @@ def test_partial_rest(self, batched):
3441
3445
3442
3446
# fen strings for board positions generated with:
3443
3447
# https://lichess.org/editor
3444
- @pytest .mark .parametrize ("stateful" , [False , True ])
3445
3448
@pytest .mark .skipif (not _has_chess , reason = "chess not found" )
3446
3449
class TestChessEnv :
3447
- def test_env (self , stateful ):
3448
- env = ChessEnv (stateful = stateful )
3449
- check_env_specs (env )
3450
+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3451
+ @pytest .mark .parametrize ("include_fen" , [False , True ])
3452
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3453
+ @pytest .mark .parametrize ("include_hash" , [False , True ])
3454
+ @pytest .mark .parametrize ("include_san" , [False , True ])
3455
+ def test_env (self , stateful , include_pgn , include_fen , include_hash , include_san ):
3456
+ with pytest .raises (
3457
+ RuntimeError , match = "At least one state representation"
3458
+ ) if not stateful and not include_pgn and not include_fen else contextlib .nullcontext ():
3459
+ env = ChessEnv (
3460
+ stateful = stateful ,
3461
+ include_pgn = include_pgn ,
3462
+ include_fen = include_fen ,
3463
+ include_hash = include_hash ,
3464
+ include_san = include_san ,
3465
+ )
3466
+ check_env_specs (env )
3467
+ if include_hash :
3468
+ if include_fen :
3469
+ assert "fen_hash" in env .observation_spec .keys ()
3470
+ if include_pgn :
3471
+ assert "pgn_hash" in env .observation_spec .keys ()
3472
+ if include_san :
3473
+ assert "san_hash" in env .observation_spec .keys ()
3474
+
3475
+ def test_pgn_bijectivity (self ):
3476
+ np .random .seed (0 )
3477
+ pgn = ChessEnv ._PGN_RESTART
3478
+ board = ChessEnv ._pgn_to_board (pgn )
3479
+ pgn_prev = pgn
3480
+ for _ in range (10 ):
3481
+ moves = list (board .legal_moves )
3482
+ move = np .random .choice (moves )
3483
+ board .push (move )
3484
+ pgn_move = ChessEnv ._board_to_pgn (board )
3485
+ assert pgn_move != pgn_prev
3486
+ assert pgn_move == ChessEnv ._board_to_pgn (ChessEnv ._pgn_to_board (pgn_move ))
3487
+ assert pgn_move == ChessEnv ._add_move_to_pgn (pgn_prev , move )
3488
+ pgn_prev = pgn_move
3489
+
3490
+ def test_consistency (self ):
3491
+ env0_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = True )
3492
+ env1_stateful = ChessEnv (stateful = True , include_pgn = False , include_fen = True )
3493
+ env2_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = False )
3494
+ env0_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = True )
3495
+ env1_stateless = ChessEnv (stateful = False , include_pgn = False , include_fen = True )
3496
+ env2_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = False )
3497
+ torch .manual_seed (0 )
3498
+ r1_stateless = env1_stateless .rollout (50 , break_when_any_done = False )
3499
+ torch .manual_seed (0 )
3500
+ r1_stateful = env1_stateful .rollout (50 , break_when_any_done = False )
3501
+ torch .manual_seed (0 )
3502
+ r2_stateless = env2_stateless .rollout (50 , break_when_any_done = False )
3503
+ torch .manual_seed (0 )
3504
+ r2_stateful = env2_stateful .rollout (50 , break_when_any_done = False )
3505
+ torch .manual_seed (0 )
3506
+ r0_stateless = env0_stateless .rollout (50 , break_when_any_done = False )
3507
+ torch .manual_seed (0 )
3508
+ r0_stateful = env0_stateful .rollout (50 , break_when_any_done = False )
3509
+ assert (r0_stateless ["action" ] == r1_stateless ["action" ]).all ()
3510
+ assert (r0_stateless ["action" ] == r2_stateless ["action" ]).all ()
3511
+ assert (r0_stateless ["action" ] == r0_stateful ["action" ]).all ()
3512
+ assert (r1_stateless ["action" ] == r1_stateful ["action" ]).all ()
3513
+ assert (r2_stateless ["action" ] == r2_stateful ["action" ]).all ()
3514
+
3515
+ @pytest .mark .parametrize (
3516
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3517
+ )
3518
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3519
+ def test_san (self , stateful , include_fen , include_pgn ):
3520
+ torch .manual_seed (0 )
3521
+ env = ChessEnv (
3522
+ stateful = stateful ,
3523
+ include_pgn = include_pgn ,
3524
+ include_fen = include_fen ,
3525
+ include_san = True ,
3526
+ )
3527
+ r = env .rollout (100 , break_when_any_done = False )
3528
+ sans = r ["next" , "san" ]
3529
+ actions = [env .san_moves .index (san ) for san in sans ]
3530
+ i = 0
3531
+
3532
+ def policy (td ):
3533
+ nonlocal i
3534
+ td ["action" ] = actions [i ]
3535
+ i += 1
3536
+ return td
3450
3537
3451
- def test_rollout (self , stateful ):
3452
- env = ChessEnv (stateful = stateful )
3453
- env .rollout (5000 )
3538
+ r2 = env .rollout (100 , policy = policy , break_when_any_done = False )
3539
+ assert_allclose_td (r , r2 )
3454
3540
3455
- def test_reset_white_to_move (self , stateful ):
3456
- env = ChessEnv (stateful = stateful )
3541
+ @pytest .mark .parametrize (
3542
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3543
+ )
3544
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3545
+ def test_rollout (self , stateful , include_pgn , include_fen ):
3546
+ torch .manual_seed (0 )
3547
+ env = ChessEnv (
3548
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3549
+ )
3550
+ r = env .rollout (500 , break_when_any_done = False )
3551
+ assert r .shape == (500 ,)
3552
+
3553
+ @pytest .mark .parametrize (
3554
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3555
+ )
3556
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3557
+ def test_reset_white_to_move (self , stateful , include_pgn , include_fen ):
3558
+ env = ChessEnv (
3559
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3560
+ )
3457
3561
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
3458
3562
td = env .reset (TensorDict ({"fen" : fen }))
3459
3563
assert td ["fen" ] == fen
3564
+ if include_fen :
3565
+ assert env .board .fen () == fen
3460
3566
assert td ["turn" ] == env .lib .WHITE
3461
3567
assert not td ["done" ]
3462
3568
3463
- def test_reset_black_to_move (self , stateful ):
3464
- env = ChessEnv (stateful = stateful )
3569
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3570
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3571
+ def test_reset_black_to_move (self , stateful , include_pgn , include_fen ):
3572
+ env = ChessEnv (
3573
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3574
+ )
3465
3575
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
3466
3576
td = env .reset (TensorDict ({"fen" : fen }))
3467
3577
assert td ["fen" ] == fen
3578
+ assert env .board .fen () == fen
3468
3579
assert td ["turn" ] == env .lib .BLACK
3469
3580
assert not td ["done" ]
3470
3581
3471
- def test_reset_done_error (self , stateful ):
3472
- env = ChessEnv (stateful = stateful )
3582
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3583
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3584
+ def test_reset_done_error (self , stateful , include_pgn , include_fen ):
3585
+ env = ChessEnv (
3586
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3587
+ )
3473
3588
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
3474
3589
with pytest .raises (ValueError ) as e_info :
3475
3590
env .reset (TensorDict ({"fen" : fen }))
@@ -3480,12 +3595,19 @@ def test_reset_done_error(self, stateful):
3480
3595
@pytest .mark .parametrize (
3481
3596
"endstate" , ["white win" , "black win" , "stalemate" , "50 move" , "insufficient" ]
3482
3597
)
3483
- def test_reward (self , stateful , reset_without_fen , endstate ):
3598
+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3599
+ @pytest .mark .parametrize ("include_fen" , [True ])
3600
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3601
+ def test_reward (
3602
+ self , stateful , reset_without_fen , endstate , include_pgn , include_fen
3603
+ ):
3484
3604
if stateful and reset_without_fen :
3485
3605
# reset_without_fen is only used for stateless env
3486
3606
return
3487
3607
3488
- env = ChessEnv (stateful = stateful )
3608
+ env = ChessEnv (
3609
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3610
+ )
3489
3611
3490
3612
if endstate == "white win" :
3491
3613
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3498,28 +3620,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
3498
3620
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
3499
3621
expected_turn = env .lib .BLACK
3500
3622
move = "Rg1#"
3501
- expected_reward = - 1
3623
+ expected_reward = 1
3502
3624
expected_done = True
3503
3625
3504
3626
elif endstate == "stalemate" :
3505
3627
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
3506
3628
expected_turn = env .lib .BLACK
3507
3629
move = "Rb7"
3508
- expected_reward = 0
3630
+ expected_reward = 0.5
3509
3631
expected_done = True
3510
3632
3511
3633
elif endstate == "insufficient" :
3512
3634
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
3513
3635
expected_turn = env .lib .WHITE
3514
3636
move = "Kxd4"
3515
- expected_reward = 0
3637
+ expected_reward = 0.5
3516
3638
expected_done = True
3517
3639
3518
3640
elif endstate == "50 move" :
3519
3641
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
3520
3642
expected_turn = env .lib .BLACK
3521
3643
move = "Kf7"
3522
- expected_reward = 0
3644
+ expected_reward = 0.5
3523
3645
expected_done = True
3524
3646
3525
3647
elif endstate == "not_done" :
@@ -3538,8 +3660,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
3538
3660
td = env .reset (TensorDict ({"fen" : fen }))
3539
3661
assert td ["turn" ] == expected_turn
3540
3662
3541
- moves = env .get_legal_moves (None if stateful else td )
3542
- td ["action" ] = moves .index (move )
3663
+ td ["action" ] = env ._san_moves .index (move )
3543
3664
td = env .step (td )["next" ]
3544
3665
assert td ["done" ] == expected_done
3545
3666
assert td ["reward" ] == expected_reward
0 commit comments