Skip to content

[ENH] Add whole-series ROCKAD anomaly detector #2871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aeon/anomaly_detection/collection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"BaseCollectionAnomalyDetector",
"ClassificationAdapter",
"OutlierDetectionAdapter",
"ROCKAD",
]

from aeon.anomaly_detection.collection._classification import ClassificationAdapter
from aeon.anomaly_detection.collection._outlier_detection import OutlierDetectionAdapter
from aeon.anomaly_detection.collection._rockad import ROCKAD
from aeon.anomaly_detection.collection.base import BaseCollectionAnomalyDetector
223 changes: 223 additions & 0 deletions aeon/anomaly_detection/collection/_rockad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""ROCKAD anomaly detector."""

__all__ = ["ROCKAD"]

import warnings
from typing import Optional

import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import PowerTransformer
from sklearn.utils import resample

from aeon.anomaly_detection.collection.base import BaseCollectionAnomalyDetector
from aeon.transformations.collection.convolution_based import Rocket


class ROCKAD(BaseCollectionAnomalyDetector):
"""
ROCKET-based whole-series Anomaly Detector (ROCKAD).

ROCKAD [1]_ leverages the ROCKET transformation for feature extraction from
time series data and applies the scikit learn k-nearest neighbors (k-NN)
approach with bootstrap aggregation for robust semi-supervised anomaly detection.
The data gets transformed into the ROCKET feature space.
Then the whole-series are compared based on the feature space by
finding the nearest neighbours. The time-point based ROCKAD anomaly detector
can be found at aeon/anomaly_detection/series/distance_based/_rockad.py

This class supports both univariate and multivariate time series and
provides options for normalizing features, applying power transformations,
and customizing the distance metric.

Parameters
----------
n_estimators : int, default=10
Number of k-NN estimators to use in the bootstrap aggregation.
n_kernels : int, default=100
Number of kernels to use in the ROCKET transformation.
normalise : bool, default=False
Whether to normalize the ROCKET-transformed features.
n_neighbors : int, default=5
Number of neighbors to use for the k-NN algorithm.
n_jobs : int, default=1
Number of parallel jobs to use for the k-NN algorithm and ROCKET transformation.
metric : str, default="euclidean"
Distance metric to use for the k-NN algorithm.
power_transform : bool, default=True
Whether to apply a power transformation (Yeo-Johnson) to the features.
random_state : int, default=42
Random seed for reproducibility.

Attributes
----------
rocket_transformer_ : Optional[Rocket]
Instance of the ROCKET transformer used to extract features, set after fitting.
list_baggers_ : Optional[list[NearestNeighbors]]
List containing k-NN estimators used for anomaly scoring, set after fitting.
power_transformer_ : PowerTransformer
Transformer used to apply power transformation to the features.

References
----------
.. [1] Theissler, A., Wengert, M., Gerschner, F. (2023).
ROCKAD: Transferring ROCKET to Whole Time Series Anomaly Detection.
In: Crémilleux, B., Hess, S., Nijssen, S. (eds) Advances in Intelligent
Data Analysis XXI. IDA 2023. Lecture Notes in Computer Science,
vol 13876. Springer, Cham. https://doi.org/10.1007/978-3-031-30047-9_33

Examples
--------
>>> import numpy as np
>>> from aeon.anomaly_detection.collection import ROCKAD
>>> rng = np.random.default_rng(seed=42)
>>> X_train = rng.normal(loc=0.0, scale=1.0, size=(10, 100))
>>> X_test = rng.normal(loc=0.0, scale=1.0, size=(5, 100))
>>> X_test[4][50:58] -= 5
>>> detector = ROCKAD() # doctest: +SKIP
>>> detector.fit(X_train) # doctest: +SKIP
>>> detector.predict(X_test) # doctest: +SKIP
array([24.11974147, 23.93866453, 21.3941765 , 22.26811959, 64.9630108 ])
"""

_tags = {
"anomaly_output_type": "anomaly_scores",
"learning_type:semi_supervised": True,
"capability:univariate": True,
"capability:multivariate": True,
"capability:missing_values": False,
"capability:multithreading": True,
"fit_is_empty": False,
}

def __init__(
self,
n_estimators=10,
n_kernels=100,
normalise=False,
n_neighbors=5,
metric="euclidean",
power_transform=True,
n_jobs=1,
random_state=42,
):

self.n_estimators = n_estimators
self.n_kernels = n_kernels
self.normalise = normalise
self.n_neighbors = n_neighbors
self.n_jobs = n_jobs
self.metric = metric
self.power_transform = power_transform
self.random_state = random_state

self.rocket_transformer_: Optional[Rocket] = None
self.list_baggers_: Optional[list[NearestNeighbors]] = None
self.power_transformer_: Optional[PowerTransformer] = None

super().__init__()

def _fit(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> "ROCKAD":
_X = X
self._inner_fit(_X)

return self

def _inner_fit(self, X: np.ndarray) -> None:

self.rocket_transformer_ = Rocket(
n_kernels=self.n_kernels,
normalise=self.normalise,
n_jobs=self.n_jobs,
random_state=self.random_state,
)
# XT: (n_cases, n_kernels*2)
Xt = self.rocket_transformer_.fit_transform(X)
Xt = Xt.astype(np.float64)

if self.power_transform:
self.power_transformer_ = PowerTransformer()
try:
Xtp = self.power_transformer_.fit_transform(Xt)

except Exception:
warnings.warn(
"Power Transform failed and thus has been disabled. ",
UserWarning,
stacklevel=2,
)
self.power_transformer_ = None
Xtp = Xt
else:
Xtp = Xt

self.list_baggers_ = []

for idx_estimator in range(self.n_estimators):
# Initialize estimator
estimator = NearestNeighbors(
n_neighbors=self.n_neighbors,
n_jobs=self.n_jobs,
metric=self.metric,
algorithm="kd_tree",
)
# Bootstrap Aggregation
Xtp_scaled_sample = resample(
Xtp,
replace=True,
n_samples=None,
random_state=self.random_state + idx_estimator,
stratify=None,
)

# Fit estimator and append to estimator list
estimator.fit(Xtp_scaled_sample)
self.list_baggers_.append(estimator)

def _predict(self, X) -> np.ndarray:
_X = X
collection_anomaly_scores = self._inner_predict(_X)

return collection_anomaly_scores

def _inner_predict(self, X: np.ndarray) -> np.ndarray:
"""
Return the anomaly scores for the input data.

