From 7b20de0c702046c50c0b51bb95484b0240e99863 Mon Sep 17 00:00:00 2001 From: Giulio Romualdi Date: Wed, 7 May 2025 20:27:26 +0200 Subject: [PATCH] Fix joint_pos_out_limit and joint_pos_out_of_manual_limit terminations --- .../isaaclab/envs/mdp/terminations.py | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/source/isaaclab/isaaclab/envs/mdp/terminations.py b/source/isaaclab/isaaclab/envs/mdp/terminations.py index 20eadf6417a..dc0bd892128 100644 --- a/source/isaaclab/isaaclab/envs/mdp/terminations.py +++ b/source/isaaclab/isaaclab/envs/mdp/terminations.py @@ -78,13 +78,23 @@ def root_height_below_minimum( def joint_pos_out_of_limit(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: - """Terminate when the asset's joint positions are outside of the soft joint limits.""" - # extract the used quantities (to enable type-hinting) + """ + Return a boolean tensor of shape `[num_envs]`: + `True` if *any* of the (optionally selected) joints in an env are + outside their soft joint‑position limits. + """ asset: Articulation = env.scene[asset_cfg.name] - # compute any violations - out_of_upper_limits = torch.any(asset.data.joint_pos > asset.data.soft_joint_pos_limits[..., 1], dim=1) - out_of_lower_limits = torch.any(asset.data.joint_pos < asset.data.soft_joint_pos_limits[..., 0], dim=1) - return torch.logical_or(out_of_upper_limits[:, asset_cfg.joint_ids], out_of_lower_limits[:, asset_cfg.joint_ids]) + + # Per‑env‑per‑joint mask of violations + joint_pos = asset.data.joint_pos # [N_envs, N_joints] + joint_lower = asset.data.soft_joint_pos_limits[..., 0] + joint_upper = asset.data.soft_joint_pos_limits[..., 1] + violations = (joint_pos < joint_lower) | (joint_pos > joint_upper) + + joint_ids = asset_cfg.joint_ids if asset_cfg.joint_ids is not None else slice(None) + + # Reduce over *selected* joints → [N_envs] + return torch.any(violations[:, joint_ids], dim=1) def joint_pos_out_of_manual_limit( @@ -95,14 +105,15 @@ def joint_pos_out_of_manual_limit( Note: This function is similar to :func:`joint_pos_out_of_limit` but allows the user to specify the bounds manually. """ - # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] - if asset_cfg.joint_ids is None: - asset_cfg.joint_ids = slice(None) - # compute any violations - out_of_upper_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] > bounds[1], dim=1) - out_of_lower_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] < bounds[0], dim=1) - return torch.logical_or(out_of_upper_limits, out_of_lower_limits) + + joint_ids = asset_cfg.joint_ids if asset_cfg.joint_ids is not None else slice(None) + joint_pos = asset.data.joint_pos[:, joint_ids] # [N_envs, N_selected] + + violations = (joint_pos < bounds[0]) | (joint_pos > bounds[1]) + + # Reduce over joints → [N_envs] + return torch.any(violations, dim=1) def joint_vel_out_of_limit(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: