@@ -332,12 +332,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
332
332
Power values, which must be input in linear space.
333
333
freq_range : list of [float, float], optional
334
334
Frequency range to restrict power spectrum to. If not provided, keeps the entire range.
335
+ ap_range : list of [float, float], or np.ndarray of booleans of the same length as freqs, optional.
336
+ Frequency range to restrict aperiodic fit to. If not provided, it will be fit on the range specified
337
+ by freq_range.
335
338
336
339
Notes
337
340
-----
338
341
Data is optional if data has been already been added to FOOOF object.
339
342
"""
340
-
341
343
# If freqs & power_spectrum provided together, add data to object.
342
344
if freqs is not None and power_spectrum is not None :
343
345
self .add_data (freqs , power_spectrum , freq_range if ap_range is None else None )
@@ -358,9 +360,15 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
358
360
# In rare cases, the model fails to fit. Therefore it's in a try/except
359
361
# Cause of failure: RuntimeError, failure to find parameters in curve_fit
360
362
try :
363
+
364
+ if ap_range is not None :#isolate aperiodic frequencies/spectrum
365
+ if not isinstance (ap_range ,np .ndarray ) or ap_range .shape [- 1 ]== 2 :
366
+ ap_inds = (self .freqs >= ap_range [0 ]) & (self .freqs <= ap_range [1 ])
367
+ elif ap_range .shape [- 1 ]== self .freqs .shape [- 1 ]:
368
+ ap_inds = ap_range
369
+ else :
370
+ raise ValueError ('ap_range must have the same length as freqs - can not proceed' )
361
371
362
- if ap_range :
363
- ap_inds = (self .freqs >= ap_range [0 ]) & (self .freqs <= ap_range [1 ])
364
372
ap_freqs = self .freqs [ap_inds ]
365
373
ap_spectrum = self .power_spectrum [ap_inds ]
366
374
else :
@@ -374,22 +382,27 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
374
382
# Flatten the power_spectrum using fit aperiodic fit
375
383
self ._spectrum_flat = self .power_spectrum - self ._ap_fit
376
384
377
- if ap_range :
378
- per_inds = (self .freqs >= ap_range [0 ]) & (self .freqs <= ap_range [1 ])
385
+
386
+ if ap_range is not None :#isolate periodic frequencies/spectrum
387
+ per_inds = (self .freqs >= freq_range [0 ]) & (self .freqs <= freq_range [1 ])
388
+ per_spectrum_flat = np .copy (self ._spectrum_flat [per_inds ])
389
+ self ._spectrum_flat = per_spectrum_flat
390
+ #save/set some attributes so peak fitting works properly
379
391
freqs_0 = self .freqs
380
392
self .freqs = self .freqs [per_inds ]
381
- per_spectrum_flat = self ._spectrum_flat [per_inds ]
382
393
if freq_range :
394
+ freq_range_0 = self .freq_range
383
395
self .freq_range = freq_range
384
- else :
385
- per_spectrum_flat = np .copy (self ._spectrum_flat )
396
+
386
397
387
398
388
399
# Find peaks, and fit them with gaussians
389
- self .gaussian_params_ = self ._fit_peaks (per_spectrum_flat )
400
+ self .gaussian_params_ = self ._fit_peaks (np . copy ( self . _spectrum_flat ) )
390
401
391
- if ap_range :
392
- self .freqs = freqs_0
402
+ if ap_range is not None :
403
+ #restore attributes to initial values
404
+ self .freqs = freqs_0
405
+ self .freq_range = freq_range_0
393
406
394
407
# Calculate the peak fit
395
408
# Note: if no peaks are found, this creates a flat (all zero) peak fit.
@@ -398,9 +411,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
398
411
# Create peak-removed (but not flattened) power spectrum.
399
412
self ._spectrum_peak_rm = self .power_spectrum - self ._peak_fit
400
413
414
+ if ap_range is not None :
415
+ ap_spectrum_peak_rm = self ._spectrum_peak_rm [ap_inds ]
416
+ else :
417
+ ap_spectrum_peak_rm = self ._spectrum_peak_rm
418
+
401
419
# Run final aperiodic fit on peak-removed power spectrum
402
420
# Note: This overwrites previous aperiodic fit
403
- self .aperiodic_params_ = self ._simple_ap_fit (self . freqs , self . _spectrum_peak_rm )
421
+ self .aperiodic_params_ = self ._simple_ap_fit (ap_freqs , ap_spectrum_peak_rm )
404
422
self ._ap_fit = gen_aperiodic (self .freqs , self .aperiodic_params_ )
405
423
406
424
# Create full power_spectrum model fit
0 commit comments