Parameters
----------
X (array-like): The input data.

Returns
-------
np.ndarray: The predicted probabilities.

"""
y_scores = np.zeros((len(X), self.n_estimators))
# Transform into rocket feature space
# XT: (n_cases, n_kernels*2)
Xt = self.rocket_transformer_.transform(X)

Xt = Xt.astype(np.float64)

if self.power_transformer_ is not None:
# Power Transform using yeo-johnson
Xtp = self.power_transformer_.transform(Xt)

else:
Xtp = Xt

for idx, bagger in enumerate(self.list_baggers_):
# Get scores from each estimator
distances, _ = bagger.kneighbors(Xtp)

# Compute mean distance of nearest points in window
scores = distances.mean(axis=1).reshape(-1, 1)
scores = scores.squeeze()

y_scores[:, idx] = scores

# Average the scores to get the final score for each whole-series
collection_anomaly_scores = y_scores.mean(axis=1)

return collection_anomaly_scores
1 change: 1 addition & 0 deletions aeon/anomaly_detection/collection/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for whole-series anomaly detection."""
69 changes: 69 additions & 0 deletions aeon/anomaly_detection/collection/tests/test_rockad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Tests for the ROCKAD anomaly detector."""

import numpy as np
import pytest
from sklearn.utils import check_random_state

from aeon.anomaly_detection.collection import ROCKAD


def test_rockad_univariate():
"""Test ROCKAD univariate output."""
rng = check_random_state(seed=2)
train_series = rng.normal(loc=0.0, scale=1.0, size=(10, 100))
test_series = rng.normal(loc=0.0, scale=1.0, size=(5, 100))

test_series[0][50:58] -= 5

ad = ROCKAD(n_estimators=100, n_kernels=10, n_neighbors=9)

ad.fit(train_series)
pred = ad.predict(test_series)

assert pred.shape == (5,)
assert pred.dtype == np.float64
assert 0 <= np.argmax(pred) <= 1


def test_rockad_multivariate():
"""Test ROCKAD multivariate output."""
rng = check_random_state(seed=2)
train_series = rng.normal(loc=0.0, scale=1.0, size=(10, 3, 100))
test_series = rng.normal(loc=0.0, scale=1.0, size=(5, 3, 100))

test_series[0][0][50:58] -= 5

ad = ROCKAD(n_estimators=1000, n_kernels=100, n_neighbors=9)

ad.fit(train_series)
pred = ad.predict(test_series)

assert pred.shape == (5,)
assert pred.dtype == np.float64
assert 0 <= np.argmax(pred) <= 1


def test_rockad_incorrect_input():
"""Test ROCKAD with invalid inputs."""
rng = check_random_state(seed=2)
series = rng.normal(size=(10, 5))

with pytest.warns(
UserWarning, match=r"Power Transform failed and thus has been disabled."
):
ad = ROCKAD()
ad.fit(series)

train_series = rng.normal(loc=0.0, scale=1.0, size=(10, 100))
test_series = rng.normal(loc=0.0, scale=1.0, size=(3, 100))

with pytest.raises(
ValueError,
match=(
r"Expected n_neighbors <= n_samples_fit, but n_neighbors = 100, "
r"n_samples_fit = 10, n_samples = 3"
),
):
ad = ROCKAD(n_estimators=100, n_kernels=10, n_neighbors=100)
ad.fit(train_series)
ad.predict(test_series)