Skip to content

Commit d1ece8f

Browse files
Vincent Moensosalpekar
authored andcommitted
[BugFix] Fix reward sum within parallel envs (#1454)
1 parent c4c93ba commit d1ece8f

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

test/test_transforms.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4096,30 +4096,55 @@ def test_transform_inverse(self):
40964096

40974097
class TestRewardSum(TransformBase):
40984098
def test_single_trans_env_check(self):
4099-
env = TransformedEnv(ContinuousActionVecMockEnv(), RewardSum())
4099+
env = TransformedEnv(
4100+
ContinuousActionVecMockEnv(),
4101+
Compose(RewardScaling(loc=-1, scale=1), RewardSum()),
4102+
)
41004103
check_env_specs(env)
4104+
r = env.rollout(4)
4105+
assert r["next", "episode_reward"].unique().numel() > 1
41014106

41024107
def test_serial_trans_env_check(self):
41034108
def make_env():
4104-
return TransformedEnv(ContinuousActionVecMockEnv(), RewardSum())
4109+
return TransformedEnv(
4110+
ContinuousActionVecMockEnv(),
4111+
Compose(RewardScaling(loc=-1, scale=1), RewardSum()),
4112+
)
41054113

41064114
env = SerialEnv(2, make_env)
41074115
check_env_specs(env)
4116+
r = env.rollout(4)
4117+
assert r["next", "episode_reward"].unique().numel() > 1
41084118

41094119
def test_parallel_trans_env_check(self):
41104120
def make_env():
4111-
return TransformedEnv(ContinuousActionVecMockEnv(), RewardSum())
4121+
return TransformedEnv(
4122+
ContinuousActionVecMockEnv(),
4123+
Compose(RewardScaling(loc=-1, scale=1), RewardSum()),
4124+
)
41124125

41134126
env = ParallelEnv(2, make_env)
41144127
check_env_specs(env)
4128+
r = env.rollout(4)
4129+
assert r["next", "episode_reward"].unique().numel() > 1
41154130

41164131
def test_trans_serial_env_check(self):
4117-
env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), RewardSum())
4132+
env = TransformedEnv(
4133+
SerialEnv(2, ContinuousActionVecMockEnv),
4134+
Compose(RewardScaling(loc=-1, scale=1), RewardSum()),
4135+
)
41184136
check_env_specs(env)
4137+
r = env.rollout(4)
4138+
assert r["next", "episode_reward"].unique().numel() > 1
41194139

41204140
def test_trans_parallel_env_check(self):
4121-
env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), RewardSum())
4141+
env = TransformedEnv(
4142+
ParallelEnv(2, ContinuousActionVecMockEnv),
4143+
Compose(RewardScaling(loc=-1, scale=1), RewardSum()),
4144+
)
41224145
check_env_specs(env)
4146+
r = env.rollout(4)
4147+
assert r["next", "episode_reward"].unique().numel() > 1
41234148

41244149
@pytest.mark.parametrize("in_key", ["reward", ("some", "nested")])
41254150
def test_transform_no_env(self, in_key):

torchrl/envs/transforms/transforms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3818,6 +3818,22 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
38183818
tensordict.set("next", next_tensordict)
38193819
return tensordict
38203820

3821+
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
3822+
state_spec = input_spec["full_state_spec"]
3823+
if state_spec is None:
3824+
state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device)
3825+
reward_spec = self.parent.reward_spec
3826+
# Define episode specs for all out_keys
3827+
for out_key in self.out_keys:
3828+
episode_spec = UnboundedContinuousTensorSpec(
3829+
shape=reward_spec.shape,
3830+
device=reward_spec.device,
3831+
dtype=reward_spec.dtype,
3832+
)
3833+
state_spec[out_key] = episode_spec
3834+
input_spec["full_state_spec"] = state_spec
3835+
return input_spec
3836+
38213837
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
38223838
"""Transforms the observation spec, adding the new keys generated by RewardSum."""
38233839
# Retrieve parent reward spec

0 commit comments

Comments
 (0)