Skip to content

Commit 1e930f5

Browse files
committed
Improve batching of axis_angle_to_quaternion implementation
1 parent 6c85cfd commit 1e930f5

File tree

2 files changed

+7
-23
lines changed

2 files changed

+7
-23
lines changed

diffdrr/pose.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -637,17 +637,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
637637
"""
638638
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
639639
half_angles = angles * 0.5
640-
eps = 1e-6
641640
small_angles = angles.abs() < eps
642-
sin_half_angles_over_angles = torch.empty_like(angles)
643-
sin_half_angles_over_angles[~small_angles] = (
644-
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
645-
)
646-
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
647-
# so sin(x/2)/x is about 1/2 - (x*x)/48
648-
sin_half_angles_over_angles[small_angles] = (
649-
0.5 - (angles[small_angles] * angles[small_angles]) / 48
650-
)
641+
large = torch.sin(half_angles) / angles
642+
small = 0.5 - (angles * angles) / 48
643+
sin_half_angles_over_angles = torch.where(small_angles, small, large)
651644
quaternions = torch.cat(
652645
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
653646
)

notebooks/api/06_pose.ipynb

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -837,20 +837,11 @@
837837
" \"\"\"\n",
838838
" angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)\n",
839839
" half_angles = angles * 0.5\n",
840-
" eps = 1e-6\n",
841840
" small_angles = angles.abs() < eps\n",
842-
" sin_half_angles_over_angles = torch.empty_like(angles)\n",
843-
" sin_half_angles_over_angles[~small_angles] = (\n",
844-
" torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n",
845-
" )\n",
846-
" # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n",
847-
" # so sin(x/2)/x is about 1/2 - (x*x)/48\n",
848-
" sin_half_angles_over_angles[small_angles] = (\n",
849-
" 0.5 - (angles[small_angles] * angles[small_angles]) / 48\n",
850-
" )\n",
851-
" quaternions = torch.cat(\n",
852-
" [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1\n",
853-
" )\n",
841+
" large = torch.sin(half_angles) / angles\n",
842+
" small = 0.5 - (angles * angles) / 48\n",
843+
" sin_half_angles_over_angles = torch.where(small_angles, small, large)\n",
844+
" quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1)\n",
854845
" return quaternions\n",
855846
"\n",
856847
"\n",

0 commit comments

Comments
 (0)