|
837 | 837 | " \"\"\"\n",
|
838 | 838 | " angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)\n",
|
839 | 839 | " half_angles = angles * 0.5\n",
|
840 |
| - " eps = 1e-6\n", |
841 | 840 | " small_angles = angles.abs() < eps\n",
|
842 |
| - " sin_half_angles_over_angles = torch.empty_like(angles)\n", |
843 |
| - " sin_half_angles_over_angles[~small_angles] = (\n", |
844 |
| - " torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n", |
845 |
| - " )\n", |
846 |
| - " # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n", |
847 |
| - " # so sin(x/2)/x is about 1/2 - (x*x)/48\n", |
848 |
| - " sin_half_angles_over_angles[small_angles] = (\n", |
849 |
| - " 0.5 - (angles[small_angles] * angles[small_angles]) / 48\n", |
850 |
| - " )\n", |
851 |
| - " quaternions = torch.cat(\n", |
852 |
| - " [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1\n", |
853 |
| - " )\n", |
| 841 | + " large = torch.sin(half_angles) / angles\n", |
| 842 | + " small = 0.5 - (angles * angles) / 48\n", |
| 843 | + " sin_half_angles_over_angles = torch.where(small_angles, small, large)\n", |
| 844 | + " quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1)\n", |
854 | 845 | " return quaternions\n",
|
855 | 846 | "\n",
|
856 | 847 | "\n",
|
|
0 commit comments