Skip to content

Commit f7b59b3

Browse files
jtigue-bdaifyu-bdaikellyguo11
authored
Adds observation term history support to Observation Manager (#1439)
# Description <!-- Thank you for your interest in sending a pull request. Please make sure to check the contribution guidelines. Link: https://isaac-sim.github.io/IsaacLab/source/refs/contributing.html --> This PR adds observation history by adding configuration parameters to ObservationTerms and having the ObservationManager handling the collection and storage of the histories via CircularBuffers. Fixes #1208 <!-- As a practice, it is recommended to open an issue to have discussions on the proposed pull request. This makes it easier for the community to keep track of what is being developed or added, and if a given feature is demanded by more than one party. --> ## Type of change <!-- As you go through the list, delete the ones that are not applicable. --> - New feature (non-breaking change which adds functionality) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there <!-- As you go through the checklist above, you can mark something as done by putting an x character in it For example, - [x] I have done this task - [ ] I have not done this task --> --------- Signed-off-by: Kelly Guo <kellyg@nvidia.com> Co-authored-by: Fangzhou Yu <156015326+fyu-bdai@users.noreply.github.com> Co-authored-by: Kelly Guo <kellyg@nvidia.com>
1 parent ee3f022 commit f7b59b3

File tree

7 files changed

+352
-13
lines changed

7 files changed

+352
-13
lines changed

source/extensions/omni.isaac.lab/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.27.29"
4+
version = "0.28.0"
55

66
# Description
77
title = "Isaac Lab framework for Robot Learning"

source/extensions/omni.isaac.lab/docs/CHANGELOG.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
Changelog
22
---------
33

4+
0.28.0 (2024-12-15)
5+
~~~~~~~~~~~~~~~~~~~
6+
7+
Added
8+
^^^^^
9+
10+
* Added observation history computation to :class:`omni.isaac.lab.manager.observation_manager.ObservationManager`.
11+
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationTermCfg`
12+
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationGroupCfg`
13+
* Added full buffer property to :class:`omni.isaac.lab.utils.buffers.circular_buffer.CircularBuffer`
14+
15+
416
0.27.29 (2024-12-15)
517
~~~~~~~~~~~~~~~~~~~~
618

source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,19 @@ class ObservationTermCfg(ManagerTermBaseCfg):
180180
please make sure the length of the tuple matches the dimensions of the tensor outputted from the term.
181181
"""
182182

183+
history_length: int = 0
184+
"""Number of past observations to store in the observation buffers. Defaults to 0, meaning no history.
185+
186+
Observation history initializes to empty, but is filled with the first append after reset or initialization. Subsequent history
187+
only adds a single entry to the history buffer. If flatten_history_dim is set to True, the source data of shape
188+
(N, H, D, ...) where N is the batch dimension and H is the history length will be reshaped to a 2D tensor of shape
189+
(N, H*D*...). Otherwise, the data will be returned as is.
190+
"""
191+
192+
flatten_history_dim: bool = True
193+
"""Whether or not the observation manager should flatten history-based observation terms to a 2D (N, D) tensor.
194+
Defaults to True."""
195+
183196

184197
@configclass
185198
class ObservationGroupCfg:
@@ -201,6 +214,22 @@ class ObservationGroupCfg:
201214
Otherwise, no corruption is applied.
202215
"""
203216

217+
history_length: int | None = None
218+
"""Number of past observation to store in the observation buffers for all observation terms in group.
219+
220+
This parameter will override :attr:`ObservationTermCfg.history_length` if set. Defaults to None. If None, each
221+
terms history will be controlled on a per term basis. See :class:`ObservationTermCfg` for details on history_length
222+
implementation.
223+
"""
224+
225+
flatten_history_dim: bool = True
226+
"""Flag to flatten history-based observation terms to a 2D (num_env, D) tensor for all observation terms in group.
227+
Defaults to True.
228+
229+
This parameter will override all :attr:`ObservationTermCfg.flatten_history_dim` in the group if
230+
ObservationGroupCfg.history_length is set.
231+
"""
232+
204233

205234
##
206235
# Event manager

source/extensions/omni.isaac.lab/omni/isaac/lab/managers/observation_manager.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from __future__ import annotations
99

1010
import inspect
11+
import numpy as np
1112
import torch
1213
from collections.abc import Sequence
1314
from prettytable import PrettyTable
1415
from typing import TYPE_CHECKING
1516

1617
from omni.isaac.lab.utils import modifiers
18+
from omni.isaac.lab.utils.buffers import CircularBuffer
1719

1820
from .manager_base import ManagerBase, ManagerTermBase
1921
from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg
@@ -45,6 +47,11 @@ class ObservationManager(ManagerBase):
4547
concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the
4648
group configuration to False.
4749
50+
Observations can also have history. This means a running history is updated per sim step. History can be controlled
51+
per :class:`ObservationTermCfg` (See the :attr:`ObservationTermCfg.history_length` and
52+
:attr:`ObservationTermCfg.flatten_history_dim`). History can also be controlled via :class:`ObservationGroupCfg`
53+
where group configuration overwrites per term configuration if set. History follows an oldest to newest ordering.
54+
4855
The observation manager can be used to compute observations for all the groups or for a specific group. The
4956
observations are computed by calling the registered functions for each term in the group. The functions are
5057
called in the order of the terms in the group. The functions are expected to return a tensor with shape
@@ -174,12 +181,17 @@ def group_obs_concatenate(self) -> dict[str, bool]:
174181

175182
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
176183
# call all terms that are classes
177-
for group_cfg in self._group_obs_class_term_cfgs.values():
184+
for group_name, group_cfg in self._group_obs_class_term_cfgs.items():
178185
for term_cfg in group_cfg:
179186
term_cfg.func.reset(env_ids=env_ids)
187+
# reset terms with history
188+
for term_name in self._group_obs_term_names[group_name]:
189+
if term_name in self._group_obs_term_history_buffer[group_name]:
190+
self._group_obs_term_history_buffer[group_name][term_name].reset(batch_ids=env_ids)
180191
# call all modifiers that are classes
181192
for mod in self._group_obs_class_modifiers:
182193
mod.reset(env_ids=env_ids)
194+
183195
# nothing to log here
184196
return {}
185197

@@ -248,7 +260,7 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
248260
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])
249261

