@@ -321,7 +321,7 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False
321
321
self .print_results (False )
322
322
323
323
324
- def fit (self , freqs = None , power_spectrum = None , freq_range = None ):
324
+ def fit (self , freqs = None , power_spectrum = None , freq_range = None , ap_range = None ):
325
325
"""Fit the full power spectrum as a combination of periodic and aperiodic components.
326
326
327
327
Parameters
@@ -340,7 +340,7 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
340
340
341
341
# If freqs & power_spectrum provided together, add data to object.
342
342
if freqs is not None and power_spectrum is not None :
343
- self .add_data (freqs , power_spectrum , freq_range )
343
+ self .add_data (freqs , power_spectrum , freq_range if ap_range is None else None )
344
344
# If power spectrum provided alone, add to object, and use existing frequency data
345
345
# Note: be careful passing in power_spectrum data like this:
346
346
# It assumes the power_spectrum is already logged, with correct freq_range.
@@ -359,15 +359,37 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
359
359
# Cause of failure: RuntimeError, failure to find parameters in curve_fit
360
360
try :
361
361
362
+ if ap_range :
363
+ ap_inds = (self .freqs >= ap_range [0 ]) & (self .freqs <= ap_range [1 ])
364
+ ap_freqs = self .freqs [ap_inds ]
365
+ ap_spectrum = self .power_spectrum [ap_inds ]
366
+ else :
367
+ ap_freqs = self .freqs
368
+ ap_spectrum = self .power_spectrum
369
+
362
370
# Fit the aperiodic component
363
- self .aperiodic_params_ = self ._robust_ap_fit (self . freqs , self . power_spectrum )
371
+ self .aperiodic_params_ = self ._robust_ap_fit (ap_freqs , ap_spectrum )
364
372
self ._ap_fit = gen_aperiodic (self .freqs , self .aperiodic_params_ )
365
373
366
374
# Flatten the power_spectrum using fit aperiodic fit
367
375
self ._spectrum_flat = self .power_spectrum - self ._ap_fit
376
+
377
+ if ap_range :
378
+ per_inds = (self .freqs >= ap_range [0 ]) & (self .freqs <= ap_range [1 ])
379
+ freqs_0 = self .freqs
380
+ self .freqs = self .freqs [per_inds ]
381
+ per_spectrum_flat = self ._spectrum_flat [per_inds ]
382
+ if freq_range :
383
+ self .freq_range = freq_range
384
+ else :
385
+ per_spectrum_flat = np .copy (self ._spectrum_flat )
386
+
368
387
369
388
# Find peaks, and fit them with gaussians
370
- self .gaussian_params_ = self ._fit_peaks (np .copy (self ._spectrum_flat ))
389
+ self .gaussian_params_ = self ._fit_peaks (per_spectrum_flat )
390
+
391
+ if ap_range :
392
+ self .freqs = freqs_0
371
393
372
394
# Calculate the peak fit
373
395
# Note: if no peaks are found, this creates a flat (all zero) peak fit.
0 commit comments