Skip to content

Commit 29c7846

Browse files
author
Laurent Mackay
committed
Adds more flexibility in the aperiodic/oscillatory fitting
DESCRIPTION: I have added a keyword `ap_range` to the FOOOF.fit() function. When `ap_range` is specified, the already existing `freq_range` keyword is used to determine the frequencies for the oscillatory fits while the ap_range keyword is used for the aperiodic fit. When it is not specified, everything proceeds as before. This allows the two parts FO and OOF to be more-or-less independent of one another if so-desired. Furthermore, the `ap_range` keyword may also be a set of indices in order to allow for a very simple implementation of exclusion zones. TESTING: Tested on 5 different EEG files, no errors unless inputs are the wrong size. Update: I have also modified the FOOOFGroup.fit() in order to work with this keyword.
1 parent 9a7fabb commit 29c7846

File tree

3 files changed

+38
-20
lines changed

3 files changed

+38
-20
lines changed

fooof/fit.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
332332
Power values, which must be input in linear space.
333333
freq_range : list of [float, float], optional
334334
Frequency range to restrict power spectrum to. If not provided, keeps the entire range.
335+
ap_range : list of [float, float], or np.ndarray of booleans of the same length as freqs, optional.
336+
Frequency range to restrict aperiodic fit to. If not provided, it will be fit on the range specified
337+
by freq_range.
335338
336339
Notes
337340
-----
338341
Data is optional if data has been already been added to FOOOF object.
339342
"""
340-
341343
# If freqs & power_spectrum provided together, add data to object.
342344
if freqs is not None and power_spectrum is not None:
343345
self.add_data(freqs, power_spectrum, freq_range if ap_range is None else None)
@@ -358,9 +360,15 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
358360
# In rare cases, the model fails to fit. Therefore it's in a try/except
359361
# Cause of failure: RuntimeError, failure to find parameters in curve_fit
360362
try:
363+
364+
if ap_range is not None:#isolate aperiodic frequencies/spectrum
365+
if not isinstance(ap_range,np.ndarray) or ap_range.shape[-1]==2:
366+
ap_inds = (self.freqs >= ap_range[0]) & (self.freqs <= ap_range[1])
367+
elif ap_range.shape[-1]==self.freqs.shape[-1]:
368+
ap_inds = ap_range
369+
else:
370+
raise ValueError('ap_range must have the same length as freqs - can not proceed')
361371

362-
if ap_range:
363-
ap_inds = (self.freqs >= ap_range[0]) & (self.freqs <= ap_range[1])
364372
ap_freqs = self.freqs[ap_inds]
365373
ap_spectrum = self.power_spectrum[ap_inds]
366374
else:
@@ -374,22 +382,27 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
374382
# Flatten the power_spectrum using fit aperiodic fit
375383
self._spectrum_flat = self.power_spectrum - self._ap_fit
376384

377-
if ap_range:
378-
per_inds = (self.freqs >= ap_range[0]) & (self.freqs <= ap_range[1])
385+
386+
if ap_range is not None:#isolate periodic frequencies/spectrum
387+
per_inds = (self.freqs >= freq_range[0]) & (self.freqs <= freq_range[1])
388+
per_spectrum_flat = np.copy(self._spectrum_flat[per_inds])
389+
self._spectrum_flat = per_spectrum_flat
390+
#save/set some attributes so peak fitting works properly
379391
freqs_0 = self.freqs
380392
self.freqs = self.freqs[per_inds]
381-
per_spectrum_flat = self._spectrum_flat[per_inds]
382393
if freq_range:
394+
freq_range_0 = self.freq_range
383395
self.freq_range = freq_range
384-
else:
385-
per_spectrum_flat = np.copy(self._spectrum_flat)
396+
386397

387398

388399
# Find peaks, and fit them with gaussians
389-
self.gaussian_params_ = self._fit_peaks(per_spectrum_flat)
400+
self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat))
390401

391-
if ap_range:
392-
self.freqs=freqs_0
402+
if ap_range is not None:
403+
#restore attributes to initial values
404+
self.freqs = freqs_0
405+
self.freq_range = freq_range_0
393406

394407
# Calculate the peak fit
395408
# Note: if no peaks are found, this creates a flat (all zero) peak fit.
@@ -398,9 +411,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
398411
# Create peak-removed (but not flattened) power spectrum.
399412
self._spectrum_peak_rm = self.power_spectrum - self._peak_fit
400413

414+
if ap_range is not None:
415+
ap_spectrum_peak_rm = self._spectrum_peak_rm[ap_inds]
416+
else:
417+
ap_spectrum_peak_rm = self._spectrum_peak_rm
418+
401419
# Run final aperiodic fit on peak-removed power spectrum
402420
# Note: This overwrites previous aperiodic fit
403-
self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm)
421+
self.aperiodic_params_ = self._simple_ap_fit(ap_freqs, ap_spectrum_peak_rm)
404422
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)
405423

406424
# Create full power_spectrum model fit

fooof/funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def combine_fooofs(fooofs):
113113
return fg
114114

115115

116-
def fit_fooof_group_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
116+
def fit_fooof_group_3d(fg, freqs, power_spectra, freq_range=None, ap_range=None, n_jobs=1):
117117
"""Run FOOOFGroup across a 3D collection of power spectra.
118118
119119
Parameters
@@ -138,7 +138,7 @@ def fit_fooof_group_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
138138

