Skip to content

Commit 9fd6893

Browse files
authored
Merge pull request #196 from fooof-tools/df
[ENH] - Add support for converting model results, including to DFs
2 parents 208fc09 + 7ba9834 commit 9fd6893

File tree

10 files changed

+246
-4
lines changed

10 files changed

+246
-4
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ There are also optional dependencies, which are not required for model fitting i
8282

8383
- `matplotlib <https://github.com/matplotlib/matplotlib>`_ is needed to visualize data and model fits
8484
- `tqdm <https://github.com/tqdm/tqdm>`_ is needed to print progress bars when fitting many models
85+
- `pandas <https://github.com/pandas-dev/pandas>`_ is needed to for exporting model fit results to dataframes
8586
- `pytest <https://github.com/pytest-dev/pytest>`_ is needed to run the test suite locally
8687

8788
We recommend using the `Anaconda <https://www.anaconda.com/distribution/>`_ distribution to manage these requirements.

fooof/data/conversions.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Conversion functions for organizing model results into alternate representations."""
2+
3+
import numpy as np
4+
5+
from fooof import Bands
6+
from fooof.core.funcs import infer_ap_func
7+
from fooof.core.info import get_ap_indices, get_peak_indices
8+
from fooof.core.modutils import safe_import, check_dependency
9+
from fooof.analysis.periodic import get_band_peak
10+
11+
pd = safe_import('pandas')
12+
13+
###################################################################################################
14+
###################################################################################################
15+
16+
def model_to_dict(fit_results, peak_org):
17+
"""Convert model fit results to a dictionary.
18+
19+
Parameters
20+
----------
21+
fit_results : FOOOFResults
22+
Results of a model fit.
23+
peak_org : int or Bands
24+
How to organize peaks.
25+
If int, extracts the first n peaks.
26+
If Bands, extracts peaks based on band definitions.
27+
28+
Returns
29+
-------
30+
dict
31+
Model results organized into a dictionary.
32+
"""
33+
34+
fr_dict = {}
35+
36+
# aperiodic parameters
37+
for label, param in zip(get_ap_indices(infer_ap_func(fit_results.aperiodic_params)),
38+
fit_results.aperiodic_params):
39+
fr_dict[label] = param
40+
41+
# periodic parameters
42+
peaks = fit_results.peak_params
43+
44+
if isinstance(peak_org, int):
45+
46+
if len(peaks) < peak_org:
47+
nans = [np.array([np.nan] * 3) for ind in range(peak_org-len(peaks))]
48+
peaks = np.vstack((peaks, nans))
49+
50+
for ind, peak in enumerate(peaks[:peak_org, :]):
51+
for pe_label, pe_param in zip(get_peak_indices(), peak):
52+
fr_dict[pe_label.lower() + '_' + str(ind)] = pe_param
53+
54+
elif isinstance(peak_org, Bands):
55+
for band, f_range in peak_org:
56+
for label, param in zip(get_peak_indices(), get_band_peak(peaks, f_range)):
57+
fr_dict[band + '_' + label.lower()] = param
58+
59+
# goodness-of-fit metrics
60+
fr_dict['error'] = fit_results.error
61+
fr_dict['r_squared'] = fit_results.r_squared
62+
63+
return fr_dict
64+
65+
@check_dependency(pd, 'pandas')
66+
def model_to_dataframe(fit_results, peak_org):
67+
"""Convert model fit results to a dataframe.
68+
69+
Parameters
70+
----------
71+
fit_results : FOOOFResults
72+
Results of a model fit.
73+
peak_org : int or Bands
74+
How to organize peaks.
75+
If int, extracts the first n peaks.
76+
If Bands, extracts peaks based on band definitions.
77+
78+
Returns
79+
-------
80+
pd.Series
81+
Model results organized into a dataframe.
82+
"""
83+
84+
return pd.Series(model_to_dict(fit_results, peak_org))
85+
86+
87+
@check_dependency(pd, 'pandas')
88+
def group_to_dataframe(fit_results, peak_org):
89+
"""Convert a group of model fit results into a dataframe.
90+
91+
Parameters
92+
----------
93+
fit_results : list of FOOOFResults
94+
List of FOOOFResults objects.
95+
peak_org : int or Bands
96+
How to organize peaks.
97+
If int, extracts the first n peaks.
98+
If Bands, extracts peaks based on band definitions.
99+
100+
Returns
101+
-------
102+
pd.DataFrame
103+
Model results organized into a dataframe.
104+
"""
105+
106+
return pd.DataFrame([model_to_dataframe(f_res, peak_org) for f_res in fit_results])

fooof/objs/fit.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from fooof.utils.data import trim_spectrum
7878
from fooof.utils.params import compute_gauss_std
7979
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
80+
from fooof.data.conversions import model_to_dataframe
8081
from fooof.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model
8182

8283
###################################################################################################
@@ -716,6 +717,25 @@ def set_check_data_mode(self, check_data):
716717
self._check_data = check_data
717718

718719

720+
def to_df(self, peak_org):
721+
"""Convert and extract the model results as a pandas object.
722+
723+
Parameters
724+
----------
725+
peak_org : int or Bands
726+
How to organize peaks.
727+
If int, extracts the first n peaks.
728+
If Bands, extracts peaks based on band definitions.
729+
730+
Returns
731+
-------
732+
pd.Series
733+
Model results organized into a pandas object.
734+
"""
735+
736+
return model_to_dataframe(self.get_results(), peak_org)
737+
738+
719739
def _check_width_limits(self):
720740
"""Check and warn about peak width limits / frequency resolution interaction."""
721741

fooof/objs/group.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from fooof.core.strings import gen_results_fg_str
2222
from fooof.core.io import save_fg, load_jsonlines
2323
from fooof.core.modutils import copy_doc_func_to_method, safe_import
24+
from fooof.data.conversions import group_to_dataframe
2425

2526
###################################################################################################
2627
###################################################################################################
@@ -541,6 +542,25 @@ def print_results(self, concise=False):
541542
print(gen_results_fg_str(self, concise))
542543

543544

545+
def to_df(self, peak_org):
546+
"""Convert and extract the model results as a pandas object.
547+
548+
Parameters
549+
----------
550+
peak_org : int or Bands
551+
How to organize peaks.
552+
If int, extracts the first n peaks.
553+
If Bands, extracts peaks based on band definitions.
554+
555+
Returns
556+
-------
557+
pd.DataFrame
558+
Model results organized into a pandas object.
559+
"""
560+
561+
return group_to_dataframe(self.get_results(), peak_org)
562+
563+
544564
def _fit(self, *args, **kwargs):
545565
"""Create an alias to FOOOF.fit for FOOOFGroup object, for internal use."""
546566

fooof/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fooof.core.modutils import safe_import
1010

11-
from fooof.tests.tutils import get_tfm, get_tfg, get_tbands
11+
from fooof.tests.tutils import get_tfm, get_tfg, get_tbands, get_tresults
1212
from fooof.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH,
1313
TEST_REPORTS_PATH, TEST_PLOTS_PATH)
1414

@@ -48,7 +48,16 @@ def tfg():
4848
def tbands():
4949
yield get_tbands()
5050

51+
@pytest.fixture(scope='session')
52+
def tresults():
53+
yield get_tresults()
54+
5155
@pytest.fixture(scope='session')
5256
def skip_if_no_mpl():
5357
if not safe_import('matplotlib'):
5458
pytest.skip('Matplotlib not available: skipping test.')
59+
60+
@pytest.fixture(scope='session')
61+
def skip_if_no_pandas():
62+
if not safe_import('pandas'):
63+
pytest.skip('Pandas not available: skipping test.')

fooof/tests/data/test_conversions.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Tests for the fooof.data.conversions."""
2+
3+
from copy import deepcopy
4+
5+
import numpy as np
6+
7+
from fooof.core.modutils import safe_import
8+
pd = safe_import('pandas')
9+
10+
from fooof.data.conversions import *
11+
12+
###################################################################################################
13+
###################################################################################################
14+
15+
def test_model_to_dict(tresults, tbands):
16+
17+
out = model_to_dict(tresults, peak_org=1)
18+
assert isinstance(out, dict)
19+
assert 'cf_0' in out
20+
assert out['cf_0'] == tresults.peak_params[0, 0]
21+
assert not 'cf_1' in out
22+
23+
out = model_to_dict(tresults, peak_org=2)
24+
assert 'cf_0' in out
25+
assert 'cf_1' in out
26+
assert out['cf_1'] == tresults.peak_params[1, 0]
27+
28+
out = model_to_dict(tresults, peak_org=3)
29+
assert 'cf_2' in out
30+
assert np.isnan(out['cf_2'])
31+
32+
out = model_to_dict(tresults, peak_org=tbands)
33+
assert 'alpha_cf' in out
34+
35+
def test_model_to_dataframe(tresults, tbands, skip_if_no_pandas):
36+
37+
for peak_org in [1, 2, 3]:
38+
out = model_to_dataframe(tresults, peak_org=peak_org)
39+
assert isinstance(out, pd.Series)
40+
41+
out = model_to_dataframe(tresults, peak_org=tbands)
42+
assert isinstance(out, pd.Series)
43+
44+
def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas):
45+
46+
fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]
47+
48+
for peak_org in [1, 2, 3]:
49+
out = group_to_dataframe(fit_results, peak_org=peak_org)
50+
assert isinstance(out, pd.DataFrame)
51+
52+
out = group_to_dataframe(fit_results, peak_org=tbands)
53+
assert isinstance(out, pd.DataFrame)

