Skip to content

Commit 04c5948

Browse files
authored
feat: compute_kinetic_energy for per-individual KE decomposition (#228) (#623)
* fix: update kinetic_energy.py to comply with Ruff linting and formatting * updated code as per niks suggestions * updated imports & examples
1 parent 7e83580 commit 04c5948

File tree

3 files changed

+337
-0
lines changed

3 files changed

+337
-0
lines changed

movement/kinematics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
compute_forward_vector_angle,
1515
compute_head_direction_vector,
1616
)
17+
from movement.kinematics.kinetic_energy import compute_kinetic_energy
1718

1819
__all__ = [
1920
"compute_displacement",
@@ -26,4 +27,5 @@
2627
"compute_forward_vector",
2728
"compute_head_direction_vector",
2829
"compute_forward_vector_angle",
30+
"compute_kinetic_energy",
2931
]
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""Functions for computing kinetic energy."""
2+
3+
import numpy as np
4+
import xarray as xr
5+
6+
from movement.kinematics.kinematics import compute_velocity
7+
from movement.utils.logging import logger
8+
from movement.utils.vector import compute_norm
9+
from movement.validators.arrays import validate_dims_coords
10+
11+
12+
def compute_kinetic_energy(
13+
position: xr.DataArray,
14+
keypoints: list | None = None,
15+
masses: dict | None = None,
16+
decompose: bool = False,
17+
) -> xr.DataArray:
18+
r"""Compute kinetic energy per individual.
19+
20+
We consider each individual's set of keypoints (pose) as a classical
21+
system of particles in physics (see Notes).
22+
23+
Parameters
24+
----------
25+
position : xr.DataArray
26+
The input data containing position information, with ``time``,
27+
``space`` and ``keypoints`` as required dimensions.
28+
keypoints : list, optional
29+
A list of keypoint names to include in the computation.
30+
By default, all are used.
31+
masses : dict, optional
32+
A dictionary mapping keypoint names to masses, e.g.
33+
{"snout": 1.2, "tail": 0.8}.
34+
By default, unit mass is assumed for all keypoints.
35+
decompose : bool, optional
36+
If True, the kinetic energy is decomposed into "translational" and
37+
"internal" components (see Notes). This requires at least two keypoints
38+
per individual, but more would be desirable for a meaningful
39+
decomposition. The default is False, meaning the total kinetic energy
40+
is returned.
41+
42+
Returns
43+
-------
44+
xr.DataArray
45+
A data array containing the kinetic energy per individual, for every
46+
time point. Note that the output array lacks ``space`` and
47+
``keypoints`` dimensions.
48+
If ``decompose=True`` an extra ``energy`` dimension is added,
49+
with coordinates ``translational`` and ``internal``.
50+
51+
Notes
52+
-----
53+
Considering a given individual at time point :math:`t` as a system of
54+
keypoint particles, its total kinetic energy :math:`T_{total}` is given by:
55+
56+
.. math:: T_{total} = \sum_{i} \frac{1}{2} m_i \| \mathbf{v}_i(t) \|^2
57+
58+
where :math:`m_i` is the mass of the :math:`i`-th keypoint and
59+
:math:`\mathbf{v}_i(t)` is its velocity at time :math:`t`.
60+
61+
From Samuel König's second theorem, we can decompose :math:`T_{total}`
62+
into:
63+
64+
- Translational kinetic energy: the kinetic energy of the individual's
65+
total mass :math:`M` moving with the centre of mass velocity;
66+
- Internal kinetic energy: the kinetic energy of the keypoints moving
67+
relative to the individual's centre of mass.
68+
69+
We compute translational kinetic energy :math:`T_{trans}` as follows:
70+
71+
.. math:: T_{trans} = \frac{1}{2} M \| \mathbf{v}_{cm}(t) \|^2
72+
73+
where :math:`M = \sum_{i} m_i` is the total mass of the individual
74+
and :math:`\mathbf{v}_{cm}(t) = \frac{1}{M} \sum_{i} m_i \mathbf{v}_i(t)`
75+
is the velocity of the centre of mass at time :math:`t`
76+
(computed as the weighted mean of keypoint velocities).
77+
78+
Internal kinetic energy :math:`T_{int}` is derived as the difference
79+
between the total and translational components:
80+
81+
.. math:: T_{int} = T_{total} - T_{trans}
82+
83+
Examples
84+
--------
85+
>>> from movement.kinematics import compute_kinetic_energy
86+
>>> import numpy as np
87+
>>> import xarray as xr
88+
89+
Compute total kinetic energy:
90+
91+
>>> position = xr.DataArray(
92+
... np.random.rand(3, 2, 4, 2),
93+
... coords={
94+
... "time": np.arange(3),
95+
... "individuals": ["id0", "id1"],
96+
... "keypoints": ["snout", "spine", "tail_base", "tail_tip"],
97+
... "space": ["x", "y"],
98+
... },
99+
... dims=["time", "individuals", "keypoints", "space"],
100+
... )
101+
102+
>>> kinetic_energy_total = compute_kinetic_energy(position)
103+
104+
>>> kinetic_energy_total
105+
<xarray.DataArray (time: 3, individuals: 2)> Size: 48B
106+
0.6579 0.7394 0.1304 0.05152 0.2436 0.5719
107+
Coordinates:
108+
* time (time) int64 24B 0 1 2
109+
* individuals (individuals) <U3 24B 'id0' 'id1'
110+
111+
Compute kinetic energy decomposed into translational
112+
and internal components:
113+
114+
>>> kinetic_energy = compute_kinetic_energy(position, decompose=True)
115+
116+
>>> kinetic_energy
117+
<xarray.DataArray (time: 3, individuals: 2, energy: 2)> Size: 96B
118+
0.0172 1.318 0.02069 0.6498 0.02933 ... 0.1716 0.07829 0.7942 0.06901 0.857
119+
Coordinates:
120+
* time (time) int64 24B 0 1 2
121+
* individuals (individuals) <U3 24B 'id0' 'id1'
122+
* energy (energy) <U13 104B 'translational' 'internal'
123+
124+
Select the 'internal' component:
125+
126+
>>> kinetic_energy_internal = kinetic_energy.sel(energy="internal")
127+
128+
Use unequal keypoint masses and exclude an unreliable keypoint
129+
(e.g. "tail_tip"):
130+
131+
>>> masses = {"snout": 1.2, "spine": 0.8, "tail_base": 1.0}
132+
133+
>>> kinetic_energy = compute_kinetic_energy(
134+
... position,
135+
... keypoints=["snout", "spine", "tail_base"],
136+
... masses=masses,
137+
... decompose=True,
138+
... )
139+
140+
"""
141+
# Validate required dimensions and coordinate labels
142+
validate_dims_coords(
143+
position, {"time": [], "space": ["x", "y"], "keypoints": []}
144+
)
145+
146+
# Subset keypoints if requested
147+
if keypoints is not None:
148+
position = position.sel(keypoints=keypoints)
149+
150+
# Validate that at least 2 keypoints exist for decomposition
151+
if decompose and position.sizes["keypoints"] < 2:
152+
raise logger.error(
153+
ValueError(
154+
"At least 2 keypoints are required to decompose "
155+
"kinetic energy into translational and internal components."
156+
)
157+
)
158+
159+
# Compute velocity from position
160+
velocity = compute_velocity(position)
161+
162+
# Initialise unit weights
163+
weights = xr.DataArray(
164+
np.ones(position.sizes["keypoints"]),
165+
dims=["keypoints"],
166+
coords={"keypoints": position.coords["keypoints"]},
167+
)
168+
169+
# Update weights with keypoint masses, if provided
170+
if masses:
171+
for keypoint, mass in masses.items():
172+
weights.loc[keypoint] = mass
173+
174+
# Compute total KE
175+
weighted_ke = 0.5 * weights * (compute_norm(velocity) ** 2)
176+
ke_total = weighted_ke.sum(dim="keypoints")
177+
178+
if not decompose:
179+
return ke_total
180+
else:
181+
# Compute translational KE based on centre of mass velocity
182+
v_cm = (velocity * weights.expand_dims(space=["x", "y"])).sum(
183+
dim="keypoints"
184+
) / weights.sum()
185+
ke_trans = 0.5 * weights.sum() * compute_norm(v_cm) ** 2
186+
187+
# Internal KE
188+
ke_int = ke_total - ke_trans
189+
190+
# Format output
191+
ke = xr.concat([ke_trans, ke_int], dim="energy")
192+
ke = ke.assign_coords(energy=["translational", "internal"])
193+
ke = ke.transpose("time", ..., "energy")
194+
return ke
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from contextlib import nullcontext as does_not_raise
2+
3+
import numpy as np
4+
import pytest
5+
import xarray as xr
6+
7+
from movement.kinematics.kinetic_energy import compute_kinetic_energy
8+
9+
10+
@pytest.mark.parametrize("decompose", [True, False])
11+
def test_basic_shape_and_values(decompose):
12+
"""Basic sanity check with simple data."""
13+
data = np.array(
14+
[[[[1, 0], [0, 1], [1, 1]]], [[[2, 0], [0, 2], [2, 2]]]]
15+
) # shape: (time, individuals, keypoints, space)
16+
position = xr.DataArray(
17+
data,
18+
dims=["time", "individuals", "keypoints", "space"],
19+
coords={
20+
"time": [0, 1],
21+
"individuals": [0],
22+
"keypoints": [0, 1, 2],
23+
"space": ["x", "y"],
24+
},
25+
)
26+
result = compute_kinetic_energy(position, decompose=decompose)
27+
if decompose:
28+
assert set(result.dims) == {"time", "individuals", "energy"}
29+
assert list(result.coords["energy"].values) == [
30+
"translational",
31+
"internal",
32+
]
33+
assert result.shape == (2, 1, 2)
34+
else:
35+
assert set(result.dims) == {"time", "individuals"}
36+
assert result.shape == (2, 1)
37+
assert (result >= 0).all()
38+
39+
40+
def test_uniform_linear_motion(valid_poses_dataset):
41+
"""Uniform rigid motion:
42+
expect translational energy > 0, internal ≈ 0.
43+
"""
44+
ds = valid_poses_dataset.copy(deep=True)
45+
energy = compute_kinetic_energy(ds["position"], decompose=True)
46+
trans = energy.sel(energy="translational")
47+
internal = energy.sel(energy="internal")
48+
assert np.allclose(trans, 3)
49+
assert np.allclose(internal, 0)
50+
51+
52+
@pytest.fixture
53+
def spinning_dataset():
54+
"""Create synthetic rotational-only dataset."""
55+
time = 10
56+
keypoints = 4
57+
angles = np.linspace(0, 2 * np.pi, time)
58+
radius = 1.0
59+
60+
positions = []
61+
for theta in angles:
62+
snapshot = []
63+
for k in range(keypoints):
64+
angle = theta + k * np.pi / 2
65+
snapshot.append([radius * np.cos(angle), radius * np.sin(angle)])
66+
positions.append([snapshot]) # 1 individual
67+
68+
return xr.DataArray(
69+
np.array(positions),
70+
dims=["time", "individuals", "keypoints", "space"],
71+
coords={
72+
"time": np.arange(time),
73+
"individuals": ["id0"],
74+
"keypoints": [f"k{i}" for i in range(keypoints)],
75+
"space": ["x", "y"],
76+
},
77+
)
78+
79+
80+
def test_pure_rotation(spinning_dataset):
81+
"""In pure rotational motion, translational energy ≈ 0."""
82+
energy = compute_kinetic_energy(spinning_dataset, decompose=True)
83+
trans = energy.sel(energy="translational")
84+
internal = energy.sel(energy="internal")
85+
assert np.allclose(trans, 0)
86+
assert (internal > 0).all()
87+
88+
89+
@pytest.mark.parametrize(
90+
"masses",
91+
[
92+
{"centroid": 2.0, "left": 2.0, "right": 2.0},
93+
{"centroid": 0.4, "left": 0.3, "right": 0.3},
94+
],
95+
)
96+
def test_weighted_kinetic_energy(valid_poses_dataset, masses):
97+
"""Kinetic energy should scale linearly with individual's total mass
98+
if velocity is constant.
99+
"""
100+
ds = valid_poses_dataset.copy(deep=True)
101+
position = ds["position"]
102+
unweighted = compute_kinetic_energy(position)
103+
weighted = compute_kinetic_energy(position, masses=masses)
104+
factor = sum(masses.values()) / position.sizes["keypoints"]
105+
xr.testing.assert_allclose(weighted, unweighted * factor)
106+
107+
108+
@pytest.mark.parametrize(
109+
"valid_poses_dataset, keypoints, expected_exception",
110+
[
111+
pytest.param(
112+
"multi_individual_array",
113+
None,
114+
does_not_raise(),
115+
id="3-keypoints (sufficient)",
116+
),
117+
pytest.param(
118+
"multi_individual_array",
119+
["centroid"],
120+
pytest.raises(ValueError, match="At least 2 keypoints"),
121+
id="3-keypoints 1-selected (insufficient)",
122+
),
123+
pytest.param(
124+
"single_keypoint_array",
125+
None,
126+
pytest.raises(ValueError, match="At least 2 keypoints"),
127+
id="1-keypoint (insufficient)",
128+
),
129+
],
130+
indirect=["valid_poses_dataset"],
131+
)
132+
def test_insufficient_keypoints(
133+
valid_poses_dataset, keypoints, expected_exception
134+
):
135+
"""Function should raise error if fewer than 2 keypoints."""
136+
with expected_exception:
137+
compute_kinetic_energy(
138+
valid_poses_dataset["position"],
139+
keypoints=keypoints,
140+
decompose=True,
141+
)

0 commit comments

Comments
 (0)