139139
fgs = []
140140
for cond_spectra in power_spectra:
141-
fg.fit(freqs, cond_spectra, freq_range, n_jobs)
141+
fg.fit(freqs, cond_spectra, freq_range, ap_range, n_jobs)
142142
fgs.append(fg.copy())
143143

144144
return fgs

fooof/group.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1):
145145
self.print_results(False)
146146

147147

148-
def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1):
148+
def fit(self, freqs=None, power_spectra=None, freq_range=None, ap_range=None, n_jobs=1):
149149
"""Run FOOOF across a group of power_spectra.
150150
151151
Parameters
@@ -167,22 +167,22 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1):
167167

168168
# If freqs & power spectra provided together, add data to object.
169169
if freqs is not None and power_spectra is not None:
170-
self.add_data(freqs, power_spectra, freq_range)
170+
self.add_data(freqs, power_spectra, freq_range if ap_range is None else None)
171171

172172
# Run linearly
173173
if n_jobs == 1:
174174
self._reset_group_results(len(self.power_spectra))
175175
for ind, power_spectrum in \
176176
_progress(enumerate(self.power_spectra), self.verbose, len(self)):
177-
self._fit(power_spectrum=power_spectrum)
177+
self._fit(power_spectrum=power_spectrum, freq_range=freq_range, ap_range=ap_range)
178178
self.group_results[ind] = self._get_results()
179179

180180
# Run in parallel
181181
else:
182182
self._reset_group_results()
183183
n_jobs = cpu_count() if n_jobs == -1 else n_jobs
184184
with Pool(processes=n_jobs) as pool:
185-
self.group_results = list(_progress(pool.imap(partial(_par_fit, fg=self),
185+
self.group_results = list(_progress(pool.imap(partial(_par_fit, fg=self, freq_range=freq_range, ap_range=ap_range),
186186
self.power_spectra),
187187
self.verbose, len(self.power_spectra)))
188188

@@ -366,10 +366,10 @@ def _check_width_limits(self):
366366
###################################################################################################
367367
###################################################################################################
368368

369-
def _par_fit(power_spectrum, fg):
369+
def _par_fit(power_spectrum, fg, freq_range, ap_range):
370370
"""Helper function for running in parallel."""
371371

372-
fg._fit(power_spectrum=power_spectrum)
372+
fg._fit(power_spectrum=power_spectrum, freq_range=freq_range, ap_range=ap_range)
373373

374374
return fg._get_results()
375375

0 commit comments

Comments
 (0)