@@ -133,8 +133,8 @@ def copysign(mag: float, other: torch.Tensor) -> torch.Tensor:
133
133
Returns:
134
134
The output tensor.
135
135
"""
136
- mag_torch = abs (mag ) * torch .ones_like (other )
137
- return torch .copysign (mag_torch , other )
136
+ mag_torch = torch . tensor (mag , device = other . device , dtype = torch .float ). repeat (other . shape [ 0 ] )
137
+ return torch .abs (mag_torch ) * torch . sign ( other )
138
138
139
139
140
140
"""
@@ -250,7 +250,7 @@ def quat_conjugate(q: torch.Tensor) -> torch.Tensor:
250
250
"""
251
251
shape = q .shape
252
252
q = q .reshape (- 1 , 4 )
253
- return torch .cat ((q [... , 0 :1 ], - q [... , 1 :]), dim = - 1 ).view (shape )
253
+ return torch .cat ((q [: , 0 :1 ], - q [: , 1 :]), dim = - 1 ).view (shape )
254
254
255
255
256
256
@torch .jit .script
@@ -401,7 +401,7 @@ def _axis_angle_rotation(axis: Literal["X", "Y", "Z"], angle: torch.Tensor) -> t
401
401
402
402
def matrix_from_euler (euler_angles : torch .Tensor , convention : str ) -> torch .Tensor :
403
403
"""
404
- Convert rotations given as Euler angles (intrinsic) in radians to rotation matrices.
404
+ Convert rotations given as Euler angles in radians to rotation matrices.
405
405
406
406
Args:
407
407
euler_angles: Euler angles in radians. Shape is (..., 3).
@@ -436,7 +436,7 @@ def euler_xyz_from_quat(
436
436
"""Convert rotations given as quaternions to Euler angles in radians.
437
437
438
438
Note:
439
- The euler angles are assumed in XYZ extrinsic convention.
439
+ The euler angles are assumed in XYZ convention.
440
440
441
441
Args:
442
442
quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
@@ -928,8 +928,14 @@ def compute_pose_error(
928
928
Raises:
929
929
ValueError: Invalid rotation error type.
930
930
"""
931
- # Compute quaternion error (i.e., quat_box_minus)
932
- quat_error = quat_mul (q01 , quat_conjugate (q02 ))
931
+ # Compute quaternion error (i.e., difference quaternion)
932
+ # Reference: https://personal.utdallas.edu/~sxb027100/dock/quaternion.html
933
+ # q_current_norm = q_current * q_current_conj
934
+ source_quat_norm = quat_mul (q01 , quat_conjugate (q01 ))[:, 0 ]
935
+ # q_current_inv = q_current_conj / q_current_norm
936
+ source_quat_inv = quat_conjugate (q01 ) / source_quat_norm .unsqueeze (- 1 )
937
+ # q_error = q_target * q_current_inv
938
+ quat_error = quat_mul (q02 , source_quat_inv )
933
939
934
940
# Compute position error
935
941
pos_error = t02 - t01
0 commit comments