fooof/tests/objs/test_fit.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from fooof.core.items import OBJ_DESC
1313
from fooof.core.errors import FitError
1414
from fooof.core.utils import group_three
15+
from fooof.core.modutils import safe_import
16+
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
1517
from fooof.sim import gen_freqs, gen_power_spectrum
1618
from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults
17-
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
19+
20+
pd = safe_import('pandas')
1821

1922
from fooof.tests.settings import TEST_DATA_PATH
2023
from fooof.tests.tutils import get_tfm, plot_test
@@ -425,3 +428,10 @@ def test_fooof_check_data():
425428
# Model fitting should execute, but return a null model fit, given the NaNs, without failing
426429
tfm.fit()
427430
assert not tfm.has_model
431+
432+
def test_fooof_to_df(tfm, tbands, skip_if_no_pandas):
433+
434+
df1 = tfm.to_df(2)
435+
assert isinstance(df1, pd.Series)
436+
df2 = tfm.to_df(tbands)
437+
assert isinstance(df2, pd.Series)

fooof/tests/objs/test_group.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
import numpy as np
1010
from numpy.testing import assert_equal
1111

12-
from fooof.data import FOOOFResults
1312
from fooof.core.items import OBJ_DESC
13+
from fooof.core.modutils import safe_import
14+
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
15+
from fooof.data import FOOOFResults
1416
from fooof.sim import gen_group_power_spectra
1517

