@@ -172,7 +172,7 @@ def __init__(self, cfg: actions_cfg.DifferentialInverseKinematicsActionCfg, env:
172
172
out = torch .tensor (self .cfg .clip ["position" ][0 ], device = self .device )
173
173
print (self ._clip .shape , out .shape )
174
174
print (out )
175
-
175
+
176
176
self ._clip [:, 0 ] = torch .tensor (self .cfg .clip ["position" ][0 ], device = self .device )
177
177
self ._clip [:, 1 ] = torch .tensor (self .cfg .clip ["position" ][1 ], device = self .device )
178
178
self ._clip [:, 2 ] = torch .tensor (self .cfg .clip ["position" ][2 ], device = self .device )
@@ -245,7 +245,9 @@ def process_actions(self, actions: torch.Tensor):
245
245
ee_pos_curr , ee_quat_curr , self ._processed_actions
246
246
)
247
247
# Cast the target_quat_w to euler angles
248
- target_euler_angles_w = torch .transpose (torch .stack (math_utils .euler_xyz_from_quat (target_quat_w )), 0 , 1 )
248
+ target_euler_angles_w = torch .transpose (
249
+ torch .stack (math_utils .euler_xyz_from_quat (target_quat_w )), 0 , 1
250
+ )
249
251
# Clip the pose
250
252
clamped_target_position_w = torch .clamp (
251
253
target_position_w , min = self ._clip [:, :3 , 0 ], max = self ._clip [:, :3 , 1 ]
@@ -272,7 +274,9 @@ def process_actions(self, actions: torch.Tensor):
272
274
self ._processed_actions [:, :3 ], min = self ._clip [:, :3 , 0 ], max = self ._clip [:, :3 , 1 ]
273
275
)
274
276
# Cast the target quaternion to euler angles
275
- target_euler_angles_w = torch .transpose (torch .stack (math_utils .euler_xyz_from_quat (self ._processed_actions [:, 3 :7 ])), 0 , 1 )
277
+ target_euler_angles_w = torch .transpose (
278
+ torch .stack (math_utils .euler_xyz_from_quat (self ._processed_actions [:, 3 :7 ])), 0 , 1
279
+ )
276
280
# Clip the euler angles
277
281
clamped_target_euler_angles_w = torch .clamp (
278
282
target_euler_angles_w , min = self ._clip [:, 3 :, 0 ], max = self ._clip [:, 3 :, 1 ]
@@ -598,7 +602,7 @@ def reset(self, env_ids: Sequence[int] | None = None) -> None:
598
602
self ._contact_sensor .reset (env_ids )
599
603
if self ._task_frame_transformer is not None :
600
604
self ._task_frame_transformer .reset (env_ids )
601
-
605
+
602
606
"""
603
607
Parameter modification functions.
604
608
"""
@@ -615,18 +619,20 @@ def set_clipping_values(
615
619
for wrench. The setter performs a direct assignment so no copy is made. If a list is provided, it must be a list
616
620
of tuples, each containing two values. The setter will convert the list to a tensor and assign it to the clipping
617
621
values.
618
-
622
+
619
623
Args:
620
624
position_clip: The clipping values for the position command.
621
625
orientation_clip: The clipping values for the orientation command.
622
626
wrench_clip: The clipping values for the wrench command.
623
627
"""
624
-
628
+
625
629
if position_clip is not None :
626
630
position_clip = self ._validate_clipping_values (self ._clip_position .shape , position_clip , "position_clip" )
627
631
self ._clip_position = position_clip
628
632
if orientation_clip is not None :
629
- orientation_clip = self ._validate_clipping_values (self ._clip_orientation .shape , orientation_clip , "orientation_clip" )
633
+ orientation_clip = self ._validate_clipping_values (
634
+ self ._clip_orientation .shape , orientation_clip , "orientation_clip"
635
+ )
630
636
self ._clip_orientation = orientation_clip
631
637
if wrench_clip is not None :
632
638
wrench_clip = self ._validate_clipping_values (self ._clip_wrench .shape , wrench_clip , "wrench_clip" )
@@ -881,7 +887,9 @@ def _preprocess_actions(self, actions: torch.Tensor):
881
887
else :
882
888
self ._processed_actions [:, self ._pose_abs_idx : self ._pose_abs_idx + 3 ] *= self ._position_scale
883
889
if self ._clip_orientation is not None :
884
- normed_quat = math_utils .normalize (self .processed_actions [:, self ._pose_abs_idx + 3 : self ._pose_abs_idx + 7 ] * self ._orientation_scale )
890
+ normed_quat = math_utils .normalize (
891
+ self .processed_actions [:, self ._pose_abs_idx + 3 : self ._pose_abs_idx + 7 ] * self ._orientation_scale
892
+ )
885
893
rpy = torch .transpose (torch .stack (math_utils .euler_xyz_from_quat (normed_quat )), 0 , 1 )
886
894
rpy_clamped = torch .clamp (rpy , min = self ._clip_orientation [:, :, 0 ], max = self ._clip_orientation [:, :, 1 ])
887
895
self .processed_actions [:, self ._pose_abs_idx + 3 : self ._pose_abs_idx + 7 ] = (
@@ -899,7 +907,9 @@ def _preprocess_actions(self, actions: torch.Tensor):
899
907
else :
900
908
self ._processed_actions [:, self ._pose_rel_idx : self ._pose_rel_idx + 3 ] *= self ._position_scale
901
909
if self ._clip_orientation is not None :
902
- rpy = self .processed_actions [:, self ._pose_rel_idx + 3 : self ._pose_rel_idx + 6 ] * self ._orientation_scale
910
+ rpy = (
911
+ self .processed_actions [:, self ._pose_rel_idx + 3 : self ._pose_rel_idx + 6 ] * self ._orientation_scale
912
+ )
903
913
rpy_clamped = torch .clamp (rpy , min = self ._clip_orientation [:, :, 0 ], max = self ._clip_orientation [:, :, 1 ])
904
914
self .processed_actions [:, self ._pose_rel_idx + 3 : self ._pose_rel_idx + 6 ] = rpy_clamped
905
915
else :
@@ -926,7 +936,7 @@ def _preprocess_actions(self, actions: torch.Tensor):
926
936
@staticmethod
927
937
def _gen_clip (control_flags , clip_cfg : list [tuple [float , float ]], name : str ) -> list [tuple [float , float ]]:
928
938
"""Generates the clipping configuration for the operational space controller.
929
-
939
+
930
940
The expected format is a list of tuples, each containing two values. Note that the order in which the tuples are provided
931
941
must match the order of the active axes in motion_control_axes_task.
932
942
@@ -938,15 +948,19 @@ def _gen_clip(control_flags, clip_cfg: list[tuple[float, float]], name: str) ->
938
948
allowed_names = ["clip_position" , "clip_orientation" , "clip_wrench" ]
939
949
if name not in allowed_names :
940
950
raise ValueError (f"Expected { name } to be one of { allowed_names } but got { name } " )
941
-
951
+
942
952
# Iterate over the control flags and add the corresponding clip to the list
943
953
if name == "clip_position" :
944
954
control_flags = control_flags [:3 ]
945
955
elif name == "clip_orientation" :
946
956
control_flags = control_flags [3 :]
947
957
# Ensure the length of the clip_cfg is the same as the number of active axes
948
958
if len (clip_cfg ) != sum (control_flags ):
949
- raise ValueError (f"{ name } must be a list of tuples of the same length as there are active axes in motion_control_axes_task. There are { sum (control_flags )} active axes and { len (clip_cfg )} tuples in { name } ." )
959
+ raise ValueError (
960
+ f"{ name } must be a list of tuples of the same length as there are active axes in"
961
+ f" motion_control_axes_task. There are { sum (control_flags )} active axes and { len (clip_cfg )} tuples in"
962
+ f" { name } ."
963
+ )
950
964
clip_pose_abs_new = []
951
965
for i , flag in enumerate (control_flags ):
952
966
# If the axis is active, add the corresponding clip
@@ -959,48 +973,56 @@ def _gen_clip(control_flags, clip_cfg: list[tuple[float, float]], name: str) ->
959
973
clip_pose_abs_new .append (clip )
960
974
else :
961
975
# If the axis is not active, add a clip of (-inf, inf). (Don't clip)
962
- clip_pose_abs_new .append ((- float (' inf' ), float (' inf' )))
963
- return clip_pose_abs_new
976
+ clip_pose_abs_new .append ((- float (" inf" ), float (" inf" )))
977
+ return clip_pose_abs_new
964
978
965
979
def _parse_clipping_cfg (self , cfg : actions_cfg .OperationalSpaceControllerActionCfg ) -> None :
966
980
"""Parses the clipping configuration for the operational space controller.
967
-
981
+
968
982
Args:
969
983
cfg: The configuration of the action term.
970
984
"""
971
985
972
986
# Parse clip_position
973
987
if cfg .clip_position is not None :
974
- clip_position = self ._gen_clip (self .cfg .controller_cfg .motion_control_axes_task , cfg .clip_position , "clip_position" )
988
+ clip_position = self ._gen_clip (
989
+ self .cfg .controller_cfg .motion_control_axes_task , cfg .clip_position , "clip_position"
990
+ )
975
991
self ._clip_position = torch .zeros ((self .num_envs , 3 , 2 ), device = self .device )
976
992
self ._clip_position [:] = torch .tensor (clip_position , device = self .device )
977
993
else :
978
994
self ._clip_position = None
979
995
980
996
# Parse clip_orientation
981
997
if cfg .clip_orientation is not None :
982
- clip_orientation = self ._gen_clip (self .cfg .controller_cfg .motion_control_axes_task , cfg .clip_orientation , "clip_orientation" )
998
+ clip_orientation = self ._gen_clip (
999
+ self .cfg .controller_cfg .motion_control_axes_task , cfg .clip_orientation , "clip_orientation"
1000
+ )
983
1001
self ._clip_orientation = torch .zeros ((self .num_envs , 3 , 2 ), device = self .device )
984
1002
self ._clip_orientation [:] = torch .tensor (clip_orientation , device = self .device )
985
1003
else :
986
1004
self ._clip_orientation = None
987
1005
988
- # Parse clip_wrench
1006
+ # Parse clip_wrench
989
1007
if cfg .clip_wrench is not None :
990
- clip_wrench = self ._gen_clip (self .cfg .controller_cfg .contact_wrench_control_axes_task , cfg .clip_wrench , "clip_wrench" )
1008
+ clip_wrench = self ._gen_clip (
1009
+ self .cfg .controller_cfg .contact_wrench_control_axes_task , cfg .clip_wrench , "clip_wrench"
1010
+ )
991
1011
self ._clip_wrench = torch .zeros ((self .num_envs , 6 , 2 ), device = self .device )
992
1012
self ._clip_wrench [:] = torch .tensor (clip_wrench , device = self .device )
993
1013
else :
994
1014
self ._clip_wrench = None
995
1015
996
- def _validate_clipping_values (self , target_shape : torch .Tensor , value : list [tuple [float , float ]] | torch .Tensor , name : str ) -> torch .Tensor :
1016
+ def _validate_clipping_values (
1017
+ self , target_shape : torch .Tensor , value : list [tuple [float , float ]] | torch .Tensor , name : str
1018
+ ) -> torch .Tensor :
997
1019
"""Validates the clipping values for the operational space controller.
998
1020
999
1021
Args:
1000
1022
target_shape: The shape of the target tensor.
1001
1023
value: The clipping values to validate.
1002
1024
name: The name of the clipping configuration.
1003
-
1025
+
1004
1026
Returns:
1005
1027
The validated clipping values.
1006
1028
@@ -1012,7 +1034,7 @@ def _validate_clipping_values(self, target_shape: torch.Tensor, value: list[tupl
1012
1034
allowed_names = ["position_clip" , "orientation_clip" , "wrench_clip" ]
1013
1035
if name not in allowed_names :
1014
1036
raise ValueError (f"Expected { name } to be one of { allowed_names } but got { name } " )
1015
-
1037
+
1016
1038
if isinstance (value , torch .Tensor ):
1017
1039
if value .shape != target_shape :
1018
1040
raise ValueError (f"Expected { name } to be a tensor of shape { target_shape } but got { value .shape } " )
@@ -1030,7 +1052,7 @@ def _validate_clipping_values(self, target_shape: torch.Tensor, value: list[tupl
1030
1052
else :
1031
1053
raise ValueError (f"Expected { name } to be a tensor or a list but got { type (value )} " )
1032
1054
return tensor_clip
1033
-
1055
+
1034
1056
def _validate_scale_values (self , value : float | torch .Tensor , name : str ) -> torch .Tensor :
1035
1057
"""Validates the scale values for the operational space controller.
1036
1058
@@ -1055,4 +1077,3 @@ def _validate_scale_values(self, value: float | torch.Tensor, name: str) -> torc
1055
1077
else :
1056
1078
raise ValueError (f"Expected { name } to be a tensor or a float but got { type (value )} " )
1057
1079
return value
1058
-
0 commit comments