Skip to content

Commit 856ab5c

Browse files
committed
Merge branch 'fit' into results
2 parents 31d03a7 + f767c28 commit 856ab5c

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

fooof/core/funcs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def get_pe_func(periodic_mode):
167167
168168
"""
169169

170-
if periodic_mode == 'gaussian':
170+
if isinstance(periodic_mode, function):
171+
pe_func = periodic_mode
172+
elif periodic_mode == 'gaussian':
171173
pe_func = gaussian_function
172174
else:
173175
raise ValueError("Requested periodic mode not understood.")
@@ -194,7 +196,9 @@ def get_ap_func(aperiodic_mode):
194196
If the specified aperiodic mode label is not understood.
195197
"""
196198

197-
if aperiodic_mode == 'fixed':
199+
if isinstance(aperiodic_mode, function):
200+
ap_func = aperiodic_mode
201+
elif aperiodic_mode == 'fixed':
198202
ap_func = expo_nk_function
199203
elif aperiodic_mode == 'knee':
200204
ap_func = expo_function

fooof/objs/fit.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from fooof.core.reports import save_report_fm
6868
from fooof.core.modutils import copy_doc_func_to_method
6969
from fooof.core.utils import group_three, check_array_dim
70-
from fooof.core.funcs import gaussian_function, get_ap_func, infer_ap_func
70+
from fooof.core.funcs import get_pe_func, get_ap_func, infer_ap_func
7171
from fooof.core.errors import (FitError, NoModelError, DataError,
7272
NoDataError, InconsistentDataError)
7373
from fooof.core.strings import (gen_settings_str, gen_results_fm_str,
@@ -154,8 +154,9 @@ class FOOOF():
154154
"""
155155
# pylint: disable=attribute-defined-outside-init
156156

157-
def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0,
158-
peak_threshold=2.0, aperiodic_mode='fixed', verbose=True):
157+
def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf,
158+
min_peak_height=0.0, peak_threshold=2.0, aperiodic_mode='fixed',
159+
periodic_mode='gaussian', verbose=True):
159160
"""Initialize object with desired settings."""
160161

161162
# Set input settings
@@ -164,6 +165,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
164165
self.min_peak_height = min_peak_height
165166
self.peak_threshold = peak_threshold
166167
self.aperiodic_mode = aperiodic_mode
168+
self.periodic_mode = periodic_mode
167169
self.verbose = verbose
168170

169171
## PRIVATE SETTINGS
@@ -439,6 +441,9 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
439441
if self.verbose:
440442
self._check_width_limits()
441443

444+
# Determine the aperiodic and periodic fit funcs
445+
self._set_fit_funcs()
446+
442447
# In rare cases, the model fails to fit, and so uses try / except
443448
try:
444449

@@ -715,6 +720,11 @@ def set_check_data_mode(self, check_data):
715720

716721
self._check_data = check_data
717722

723+
def _set_fit_funcs(self):
724+
"""Set the requested aperiodic and periodic fit functions."""
725+
726+
self._pe_func = get_pe_func(self.periodic_mode)
727+
self._ap_func = get_ap_func(self.aperiodic_mode)
718728

719729
def _check_width_limits(self):
720730
"""Check and warn about peak width limits / frequency resolution interaction."""
@@ -762,8 +772,7 @@ def _simple_ap_fit(self, freqs, power_spectrum):
762772
try:
763773
with warnings.catch_warnings():
764774
warnings.simplefilter("ignore")
765-
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
766-
freqs, power_spectrum, p0=guess,
775+
aperiodic_params, _ = curve_fit(self._ap_func, freqs, power_spectrum, p0=guess,
767776
maxfev=self._maxfev, bounds=ap_bounds)
768777
except RuntimeError:
769778
raise FitError("Model fitting failed due to not finding parameters in "
@@ -818,9 +827,8 @@ def _robust_ap_fit(self, freqs, power_spectrum):
818827
try:
819828
with warnings.catch_warnings():
820829
warnings.simplefilter("ignore")
821-
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
822-
freqs_ignore, spectrum_ignore, p0=popt,
823-
maxfev=self._maxfev, bounds=ap_bounds)
830+
aperiodic_params, _ = curve_fit(self._ap_func, freqs_ignore, spectrum_ignore,
831+
p0=popt, maxfev=self._maxfev, bounds=ap_bounds)
824832
except RuntimeError:
825833
raise FitError("Model fitting failed due to not finding "
826834
"parameters in the robust aperiodic fit.")
@@ -904,7 +912,7 @@ def _fit_peaks(self, flat_iter):
904912

905913
# Collect guess parameters and subtract this guess gaussian from the data
906914
guess = np.vstack((guess, (guess_freq, guess_height, guess_std)))
907-
peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std)
915+
peak_gauss = self._pe_func(self.freqs, guess_freq, guess_height, guess_std)
908916
flat_iter = flat_iter - peak_gauss
909917

910918
# Check peaks based on edges, and on overlap, dropping any that violate requirements
@@ -963,7 +971,7 @@ def _fit_peak_guess(self, guess):
963971

964972
# Fit the peaks
965973
try:
966-
gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat,
974+
gaussian_params, _ = curve_fit(self._pe_func, self.freqs, self._spectrum_flat,
967975
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds)
968976
except RuntimeError:
969977
raise FitError("Model fitting failed due to not finding "

0 commit comments

Comments
 (0)