Skip to content

Commit 8ddc483

Browse files
kellyguo11DorsaRoh
andauthored
Adds position threshold check for state transitions (#1544)
# Description Adds a position threshold check, resolving 3 TODO error comments, to ensure the robot's end effector is within a specified distance from the target position before transitioning between states in the pick and lift state machine. Improves the precision of state transitions and helps prevent premature actions during object manipulation. I.e, the threshold ensures the robot is "close enough" to the target position before proceeding, reducing the likelihood of failed grasps or incorrect movements. PR adapted from #1273 by @DorsaRoh. ## Type of change <!-- As you go through the list, delete the ones that are not applicable. --> - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [ ] I have added my name to the `CONTRIBUTORS.md` or my name already exists there <!-- As you go through the checklist above, you can mark something as done by putting an x character in it For example, - [x] I have done this task - [ ] I have not done this task --> --------- Signed-off-by: Kelly Guo <kellyg@nvidia.com> Co-authored-by: DorsaRoh <dorsa.rohani@gmail.com>
1 parent f01c6f9 commit 8ddc483

File tree

5 files changed

+120
-54
lines changed

5 files changed

+120
-54
lines changed

source/extensions/omni.isaac.lab/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.27.27"
4+
version = "0.27.28"
55

66
# Description
77
title = "Isaac Lab framework for Robot Learning"

source/extensions/omni.isaac.lab/docs/CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Changelog
22
---------
33

4+
0.27.28 (2024-12-14)
5+
~~~~~~~~~~~~~~~~~~~~
6+
7+
Changed
8+
^^^^^^^
9+
10+
* Added check for error below threshold in state machines to ensure the state has been reached.
11+
12+
413
0.27.27 (2024-12-13)
514
~~~~~~~~~~~~~~~~~~~~
615

source/standalone/environments/state_machine/lift_cube_sm.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ class PickSmWaitTime:
8181
LIFT_OBJECT = wp.constant(1.0)
8282

8383

84+
@wp.func
85+
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
86+
return wp.length(current_pos - desired_pos) < threshold
87+
88+
8489
@wp.kernel
8590
def infer_state_machine(
8691
dt: wp.array(dtype=float),
@@ -92,6 +97,7 @@ def infer_state_machine(
9297
des_ee_pose: wp.array(dtype=wp.transform),
9398
gripper_state: wp.array(dtype=float),
9499
offset: wp.array(dtype=wp.transform),
100+
position_threshold: float,
95101
):
96102
# retrieve thread id
97103
tid = wp.tid()
@@ -109,21 +115,28 @@ def infer_state_machine(
109115
elif state == PickSmState.APPROACH_ABOVE_OBJECT:
110116
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
111117
gripper_state[tid] = GripperState.OPEN
112-
# TODO: error between current and desired ee pose below threshold
113-
# wait for a while
114-
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
115-
# move to next state and reset wait time
116-
sm_state[tid] = PickSmState.APPROACH_OBJECT
117-
sm_wait_time[tid] = 0.0
118+
if distance_below_threshold(
119+
wp.transform_get_translation(ee_pose[tid]),
120+
wp.transform_get_translation(des_ee_pose[tid]),
121+
position_threshold,
122+
):
123+
# wait for a while
124+
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
125+
# move to next state and reset wait time
126+
sm_state[tid] = PickSmState.APPROACH_OBJECT
127+
sm_wait_time[tid] = 0.0
118128
elif state == PickSmState.APPROACH_OBJECT:
119129
des_ee_pose[tid] = object_pose[tid]
120130
gripper_state[tid] = GripperState.OPEN
121-
# TODO: error between current and desired ee pose below threshold
122-
# wait for a while
123-
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
124-
# move to next state and reset wait time
125-
sm_state[tid] = PickSmState.GRASP_OBJECT
126-
sm_wait_time[tid] = 0.0
131+
if distance_below_threshold(
132+
wp.transform_get_translation(ee_pose[tid]),
133+
wp.transform_get_translation(des_ee_pose[tid]),
134+
position_threshold,
135+
):
136+
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
137+
# move to next state and reset wait time
138+
sm_state[tid] = PickSmState.GRASP_OBJECT
139+
sm_wait_time[tid] = 0.0
127140
elif state == PickSmState.GRASP_OBJECT:
128141
des_ee_pose[tid] = object_pose[tid]
129142
gripper_state[tid] = GripperState.CLOSE
@@ -135,12 +148,16 @@ def infer_state_machine(
135148
elif state == PickSmState.LIFT_OBJECT:
136149
des_ee_pose[tid] = des_object_pose[tid]
137150
gripper_state[tid] = GripperState.CLOSE
138-
# TODO: error between current and desired ee pose below threshold
139-
# wait for a while
140-
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
141-
# move to next state and reset wait time
142-
sm_state[tid] = PickSmState.LIFT_OBJECT
143-
sm_wait_time[tid] = 0.0
151+
if distance_below_threshold(
152+
wp.transform_get_translation(ee_pose[tid]),
153+
wp.transform_get_translation(des_ee_pose[tid]),
154+
position_threshold,
155+
):
156+
# wait for a while
157+
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
158+
# move to next state and reset wait time
159+
sm_state[tid] = PickSmState.LIFT_OBJECT
160+
sm_wait_time[tid] = 0.0
144161
# increment wait time
145162
sm_wait_time[tid] = sm_wait_time[tid] + dt[tid]
146163

@@ -160,7 +177,7 @@ class PickAndLiftSm:
160177
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
161178
"""
162179

163-
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
180+
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
164181
"""Initialize the state machine.
165182
166183
Args:
@@ -172,6 +189,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu")
172189
self.dt = float(dt)
173190
self.num_envs = num_envs
174191
self.device = device
192+
self.position_threshold = position_threshold
175193
# initialize state machine
176194
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
177195
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
@@ -201,7 +219,7 @@ def reset_idx(self, env_ids: Sequence[int] = None):
201219
self.sm_state[env_ids] = 0
202220
self.sm_wait_time[env_ids] = 0.0
203221

204-
def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor):
222+
def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor) -> torch.Tensor:
205223
"""Compute the desired state of the robot's end-effector and the gripper."""
206224
# convert all transformations from (w, x, y, z) to (x, y, z, w)
207225
ee_pose = ee_pose[:, [0, 1, 2, 4, 5, 6, 3]]
@@ -227,6 +245,7 @@ def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_p
227245
self.des_ee_pose_wp,
228246
self.des_gripper_state_wp,
229247
self.offset_wp,
248+
self.position_threshold,
230249
],
231250
device=self.device,
232251
)
@@ -257,7 +276,9 @@ def main():
257276
desired_orientation = torch.zeros((env.unwrapped.num_envs, 4), device=env.unwrapped.device)
258277
desired_orientation[:, 1] = 1.0
259278
# create state machine
260-
pick_sm = PickAndLiftSm(env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device)
279+
pick_sm = PickAndLiftSm(
280+
env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device, position_threshold=0.01
281+
)
261282

