Skip to content

Commit da35462

Browse files
committed
Added tests for whole-series ROCKAD
1 parent c77fe1e commit da35462

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Tests for the ROCKAD anomaly detector."""
2+
3+
import numpy as np
4+
import pytest
5+
from sklearn.utils import check_random_state
6+
7+
from aeon.anomaly_detection.whole_series import ROCKAD
8+
9+
10+
def test_rockad_univariate():
11+
"""Test ROCKAD univariate output."""
12+
rng = check_random_state(seed=2)
13+
train_series = rng.normal(loc=0.0, scale=1.0, size=(10, 100))
14+
test_series = rng.normal(loc=0.0, scale=1.0, size=(5, 100))
15+
16+
test_series[0][50:58] -= 5
17+
18+
ad = ROCKAD(n_estimators=100, n_kernels=10, n_neighbors=9)
19+
20+
ad.fit(train_series)
21+
pred = ad.predict(test_series)
22+
23+
assert pred.shape == (5,)
24+
assert pred.dtype == np.float64
25+
assert 0 <= np.argmax(pred) <= 1
26+
27+
28+
def test_rockad_multivariate():
29+
"""Test ROCKAD multivariate output."""
30+
rng = check_random_state(seed=2)
31+
train_series = rng.normal(loc=0.0, scale=1.0, size=(10, 3, 100))
32+
test_series = rng.normal(loc=0.0, scale=1.0, size=(5, 3, 100))
33+
34+
test_series[0][0][50:58] -= 5
35+
36+
ad = ROCKAD(n_estimators=1000, n_kernels=100, n_neighbors=9)
37+
38+
ad.fit(train_series)
39+
pred = ad.predict(test_series)
40+
41+
assert pred.shape == (5,)
42+
assert pred.dtype == np.float64
43+
assert 0 <= np.argmax(pred) <= 1
44+
45+
46+
def test_rockad_incorrect_input():
47+
"""Test ROCKAD with invalid inputs."""
48+
rng = check_random_state(seed=2)
49+
series = rng.normal(size=(10, 5))
50+
51+
with pytest.warns(
52+
UserWarning, match=r"Power Transform failed and thus has been disabled."
53+
):
54+
ad = ROCKAD()
55+
ad.fit(series)
56+
57+
train_series = rng.normal(loc=0.0, scale=1.0, size=(10, 100))
58+
test_series = rng.normal(loc=0.0, scale=1.0, size=(3, 100))
59+
60+
with pytest.raises(
61+
ValueError,
62+
match="""Expected n_neighbors <= n_samples_fit, but n_neighbors = 100,
63+
n_samples_fit = 10, n_samples = 3""",
64+
):
65+
ad = ROCKAD(n_estimators=100, n_kernels=10, n_neighbors=100)
66+
ad.fit(train_series)
67+
ad.predict(test_series)

0 commit comments

Comments
 (0)