250262
# evaluate terms: compute, add noise, clip, scale, custom modifiers
251-
for name, term_cfg in obs_terms:
263+
for term_name, term_cfg in obs_terms:
252264
# compute term's value
253265
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone()
254266
# apply post-processing
@@ -261,8 +273,17 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
261273
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
262274
if term_cfg.scale is not None:
263275
obs = obs.mul_(term_cfg.scale)
264-
# add value to list
265-
group_obs[name] = obs
276+
# Update the history buffer if observation term has history enabled
277+
if term_cfg.history_length > 0:
278+
self._group_obs_term_history_buffer[group_name][term_name].append(obs)
279+
if term_cfg.flatten_history_dim:
280+
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape(
281+
self._env.num_envs, -1
282+
)
283+
else:
284+
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer
285+
else:
286+
group_obs[term_name] = obs
266287

267288
# concatenate all observations in the group together
268289
if self._group_obs_concatenate[group_name]:
@@ -283,7 +304,7 @@ def _prepare_terms(self):
283304
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
284305
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
285306
self._group_obs_concatenate: dict[str, bool] = dict()
286-
307+
self._group_obs_term_history_buffer: dict[str, dict] = dict()
287308
# create a list to store modifiers that are classes
288309
# we store it as a separate list to only call reset on them and prevent unnecessary calls
289310
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list()
@@ -309,6 +330,7 @@ def _prepare_terms(self):
309330
self._group_obs_term_dim[group_name] = list()
310331
self._group_obs_term_cfgs[group_name] = list()
311332
self._group_obs_class_term_cfgs[group_name] = list()
333+
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
312334
# read common config for the group
313335
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
314336
# check if config is dict already
@@ -319,7 +341,7 @@ def _prepare_terms(self):
319341
# iterate over all the terms in each group
320342
for term_name, term_cfg in group_cfg_items:
321343
# skip non-obs settings
322-
if term_name in ["enable_corruption", "concatenate_terms"]:
344+
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
323345
continue
324346
# check for non config
325347
if term_cfg is None:
@@ -335,12 +357,26 @@ def _prepare_terms(self):
335357
# check noise settings
336358
if not group_cfg.enable_corruption:
337359
term_cfg.noise = None
360+
# check group history params and override terms
361+
if group_cfg.history_length is not None:
362+
term_cfg.history_length = group_cfg.history_length
363+
term_cfg.flatten_history_dim = group_cfg.flatten_history_dim
338364
# add term config to list to list
339365
self._group_obs_term_names[group_name].append(term_name)
340366
self._group_obs_term_cfgs[group_name].append(term_cfg)
341-
342367
# call function the first time to fill up dimensions
343368
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
369+
# create history buffers and calculate history term dimensions
370+
if term_cfg.history_length > 0:
371+
group_entry_history_buffer[term_name] = CircularBuffer(
372+
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
373+
)
374+
old_dims = list(obs_dims)
375+
old_dims.insert(1, term_cfg.history_length)
376+
obs_dims = tuple(old_dims)
377+
if term_cfg.flatten_history_dim:
378+
obs_dims = (obs_dims[0], np.prod(obs_dims[1:]))
379+
344380
self._group_obs_term_dim[group_name].append(obs_dims[1:])
345381

