67
67
from fooof .core .reports import save_report_fm
68
68
from fooof .core .modutils import copy_doc_func_to_method
69
69
from fooof .core .utils import group_three , check_array_dim
70
- from fooof .core .funcs import gaussian_function , get_ap_func , infer_ap_func
70
+ from fooof .core .funcs import get_pe_func , get_ap_func , infer_ap_func
71
71
from fooof .core .errors import (FitError , NoModelError , DataError ,
72
72
NoDataError , InconsistentDataError )
73
73
from fooof .core .strings import (gen_settings_str , gen_results_fm_str ,
@@ -154,8 +154,9 @@ class FOOOF():
154
154
"""
155
155
# pylint: disable=attribute-defined-outside-init
156
156
157
- def __init__ (self , peak_width_limits = (0.5 , 12.0 ), max_n_peaks = np .inf , min_peak_height = 0.0 ,
158
- peak_threshold = 2.0 , aperiodic_mode = 'fixed' , verbose = True ):
157
+ def __init__ (self , peak_width_limits = (0.5 , 12.0 ), max_n_peaks = np .inf ,
158
+ min_peak_height = 0.0 , peak_threshold = 2.0 , aperiodic_mode = 'fixed' ,
159
+ periodic_mode = 'gaussian' , verbose = True ):
159
160
"""Initialize object with desired settings."""
160
161
161
162
# Set input settings
@@ -164,6 +165,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
164
165
self .min_peak_height = min_peak_height
165
166
self .peak_threshold = peak_threshold
166
167
self .aperiodic_mode = aperiodic_mode
168
+ self .periodic_mode = periodic_mode
167
169
self .verbose = verbose
168
170
169
171
## PRIVATE SETTINGS
@@ -439,6 +441,9 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
439
441
if self .verbose :
440
442
self ._check_width_limits ()
441
443
444
+ # Determine the aperiodic and periodic fit funcs
445
+ self ._set_fit_funcs ()
446
+
442
447
# In rare cases, the model fails to fit, and so uses try / except
443
448
try :
444
449
@@ -715,6 +720,11 @@ def set_check_data_mode(self, check_data):
715
720
716
721
self ._check_data = check_data
717
722
723
+ def _set_fit_funcs (self ):
724
+ """Set the requested aperiodic and periodic fit functions."""
725
+
726
+ self ._pe_func = get_pe_func (self .periodic_mode )
727
+ self ._ap_func = get_ap_func (self .aperiodic_mode )
718
728
719
729
def _check_width_limits (self ):
720
730
"""Check and warn about peak width limits / frequency resolution interaction."""
@@ -762,8 +772,7 @@ def _simple_ap_fit(self, freqs, power_spectrum):
762
772
try :
763
773
with warnings .catch_warnings ():
764
774
warnings .simplefilter ("ignore" )
765
- aperiodic_params , _ = curve_fit (get_ap_func (self .aperiodic_mode ),
766
- freqs , power_spectrum , p0 = guess ,
775
+ aperiodic_params , _ = curve_fit (self ._ap_func , freqs , power_spectrum , p0 = guess ,
767
776
maxfev = self ._maxfev , bounds = ap_bounds )
768
777
except RuntimeError :
769
778
raise FitError ("Model fitting failed due to not finding parameters in "
@@ -818,9 +827,8 @@ def _robust_ap_fit(self, freqs, power_spectrum):
818
827
try :
819
828
with warnings .catch_warnings ():
820
829
warnings .simplefilter ("ignore" )
821
- aperiodic_params , _ = curve_fit (get_ap_func (self .aperiodic_mode ),
822
- freqs_ignore , spectrum_ignore , p0 = popt ,
823
- maxfev = self ._maxfev , bounds = ap_bounds )
830
+ aperiodic_params , _ = curve_fit (self ._ap_func , freqs_ignore , spectrum_ignore ,
831
+ p0 = popt , maxfev = self ._maxfev , bounds = ap_bounds )
824
832
except RuntimeError :
825
833
raise FitError ("Model fitting failed due to not finding "
826
834
"parameters in the robust aperiodic fit." )
@@ -904,7 +912,7 @@ def _fit_peaks(self, flat_iter):
904
912
905
913
# Collect guess parameters and subtract this guess gaussian from the data
906
914
guess = np .vstack ((guess , (guess_freq , guess_height , guess_std )))
907
- peak_gauss = gaussian_function (self .freqs , guess_freq , guess_height , guess_std )
915
+ peak_gauss = self . _pe_func (self .freqs , guess_freq , guess_height , guess_std )
908
916
flat_iter = flat_iter - peak_gauss
909
917
910
918
# Check peaks based on edges, and on overlap, dropping any that violate requirements
@@ -963,7 +971,7 @@ def _fit_peak_guess(self, guess):
963
971
964
972
# Fit the peaks
965
973
try :
966
- gaussian_params , _ = curve_fit (gaussian_function , self .freqs , self ._spectrum_flat ,
974
+ gaussian_params , _ = curve_fit (self . _pe_func , self .freqs , self ._spectrum_flat ,
967
975
p0 = guess , maxfev = self ._maxfev , bounds = gaus_param_bounds )
968
976
except RuntimeError :
969
977
raise FitError ("Model fitting failed due to not finding "
0 commit comments