18+
pd = safe_import('pandas')
19+
1620
from fooof.tests.settings import TEST_DATA_PATH
1721
from fooof.tests.tutils import default_group_params, plot_test
1822

@@ -349,3 +353,10 @@ def test_fg_get_group(tfg):
349353
# Check that the correct results are extracted
350354
assert [tfg.group_results[ind] for ind in inds1] == nfg1.group_results
351355
assert [tfg.group_results[ind] for ind in inds2] == nfg2.group_results
356+
357+
def test_fg_to_df(tfg, tbands, skip_if_no_pandas):
358+
359+
df1 = tfg.to_df(2)
360+
assert isinstance(df1, pd.DataFrame)
361+
df2 = tfg.to_df(tbands)
362+
assert isinstance(df2, pd.DataFrame)

fooof/tests/tutils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
from functools import wraps
44

5+
import numpy as np
6+
57
from fooof.bands import Bands
8+
from fooof.data import FOOOFResults
69
from fooof.objs import FOOOF, FOOOFGroup
710
from fooof.core.modutils import safe_import
811
from fooof.sim.params import param_sampler
@@ -43,6 +46,14 @@ def get_tbands():
4346

4447
return Bands({'theta' : (4, 8), 'alpha' : (8, 12), 'beta' : (13, 30)})
4548

49+
def get_tresults():
50+
"""Get a FOOOFResults objet, for testing."""
51+
52+
return FOOOFResults(aperiodic_params=np.array([1.0, 1.00]),
53+
peak_params=np.array([[10.0, 1.25, 2.0], [20.0, 1.0, 3.0]]),
54+
r_squared=0.97, error=0.01,
55+
gaussian_params=np.array([[10.0, 1.25, 1.0], [20.0, 1.0, 1.5]]))
56+
4657
def default_group_params():
4758
"""Create default parameters for generating a test group of power spectra."""
4859

optional-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
matplotlib
2-
tqdm
2+
tqdm
3+
pandas

0 commit comments

Comments
 (0)