346382
# if scale is set, check if single float or tuple
@@ -411,3 +447,5 @@ def _prepare_terms(self):
411447
self._group_obs_class_term_cfgs[group_name].append(term_cfg)
412448
# call reset (in-case above call to get obs dims changed the state)
413449
term_cfg.func.reset()
450+
# add history buffers for each group
451+
self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer

source/extensions/omni.isaac.lab/omni/isaac/lab/utils/buffers/circular_buffer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def current_length(self) -> torch.Tensor:
7575
"""
7676
return torch.minimum(self._num_pushes, self._max_len)
7777

78+
@property
79+
def buffer(self) -> torch.Tensor:
80+
"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning.
81+
Returns:
82+
Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]].
83+
"""
84+
buf = self._buffer.clone()
85+
buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0)
86+
return torch.transpose(buf, dim0=0, dim1=1)
87+
7888
"""
7989
Operations.
8090
"""
@@ -89,8 +99,10 @@ def reset(self, batch_ids: Sequence[int] | None = None):
8999
if batch_ids is None:
90100
batch_ids = slice(None)
91101
# reset the number of pushes for the specified batch indices
92-
# note: we don't need to reset the buffer since it will be overwritten. The pointer handles this.
93102
self._num_pushes[batch_ids] = 0
103+
if self._buffer is not None:
104+
# set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset.
105+
self._buffer[:, batch_ids, :] = 0.0
94106

95107
def append(self, data: torch.Tensor):
96108
"""Append the data to the circular buffer.
@@ -106,15 +118,20 @@ def append(self, data: torch.Tensor):
106118
if data.shape[0] != self.batch_size:
107119
raise ValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}")
108120

109-
# at the fist call, initialize the buffer
121+
# at the first call, initialize the buffer size
110122
if self._buffer is None:
111123
self._pointer = -1
112124
self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device)
113125
# move the head to the next slot
114126
self._pointer = (self._pointer + 1) % self.max_length
115127
# add the new data to the last layer
116128
self._buffer[self._pointer] = data.to(self._device)
117-
# increment number of number of pushes
129+
# Check for batches with zero pushes and initialize all values in batch to first append
130+
if 0 in self._num_pushes.tolist():
131+
fill_ids = [i for i, x in enumerate(self._num_pushes.tolist()) if x == 0]
132+
self._num_pushes.tolist().index(0) if 0 in self._num_pushes.tolist() else None
133+
self._buffer[:, fill_ids, :] = data.to(self._device)[fill_ids]
134+
# increment number of number of pushes for all batches
118135
self._num_pushes += 1
119136

120137
def __getitem__(self, key: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)