Skip to content

Commit 6623988

Browse files
Refactoring to individual classes
1 parent 37b4101 commit 6623988

10 files changed

+306
-488
lines changed

examples/run_example.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

src/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1-
from .signal_processor import SignalProcessor
1+
from .signal_generator import SignalGenerator
2+
from .signal_filter import SignalFilter
3+
from .signal_fitter import SignalFitter
4+
from .signal_visualizer import SignalVisualizer
5+
from .statistical_analyzer import StatisticalAnalyzer
6+
7+
__version__ = '1.6.0'

src/main.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
1-
from signal_processor import SignalProcessor
1+
from src.signal_generator import SignalGenerator
2+
from src.signal_visualizer import SignalVisualizer
23
import numpy as np
34

45
def runProcessing():
56
"""
6-
Run the signal processing steps using the default parameters of the SignalProcessor class.
7+
Run the signal processing steps using the default parameters of the SignalProcessor classes.
78
"""
89
timeVector = np.linspace(0, 1, 1000, endpoint = False) # Consider importing or modifying your time vector
9-
processor = SignalProcessor(timeVector)
10-
processor.generateNoisySignal()
11-
processor.applyFilter()
12-
processor.fitDampedSineWave()
13-
processor.performTTest()
14-
processor.plotResults()
15-
processor.plotInteractiveResults()
16-
processor.printResults()
10+
11+
noisyInstance = SignalGenerator(timeVector)
12+
13+
filteredInstance = noisyInstance.generateNoisySignal().applyFilter()
14+
fittedInstance = filteredInstance.fitDampedSineWave()
15+
analyzedInstance = fittedInstance.analyzeFit()
16+
17+
# Retrieve results
18+
noisySignal = noisyInstance.getNoisySignal()
19+
filteredSignal = filteredInstance.getFilteredSignal()
20+
fittedSignal = fittedInstance.getFittedSignal()
21+
tTestResults = analyzedInstance.getTTestResults()
22+
23+
print(f"T-test result: statistic={tTestResults[0]}, p-value={tTestResults[1]}")
24+
25+
visualizer = SignalVisualizer(timeVector, noisySignal, filteredSignal, fittedSignal)
26+
27+
visualizer.plotResults()
28+
visualizer.plotInteractiveResults()
1729

1830
if __name__ == "__main__":
1931
try:

src/signal_filter.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,78 @@
1+
from src.signal_fitter import SignalFitter
12
from scipy import signal
23
from typing import Optional
34
import numpy as np
45

56
class SignalFilter:
6-
def __init__(self, noisySignal: np.ndarray, sampleRate: float = 1.0):
7-
self.sampleRate = sampleRate
7+
def __init__(self, timeVector: np.ndarray, noisySignal: np.ndarray):
8+
"""
9+
Initialize with a noisy signal to filter.
10+
Defaults:
11+
filterOrder: Order of the Butterworth filter (default is 4).
12+
cutoffFrequency: Cutoff frequency for the low-pass filter (default is 0.2).
13+
filterType: Type of filter ('butter', 'bessel', 'highpass'). Default is 'butter'.
14+
:param timeVector: The time vector associated with the signal.
15+
:param noisySignal: The noisy signal to be filtered.
16+
"""
17+
self.timeVector = timeVector
818
self.noisySignal = noisySignal
919
self.filteredSignal: Optional[np.ndarray] = None
1020

21+
# Default parameters
22+
self.filterType: str = 'butter'
23+
self.filterOrder: int = 4
24+
self.cutOffFrequency: float = 0.2
25+
self.bType: str = 'low'
26+
1127
self.filterTypes = {
1228
"butter": lambda order, cutoff, btype: signal.butter(order, cutoff, btype, analog = False),
13-
"chebyshev1": lambda order, cutoff, btype: signal.cheby1(order, 0.5, cutoff, btype, analog = False),
14-
"chebyshev2": lambda order, cutoff, btype: signal.cheby2(order, 20, cutoff, btype, analog = False),
15-
"elliptic": lambda order, cutoff , btype: signal.ellip(order, 0.5, 20, cutoff, analog = False),
1629
"bessel": lambda order, cutoff, btype: signal.bessel(order, cutoff, btype, analog = False),
17-
"notch": lambda notchFreq, Q: signal.iirnotch(notchFreq, Q, fs = self.sampleRate),
18-
"highpass": lambda order, cutoff, btype: signal.butter(order, cutoff, btype, analog = False),
19-
"bandpass": lambda low_cutoff, high_cutoff, btype: signal.butter(4, [low_cutoff, high_cutoff], btype, analog = False),
20-
"bandstop": lambda low_cutoff, high_cutoff, btype: signal.butter(4, [low_cutoff, high_cutoff], btype, analog = False)
30+
"highpass": lambda order, cutoff, btype: signal.butter(order, cutoff, btype, analog = False)
2131
}
2232

