Skip to content

Commit d1ecc37

Browse files
patrickhaoyooctipuskellyguo11
authored
Fixes ObservationManager history buffer corrupted by external calls to ObservationManager.compute (#2885)
# Description When observation group has history length greater than zero, calling `ObservationManager.compute` modifies history state by appending current observation to history. This creates history corruption when non-`ManagerBasedEnv` classes invoke `ObservationManager.compute`. This PR introduces `update_history` flag (default to `False`) and only `ManagerBasedEnv` has the privilege to run `ObservationManager.compute` with `update_history=True`. If `update_history=False` and the history buffer is `None`, a copy of history is returned instead of the original. I have added test cases to verify this fix is effective. Fixes #2884 ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there --------- Signed-off-by: ooctipus <zhengyuz@nvidia.com> Signed-off-by: Kelly Guo <kellyg@nvidia.com> Co-authored-by: ooctipus <zhengyuz@nvidia.com> Co-authored-by: Kelly Guo <kellyg@nvidia.com>
1 parent 8e57a3a commit d1ecc37

File tree

7 files changed

+157
-13
lines changed

7 files changed

+157
-13
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ Guidelines for modifications:
9797
* Ori Gadot
9898
* Oyindamola Omotuyi
9999
* Özhan Özen
100+
* Patrick Yin
100101
* Peter Du
101102
* Pulkit Goyal
102103
* Qian Wan

source/isaaclab/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.40.17"
4+
version = "0.40.18"
55

66
# Description
77
title = "Isaac Lab framework for Robot Learning"

source/isaaclab/docs/CHANGELOG.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
Changelog
22
---------
33

4+
0.40.18 (2025-07-09)
5+
~~~~~~~~~~~~~~~~~~~~
6+
7+
Added
8+
^^^^^
9+
10+
* Added input param ``update_history`` to :meth:`~isaaclab.managers.ObservationManager.compute`
11+
to control whether the history buffer should be updated.
12+
* Added unit test for :class:`~isaaclab.envs.ManagerBasedEnv`.
13+
14+
Fixed
15+
^^^^^
16+
17+
* Fixed :class:`~isaaclab.envs.ManagerBasedEnv` and :class:`~isaaclab.envs.ManagerBasedRLEnv` to not update the history
18+
buffer on recording.
19+
20+
421
0.40.17 (2025-07-10)
522
~~~~~~~~~~~~~~~~~~~~
623

source/isaaclab/isaaclab/envs/manager_based_env.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def reset(
305305
self.recorder_manager.record_post_reset(env_ids)
306306

307307
# compute observations
308-
self.obs_buf = self.observation_manager.compute()
308+
self.obs_buf = self.observation_manager.compute(update_history=True)
309309

310310
if self.cfg.wait_for_textures and self.sim.has_rtx_sensors():
311311
while SimulationManager.assets_loading():
@@ -365,7 +365,7 @@ def reset_to(
365365
self.recorder_manager.record_post_reset(env_ids)
366366

367367
# compute observations
368-
self.obs_buf = self.observation_manager.compute()
368+
self.obs_buf = self.observation_manager.compute(update_history=True)
369369

370370
# return observations
371371
return self.obs_buf, self.extras
@@ -416,7 +416,7 @@ def step(self, action: torch.Tensor) -> tuple[VecEnvObs, dict]:
416416
self.event_manager.apply(mode="interval", dt=self.step_dt)
417417

418418
# -- compute observations
419-
self.obs_buf = self.observation_manager.compute()
419+
self.obs_buf = self.observation_manager.compute(update_history=True)
420420
self.recorder_manager.record_post_step()
421421

422422
# return observations and extras

source/isaaclab/isaaclab/envs/manager_based_rl_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn:
237237
self.event_manager.apply(mode="interval", dt=self.step_dt)
238238
# -- compute observations
239239
# note: done after reset to get the correct observations for reset envs
240-
self.obs_buf = self.observation_manager.compute()
240+
self.obs_buf = self.observation_manager.compute(update_history=True)
241241

242242
# return observations, rewards, resets and extras
243243
return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras

source/isaaclab/isaaclab/managers/observation_manager.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,17 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
245245
# nothing to log here
246246
return {}
247247

248-
def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
248+
def compute(self, update_history: bool = False) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
249249
"""Compute the observations per group for all groups.
250250
251251
The method computes the observations for all the groups handled by the observation manager.
252252
Please check the :meth:`compute_group` on the processing of observations per group.
253253
254+
Args:
255+
update_history: The boolean indicator without return obs should be appended to observation history.
256+
Default to False, in which case calling compute_group does not modify history. This input is no-ops
257+
if the group's history_length == 0.
258+
254259
Returns:
255260
A dictionary with keys as the group names and values as the computed observations.
256261
The observations are either concatenated into a single tensor or returned as a dictionary
@@ -260,14 +265,14 @@ def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
260265
obs_buffer = dict()
261266
# iterate over all the terms in each group
262267
for group_name in self._group_obs_term_names:
263-
obs_buffer[group_name] = self.compute_group(group_name)
268+
obs_buffer[group_name] = self.compute_group(group_name, update_history=update_history)
264269
# otherwise return a dict with observations of all groups
265270

266271
# Cache the observations.
267272
self._obs_buffer = obs_buffer
268273
return obs_buffer
269274

270-
def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tensor]:
275+
def compute_group(self, group_name: str, update_history: bool = False) -> torch.Tensor | dict[str, torch.Tensor]:
271276
"""Computes the observations for a given group.
272277
273278
The observations for a given group are computed by calling the registered functions for each
@@ -290,6 +295,9 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
290295
Args:
291296
group_name: The name of the group for which to compute the observations. Defaults to None,
292297
in which case observations for all the groups are computed and returned.
298+
update_history: The boolean indicator without return obs should be appended to observation group's history.
299+
Default to False, in which case calling compute_group does not modify history. This input is no-ops
300+
if the group's history_length == 0.
293301
294302
Returns:
295303
Depending on the group's configuration, the tensors for individual observation terms are
@@ -330,13 +338,23 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
330338
obs = obs.mul_(term_cfg.scale)
331339
# Update the history buffer if observation term has history enabled
332340
if term_cfg.history_length > 0:
333-
self._group_obs_term_history_buffer[group_name][term_name].append(obs)
334-
if term_cfg.flatten_history_dim:
335-
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape(
336-
self._env.num_envs, -1
341+
circular_buffer = self._group_obs_term_history_buffer[group_name][term_name]
342+
if update_history:
343+
circular_buffer.append(obs)
344+
elif circular_buffer._buffer is None:
345+
# because circular buffer only exits after the simulation steps,
346+
# this guards history buffer from corruption by external calls before simulation start
347+
circular_buffer = CircularBuffer(
348+
max_len=circular_buffer.max_length,
349+
batch_size=circular_buffer.batch_size,
350+
device=circular_buffer.device,
337351
)
352+
circular_buffer.append(obs)
353+
354+
if term_cfg.flatten_history_dim:
355+
group_obs[term_name] = circular_buffer.buffer.reshape(self._env.num_envs, -1)
338356
else:
339-
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer
357+
group_obs[term_name] = circular_buffer.buffer
340358
else:
341359
group_obs[term_name] = obs
342360

source/isaaclab/test/envs/test_manager_based_env.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import pytest
2424

2525
from isaaclab.envs import ManagerBasedEnv, ManagerBasedEnvCfg
26+
from isaaclab.managers import ObservationGroupCfg as ObsGroup
27+
from isaaclab.managers import ObservationTermCfg as ObsTerm
2628
from isaaclab.scene import InteractiveSceneCfg
2729
from isaaclab.utils import configclass
2830

@@ -34,6 +36,22 @@ class EmptyManagerCfg:
3436
pass
3537

3638

39+
@configclass
40+
class EmptyObservationWithHistoryCfg:
41+
"""Empty observation with history specifications for the environment."""
42+
43+
@configclass
44+
class EmptyObservationGroupWithHistoryCfg(ObsGroup):
45+
"""Empty observation with history specifications for the environment."""
46+
47+
dummy_term: ObsTerm = ObsTerm(func=lambda env: torch.randn(env.num_envs, 1, device=env.device))
48+
49+
def __post_init__(self):
50+
self.history_length = 5
51+
52+
empty_observation: EmptyObservationGroupWithHistoryCfg = EmptyObservationGroupWithHistoryCfg()
53+
54+
3755
@configclass
3856
class EmptySceneCfg(InteractiveSceneCfg):
3957
"""Configuration for an empty scene."""
@@ -67,6 +85,32 @@ def __post_init__(self):
6785
return EmptyEnvCfg()
6886

6987

88+
def get_empty_base_env_cfg_with_history(device: str = "cuda:0", num_envs: int = 1, env_spacing: float = 1.0):
89+
"""Generate base environment config based on device"""
90+
91+
@configclass
92+
class EmptyEnvWithHistoryCfg(ManagerBasedEnvCfg):
93+
"""Configuration for the empty test environment."""
94+
95+
# Scene settings
96+
scene: EmptySceneCfg = EmptySceneCfg(num_envs=num_envs, env_spacing=env_spacing)
97+
# Basic settings
98+
actions: EmptyManagerCfg = EmptyManagerCfg()
99+
observations: EmptyObservationWithHistoryCfg = EmptyObservationWithHistoryCfg()
100+
101+
def __post_init__(self):
102+
"""Post initialization."""
103+
# step settings
104+
self.decimation = 4 # env step every 4 sim steps: 200Hz / 4 = 50Hz
105+
# simulation settings
106+
self.sim.dt = 0.005 # sim step every 5ms: 200Hz
107+
self.sim.render_interval = self.decimation # render every 4 sim steps
108+
# pass device down from test
109+
self.sim.device = device
110+
111+
return EmptyEnvWithHistoryCfg()
112+
113+
70114
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
71115
def test_initialization(device):
72116
"""Test initialization of ManagerBasedEnv."""
@@ -90,3 +134,67 @@ def test_initialization(device):
90134
obs, ext = env.step(action=act)
91135
# close the environment
92136
env.close()
137+
138+
139+
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
140+
def test_observation_history_changes_only_after_step(device):
141+
"""Test observation history of ManagerBasedEnv.
142+
143+
The history buffer should only change after a step is taken.
144+
"""
145+
# create a new stage
146+
omni.usd.get_context().new_stage()
147+
# create environment with history length of 5
148+
env = ManagerBasedEnv(cfg=get_empty_base_env_cfg_with_history(device=device))
149+
150+
# check if history buffer is empty
151+
for group_name in env.observation_manager._group_obs_term_names:
152+
group_term_names = env.observation_manager._group_obs_term_names[group_name]
153+
for term_name in group_term_names:
154+
torch.testing.assert_close(
155+
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
156+
torch.zeros((env.num_envs,), device=device, dtype=torch.int64),
157+
)
158+
159+
# check if history buffer is empty after compute
160+
env.observation_manager.compute()
161+
for group_name in env.observation_manager._group_obs_term_names:
162+
group_term_names = env.observation_manager._group_obs_term_names[group_name]
163+
for term_name in group_term_names:
164+
torch.testing.assert_close(
165+
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
166+
torch.zeros((env.num_envs,), device=device, dtype=torch.int64),
167+
)
168+
169+
# check if history buffer is not empty after step
170+
act = torch.randn_like(env.action_manager.action)
171+
env.step(act)
172+
group_obs = dict()
173+
for group_name in env.observation_manager._group_obs_term_names:
174+
group_term_names = env.observation_manager._group_obs_term_names[group_name]
175+
group_obs[group_name] = dict()
176+
for term_name in group_term_names:
177+
torch.testing.assert_close(
178+
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
179+
torch.ones((env.num_envs,), device=device, dtype=torch.int64),
180+
)
181+
group_obs[group_name][term_name] = env.observation_manager._group_obs_term_history_buffer[group_name][
182+
term_name
183+
].buffer
184+
185+
# check if history buffer is not empty after compute and is the same as the buffer after step
186+
env.observation_manager.compute()
187+
for group_name in env.observation_manager._group_obs_term_names:
188+
group_term_names = env.observation_manager._group_obs_term_names[group_name]
189+
for term_name in group_term_names:
190+
torch.testing.assert_close(
191+
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
192+
torch.ones((env.num_envs,), device=device, dtype=torch.int64),
193+
)
194+
assert torch.allclose(
195+
group_obs[group_name][term_name],
196+
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].buffer,
197+
)
198+
199+
# close the environment
200+
env.close()

0 commit comments

Comments
 (0)