262283
while simulation_app.is_running():
263284
# run everything in inference mode

source/standalone/environments/state_machine/lift_teddy_bear.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ class PickSmWaitTime:
8080
OPEN_GRIPPER = wp.constant(0.0)
8181

8282

83+
@wp.func
84+
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
85+
return wp.length(current_pos - desired_pos) < threshold
86+
87+
8388
@wp.kernel
8489
def infer_state_machine(
8590
dt: wp.array(dtype=float),
@@ -91,6 +96,7 @@ def infer_state_machine(
9196
des_ee_pose: wp.array(dtype=wp.transform),
9297
gripper_state: wp.array(dtype=float),
9398
offset: wp.array(dtype=wp.transform),
99+
position_threshold: float,
94100
):
95101
# retrieve thread id
96102
tid = wp.tid()
@@ -108,21 +114,29 @@ def infer_state_machine(
108114
elif state == PickSmState.APPROACH_ABOVE_OBJECT:
109115
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
110116
gripper_state[tid] = GripperState.OPEN
111-
# TODO: error between current and desired ee pose below threshold
112-
# wait for a while
113-
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
114-
# move to next state and reset wait time
115-
sm_state[tid] = PickSmState.APPROACH_OBJECT
116-
sm_wait_time[tid] = 0.0
117+
if distance_below_threshold(
118+
wp.transform_get_translation(ee_pose[tid]),
119+
wp.transform_get_translation(des_ee_pose[tid]),
120+
position_threshold,
121+
):
122+
# wait for a while
123+
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
124+
# move to next state and reset wait time
125+
sm_state[tid] = PickSmState.APPROACH_OBJECT
126+
sm_wait_time[tid] = 0.0
117127
elif state == PickSmState.APPROACH_OBJECT:
118128
des_ee_pose[tid] = object_pose[tid]
119129
gripper_state[tid] = GripperState.OPEN
120-
# TODO: error between current and desired ee pose below threshold
121-
# wait for a while
122-
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
123-
# move to next state and reset wait time
124-
sm_state[tid] = PickSmState.GRASP_OBJECT
125-
sm_wait_time[tid] = 0.0
130+
if distance_below_threshold(
131+
wp.transform_get_translation(ee_pose[tid]),
132+
wp.transform_get_translation(des_ee_pose[tid]),
133+
position_threshold,
134+
):
135+
# wait for a while
136+
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
137+
# move to next state and reset wait time
138+
sm_state[tid] = PickSmState.GRASP_OBJECT
139+
sm_wait_time[tid] = 0.0
126140
elif state == PickSmState.GRASP_OBJECT:
127141
des_ee_pose[tid] = object_pose[tid]
128142
gripper_state[tid] = GripperState.CLOSE
@@ -134,12 +148,16 @@ def infer_state_machine(
134148
elif state == PickSmState.LIFT_OBJECT:
135149
des_ee_pose[tid] = des_object_pose[tid]
136150
gripper_state[tid] = GripperState.CLOSE
137-
# TODO: error between current and desired ee pose below threshold
138-
# wait for a while
139-
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
140-
# move to next state and reset wait time
141-
sm_state[tid] = PickSmState.OPEN_GRIPPER
142-
sm_wait_time[tid] = 0.0
151+
if distance_below_threshold(
152+
wp.transform_get_translation(ee_pose[tid]),
153+
wp.transform_get_translation(des_ee_pose[tid]),
154+
position_threshold,
155+
):
156+
# wait for a while
157+
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
158+
# move to next state and reset wait time
159+
sm_state[tid] = PickSmState.OPEN_GRIPPER
160+
sm_wait_time[tid] = 0.0
143161
elif state == PickSmState.OPEN_GRIPPER:
144162
# des_ee_pose[tid] = object_pose[tid]
145163
gripper_state[tid] = GripperState.OPEN
@@ -167,7 +185,7 @@ class PickAndLiftSm:
167185
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
168186
"""
169187

170-
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
188+
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
171189
"""Initialize the state machine.
172190
173191
Args:
@@ -179,6 +197,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu")
179197
self.dt = float(dt)
180198
self.num_envs = num_envs
181199
self.device = device
200+
self.position_threshold = position_threshold
182201
# initialize state machine
183202
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
184203
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
@@ -234,6 +253,7 @@ def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_p
234253
self.des_ee_pose_wp,
235254
self.des_gripper_state_wp,
236255
self.offset_wp,
256+
self.position_threshold,
237257
],
238258
device=self.device,
239259
)

source/standalone/environments/state_machine/open_cabinet_sm.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ class OpenDrawerSmWaitTime:
8383
RELEASE_HANDLE = wp.constant(0.2)
8484

8585

86+
@wp.func
87+
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
88+
return wp.length(current_pos - desired_pos) < threshold
89+
90+
8691
@wp.kernel
8792
def infer_state_machine(
8893
dt: wp.array(dtype=float),
@@ -95,6 +100,7 @@ def infer_state_machine(
95100
handle_approach_offset: wp.array(dtype=wp.transform),
96101
handle_grasp_offset: wp.array(dtype=wp.transform),
97102
drawer_opening_rate: wp.array(dtype=wp.transform),
103+
position_threshold: float,
98104
):
99105
# retrieve thread id
100106
tid = wp.tid()
@@ -112,21 +118,29 @@ def infer_state_machine(
112118
elif state == OpenDrawerSmState.APPROACH_INFRONT_HANDLE:
113119
des_ee_pose[tid] = wp.transform_multiply(handle_approach_offset[tid], handle_pose[tid])
114120
gripper_state[tid] = GripperState.OPEN
115-
# TODO: error between current and desired ee pose below threshold
116-
# wait for a while
117-
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE:
118-
# move to next state and reset wait time
119-
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE
120-
sm_wait_time[tid] = 0.0
121+
if distance_below_threshold(
122+
wp.transform_get_translation(ee_pose[tid]),
123+
wp.transform_get_translation(des_ee_pose[tid]),
124+
position_threshold,
125+
):
126+
# wait for a while
127+
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE:
128+
# move to next state and reset wait time
129+
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE
130+
sm_wait_time[tid] = 0.0
121131
elif state == OpenDrawerSmState.APPROACH_HANDLE:
122132
des_ee_pose[tid] = handle_pose[tid]
123133
gripper_state[tid] = GripperState.OPEN
124-
# TODO: error between current and desired ee pose below threshold
125-
# wait for a while
126-
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE:
127-
# move to next state and reset wait time
128-
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE
129-
sm_wait_time[tid] = 0.0
134+
if distance_below_threshold(
135+
wp.transform_get_translation(ee_pose[tid]),
136+
wp.transform_get_translation(des_ee_pose[tid]),
137+
position_threshold,
138+
):
139+
# wait for a while
140+
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE:
141+
# move to next state and reset wait time
142+
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE
143+
sm_wait_time[tid] = 0.0
130144
elif state == OpenDrawerSmState.GRASP_HANDLE:
131145
des_ee_pose[tid] = wp.transform_multiply(handle_grasp_offset[tid], handle_pose[tid])
132146
gripper_state[tid] = GripperState.CLOSE
@@ -170,7 +184,7 @@ class OpenDrawerSm:
170184
5. RELEASE_HANDLE: The robot releases the handle of the drawer. This is the final state.
171185
"""
172186

173-
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
187+
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
174188
"""Initialize the state machine.
175189
176190
Args:
@@ -182,6 +196,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu")
182196
self.dt = float(dt)
183197
self.num_envs = num_envs
184198
self.device = device
199+
self.position_threshold = position_threshold
185200
# initialize state machine
186201
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
187202
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
@@ -248,6 +263,7 @@ def compute(self, ee_pose: torch.Tensor, handle_pose: torch.Tensor):
248263
self.handle_approach_offset_wp,
249264
self.handle_grasp_offset_wp,
250265
self.drawer_opening_rate_wp,
266+
self.position_threshold,
251267
],
252268
device=self.device,
253269
)

0 commit comments

Comments
 (0)