23-
def apply_filter(self, filterType: str = 'butter', **kwargs) -> np.ndarray:
33+
self.applyFilter(order = self.filterOrder, cutoff = self.cutOffFrequency, btype = self.bType)
34+
35+
def setFilterParameters(self, filterType: str, filterOrder: int, cutOffFrequency: float, bType: str) -> None:
2436
"""
25-
Apply the specified filter type.
26-
:param filterType: Type of filter ('butter', 'chebyshev1', 'chebyshev2', 'elliptic', 'bessel', 'notch', 'highpass', 'bandpass', 'bandstop').
27-
:param kwargs: Additional parameters for specific filters.
28-
:return: Filtered signal as a numpy array.
37+
Set or change filter parameters.
38+
:param filterType: Type of filter ('butter', 'bessel', 'highpass').
39+
:param filterOrder: Order of the filter.
40+
:param cutOffFrequency: Cutoff frequency for the filter.
41+
:param bType: Filter band type ('lowpass', 'highpass', etc.).
2942
"""
30-
if filterType not in self.filterTypes:
31-
raise ValueError(f"Filter type '{filterType}' is not recognized.")
43+
self.filterType = filterType
44+
self.throwIfNotSupported()
45+
self.filterOrder = filterOrder
46+
self.cutOffFrequency = cutOffFrequency
47+
self.bType = bType
3248

33-
if self.noisySignal is None:
34-
raise ValueError("Noisy signal is not generated. Please call 'generateNoisySignal' first.")
35-
36-
# Design Butterworth low-pass filter
37-
[filterCoefficientsB, filterCoefficientsA] = self.filterTypes[filterType](**kwargs)
49+
def throwIfNotSupported(self):
50+
"""
51+
Raise an error if the filter type is not supported.
52+
"""
53+
if self.filterType not in self.filterTypes:
54+
raise ValueError(f"Filter type '{self.filterType}' is not recognized.")
3855

56+
def applyFilter(self, **kwargs) -> 'SignalFilter':
57+
"""
58+
Apply the configured filter to the noisy signal.
59+
:param kwargs: Additional parameters for specific filters. Given as Dict with keys: Order, cutoff, btype
60+
"""
61+
[filterCoefficientsB, filterCoefficientsA] = self.filterTypes[self.filterType](**kwargs)
3962
self.filteredSignal = signal.filtfilt(filterCoefficientsB, filterCoefficientsA, self.noisySignal)
40-
4163
print("Filter applied.")
64+
return self
4265

43-
return self.filteredSignal
66+
def fitDampedSineWave(self) -> SignalFitter:
67+
"""
68+
Create a SignalFitter instance to fit a damped sine wave to the filtered signal.
69+
:return: A SignalFitter instance.
70+
"""
71+
return SignalFitter(self.timeVector, self.filteredSignal)
4472

4573
def getFilteredSignal(self) -> Optional[np.ndarray]:
74+
"""
75+
Retrieve the filtered signal.
76+
:return: The filtered signal or None if filtering hasn't been performed.
77+
"""
4678
return self.filteredSignal

src/signal_fitter.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
from src.statistical_analyzer import StatisticalAnalyzer
12
from scipy import optimize
23
import numpy as np
34
from typing import Optional
45

56
class SignalFitter:
67
def __init__(self, timeVector: np.ndarray, filteredSignal: np.ndarray):
8+
"""
9+
Initialize the SignalFitter with a time vector and a filtered signal.
10+
:param timeVector: The time vector associated with the signal.
11+
:param filteredSignal: The filtered signal to fit.
12+
"""
713
self.timeVector = timeVector
814
self.filteredSignal = filteredSignal
15+
916
self.optimalParamsDamped: Optional[np.ndarray] = None
1017
self.fittedSignalDamped: Optional[np.ndarray] = None
1118

@@ -18,6 +25,8 @@ def __init__(self, timeVector: np.ndarray, filteredSignal: np.ndarray):
1825
# Default bounds
1926
self.bounds: tuple = ([0, 0, -np.pi, 0], [np.inf, np.inf, np.pi, np.inf])
2027

