@@ -4096,30 +4096,55 @@ def test_transform_inverse(self):
4096
4096
4097
4097
class TestRewardSum (TransformBase ):
4098
4098
def test_single_trans_env_check (self ):
4099
- env = TransformedEnv (ContinuousActionVecMockEnv (), RewardSum ())
4099
+ env = TransformedEnv (
4100
+ ContinuousActionVecMockEnv (),
4101
+ Compose (RewardScaling (loc = - 1 , scale = 1 ), RewardSum ()),
4102
+ )
4100
4103
check_env_specs (env )
4104
+ r = env .rollout (4 )
4105
+ assert r ["next" , "episode_reward" ].unique ().numel () > 1
4101
4106
4102
4107
def test_serial_trans_env_check (self ):
4103
4108
def make_env ():
4104
- return TransformedEnv (ContinuousActionVecMockEnv (), RewardSum ())
4109
+ return TransformedEnv (
4110
+ ContinuousActionVecMockEnv (),
4111
+ Compose (RewardScaling (loc = - 1 , scale = 1 ), RewardSum ()),
4112
+ )
4105
4113
4106
4114
env = SerialEnv (2 , make_env )
4107
4115
check_env_specs (env )
4116
+ r = env .rollout (4 )
4117
+ assert r ["next" , "episode_reward" ].unique ().numel () > 1
4108
4118
4109
4119
def test_parallel_trans_env_check (self ):
4110
4120
def make_env ():
4111
- return TransformedEnv (ContinuousActionVecMockEnv (), RewardSum ())
4121
+ return TransformedEnv (
4122
+ ContinuousActionVecMockEnv (),
4123
+ Compose (RewardScaling (loc = - 1 , scale = 1 ), RewardSum ()),
4124
+ )
4112
4125
4113
4126
env = ParallelEnv (2 , make_env )
4114
4127
check_env_specs (env )
4128
+ r = env .rollout (4 )
4129
+ assert r ["next" , "episode_reward" ].unique ().numel () > 1
4115
4130
4116
4131
def test_trans_serial_env_check (self ):
4117
- env = TransformedEnv (SerialEnv (2 , ContinuousActionVecMockEnv ), RewardSum ())
4132
+ env = TransformedEnv (
4133
+ SerialEnv (2 , ContinuousActionVecMockEnv ),
4134
+ Compose (RewardScaling (loc = - 1 , scale = 1 ), RewardSum ()),
4135
+ )
4118
4136
check_env_specs (env )
4137
+ r = env .rollout (4 )
4138
+ assert r ["next" , "episode_reward" ].unique ().numel () > 1
4119
4139
4120
4140
def test_trans_parallel_env_check (self ):
4121
- env = TransformedEnv (ParallelEnv (2 , ContinuousActionVecMockEnv ), RewardSum ())
4141
+ env = TransformedEnv (
4142
+ ParallelEnv (2 , ContinuousActionVecMockEnv ),
4143
+ Compose (RewardScaling (loc = - 1 , scale = 1 ), RewardSum ()),
4144
+ )
4122
4145
check_env_specs (env )
4146
+ r = env .rollout (4 )
4147
+ assert r ["next" , "episode_reward" ].unique ().numel () > 1
4123
4148
4124
4149
@pytest .mark .parametrize ("in_key" , ["reward" , ("some" , "nested" )])
4125
4150
def test_transform_no_env (self , in_key ):
0 commit comments