28+
self.fitDampedSineWave()
29+
2130
def setDampedSineWaveParameters(self, amplitudeParam: float, frequencyParam: float, phaseParam: float, decayRateParam: float) -> None:
2231
"""
2332
Set the parameters for the damped sine wave fitting.
@@ -42,34 +51,56 @@ def setDampedSineWaveBounds(self, lower: list[float], upper: list[float]) -> Non
4251
raise ValueError("Bounds should be lists of length 4.")
4352
self.bounds = (lower, upper)
4453

45-
def fitDampedSineWave(self) -> np.ndarray:
54+
def fitDampedSineWave(self) -> 'SignalFitter':
4655
"""
4756
Fit a damped sine wave to the filtered signal using nonlinear least squares.
4857
default sine wave parameters: amplitudeParam = 1.0, frequencyParam = 10.0, phaseParam = 0.0, decayRateParam = 0.1
4958
change with setDampedSineWaveParameters().
5059
"""
60+
if self.filteredSignal is None or len(self.filteredSignal) == 0:
61+
raise ValueError("Filtered signal is None or empty. Cannot fit a damped sine wave.")
62+
5163
def dampedSineFunc(time: np.ndarray, amplitude: float, frequency: float, phase: float, decayRate: float) -> np.ndarray:
64+
"""Define the damped sine wave function."""
5265
return amplitude * np.exp(-decayRate * time) * np.sin(2 * np.pi * frequency * time + phase)
5366

5467
def residualsDamped(params: np.ndarray, time: np.ndarray, data: np.ndarray) -> np.ndarray:
68+
"""Calculate residuals for the damped sine wave fitting."""
5569
return dampedSineFunc(time, *params) - data
5670

5771
initialParams = np.array([self.amplitudeParam, self.frequencyParam, self.phaseParam, self.decayRateParam])
5872

5973
try:
6074
result = optimize.least_squares(residualsDamped, initialParams, bounds = self.bounds, args = (self.timeVector, self.filteredSignal))
6175
self.optimalParamsDamped = result.x
76+
6277
except Exception as e:
6378
raise RuntimeError(f"Optimization failed: {e}")
6479

6580
self.fittedSignalDamped = dampedSineFunc(self.timeVector, *self.optimalParamsDamped)
66-
6781
print("Damped sine wave fitted.")
68-
69-
return self.fittedSignalDamped
82+
return self
7083

7184
def getFittedSignal(self) -> Optional[np.ndarray]:
85+
"""
86+
Retrieve the fitted damped sine wave signal.
87+
:return: The fitted signal or None if fitting hasn't been performed.
88+
"""
7289
return self.fittedSignalDamped
7390

7491
def getOptimalParams(self) -> Optional[np.ndarray]:
92+
"""
93+
Retrieve the optimal parameters obtained from fitting a damped sine wave to the signal.
94+
returns Parameters: Amplitude (A), Frequency (f), Phase (φ), Decay rate (λ)
95+
:return: A numpy array containing the optimal parameters or None if fitting hasn't been performed.
96+
"""
7597
return self.optimalParamsDamped
98+
99+
def analyzeFit(self):
100+
"""
101+
Create a StatisticalAnalyzer instance to analyze the fitted signal against the filtered signal.
102+
:return: A StatisticalAnalyzer instance.
103+
"""
104+
if self.fittedSignalDamped is None:
105+
raise ValueError("Fitted signal is not available. Perform fitting first.")
106+
return StatisticalAnalyzer(self.filteredSignal, self.fittedSignalDamped)

src/signal_generator.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,48 @@
1+
from src.signal_filter import SignalFilter
12
from typing import Optional
23
import numpy as np
34

45
class SignalGenerator:
56
def __init__(self, timeVector: np.ndarray):
7+
"""
8+
Initialize the SignalGenerator with a time vector.
9+
"""
610
self.timeVector = timeVector
711
self.noisySignal: Optional[np.ndarray] = None
812

9-
def generateNoisySignal(self, frequency: float = 10, noiseStdDev: float = 0.5) -> np.ndarray:
13+
def generateNoisySignal(self, frequency: float = 10, noiseStdDev: float = 0.5) -> 'SignalGenerator':
1014
"""
1115
Generate a noisy signal using the defined time vector.
1216
:param frequency: Frequency of the sine wave (default is 10 Hz).
1317
:param noiseStdDev: Standard deviation of the noise (default is 0.5).
1418
"""
19+
if self.timeVector is None or len(self.timeVector) == 0:
20+
raise ValueError("Time array 'timeVector' is not properly initialised.")
21+
1522
self.noisySignal = np.sin(2 * np.pi * frequency * self.timeVector) + noiseStdDev * np.random.randn(len(self.timeVector))
16-
return self.noisySignal
23+
return self
1724

18-
def importNoisySignal(self, signal: np.ndarray) -> None:
25+
def importNoisySignal(self, signal: np.ndarray) -> 'SignalGenerator':
26+
"""
27+
Import an externally generated noisy signal.
28+
:param signal: The noisy signal to import.
29+
"""
1930
self.noisySignal = signal
31+
return self
2032

21-
def getSignal(self) -> Optional[np.ndarray]:
33+
def getNoisySignal(self) -> Optional[np.ndarray]:
34+
"""
35+
Retrieve the current noisy signal.
36+
:return: The noisy signal or None if it hasn't been set.
37+
"""
2238
return self.noisySignal
39+
40+
def applyFilter(self) -> SignalFilter:
41+
"""
42+
Create a SignalFilter instance to filter the noisy signal.
43+
:return: A SignalFilter instance.
44+
"""
45+
if self.noisySignal is None:
46+
raise ValueError("Noisy signal has not been generated or imported.")
47+
return SignalFilter(self.timeVector, self.noisySignal)
48+

0 commit comments

Comments
 (0)