Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions dev_scripts/convert_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ def convert_param_files_from_repo(owner, repo, repo_path, local_path):
github account
repo : str
repository name
path : str
path to directory

repo_path : str
path to directory in repository
local_path : str
local path to save converted files

Returns
-------
None
None
"""
# Download param files
temp_dir = download_folder_contents(owner, repo, repo_path)
Expand Down
2 changes: 1 addition & 1 deletion examples/howto/optimize_rhythmic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from hnn_core import (MPIBackend, jones_2009_model, simulate_dipole)

# The number of cores may need modifying depending on your current machine.
n_procs = 10
n_procs = 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: despite the PR description, this line does change actual code. This means that there are other places that this PR could have changed actual code instead of documentation, and therefore this PR would need to be inspected very carefully for accidental code changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would this change be necessary? Seems out of scope for this PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was a mistake, resulting from testing example scripts on my local machine without reverting the changes.


###############################################################################
# First, we define a function that will tell the optimization routine how to
Expand Down
62 changes: 59 additions & 3 deletions hnn_core/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,32 @@


def _check_gids(gids, gid_ranges, valid_cells, arg_name, same_type=True):
"""Format different gid specifications into list of gids"""
"""Format different gid specifications into list of gids

Parameters
----------

gids : int, list, range, str, optional
The gids to check, can be a single gid or a list of gids.

gid_ranges : dict
Dict containing the gid ranges for each cell type.

valid_cells : list of str
List of valid cell type strings.

arg_name : str
The name of the argument to be checked.

same_type : bool, optional
If True, all gids must be of the same cell type. The default is True.

Returns
-------

gids : list of int
List of gids.
"""
_validate_type(
gids, (int, list, range, str, None), arg_name, "int list, range, str, or None"
)
Expand Down Expand Up @@ -37,14 +62,45 @@ def _check_gids(gids, gid_ranges, valid_cells, arg_name, same_type=True):


def _gid_to_type(gid, gid_ranges):
"""Reverse lookup of gid to type."""
"""Reverse lookup of gid to type.

Parameters
----------
gid : int
The gid to check.
gid_ranges : dict
Dict containing the gid ranges for each cell type.

Returns
-------
gidtype : str, None
The cell type of the gid, or None if not found.

"""
for gidtype, gids in gid_ranges.items():
if gid in gids:
return gidtype


def _string_input_to_list(input_str, valid_str, arg_name):
"""Convert input strings to list"""
"""Convert input strings to list

Parameters
----------
input_str : str, list of str, optional
The input string(s) to check.

valid_str : list of str
List of valid strings.

arg_name : str
The name of the argument to be checked.

Returns
-------
input_str : list of str
The input strings as a list.
"""
if input_str is None:
input_str = list()
elif isinstance(input_str, str):
Expand Down
4 changes: 2 additions & 2 deletions hnn_core/externals/bayesopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def bayes_opt(func, x0, cons, acquisition, maxfun=200, debug=False, random_state
Parameter constraints in solver-specific format.
acquisition : func
Acquisition function we want to use to find query points.
maxfun : int, optional
maxfun : int, optional, default = 200
Maximum number of function evaluations. The default is 200.
debug : bool, optional
debug : bool, optional, default = False
The default is False.
random_state : int, optional
Random state of the GaussianProcessRegressor. The default is None.
Expand Down
119 changes: 93 additions & 26 deletions hnn_core/externals/mne.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,17 +462,17 @@ def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False):
The sampling Frequency.
freqs : array
Frequency range of interest (1 x Frequencies).
n_cycles : float | array of float, default 7.0
n_cycles : float | array of float, default = 7.0
Number of cycles. Fixed number or one per frequency.
sigma : float, default None
sigma : float, optional
It controls the width of the wavelet ie its temporal
resolution. If sigma is None the temporal resolution
is adapted with the frequency like for all wavelet transform.
The higher the frequency the shorter is the wavelet.
If sigma is fixed the temporal resolution is fixed
like for the short time Fourier transform and the number
of oscillations increases with the frequency.
zero_mean : bool, default False
of oscillations increases with the frequency. By default, None.
zero_mean : bool, default = False
Make sure the wavelet has a mean of zero.
Returns
-------
Expand Down Expand Up @@ -524,14 +524,14 @@ def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True):
fsize : int
FFT length.
mode : {'full', 'valid', 'same'}
See numpy.convolve.
decim : int | slice, default 1
Method of convolution. See numpy.convolve.
decim : int | slice, default = 1
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note:: Decimation may create aliasing artifacts.
use_fft : bool, default True
use_fft : bool, default = True
Use the FFT for convolutions or not.
Returns
-------
Expand Down Expand Up @@ -597,7 +597,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
Ws : list, shape (n_tapers, n_wavelets, n_times)
The wavelets.
output : str
* 'complex' : single trial complex.
* 'complex' : single trial complex containing both amplitude and phase.
* 'power' : single trial power.
* 'phase' : single trial phase.
* 'avg_power' : average of single trial power.
Expand All @@ -607,9 +607,18 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
use_fft : bool
Use the FFT for convolutions or not.
mode : {'full', 'valid', 'same'}
See numpy.convolve.
Method of convolution. See numpy.convolve.
decim : slice
The decimation slice: e.g. power[:, decim]

Returns
-------
tfrs : ndarray
The time-frequency transform in the selected output format. If output is in ['complex', 'phase', 'power'], the shape is
(n_epochs, n_freqs, n_times). If output is in ['avg_power', 'itc', 'avg_power_itc'], the shape is
(n_freqs, n_times). For 'avg_power_itc', the real part contains average power and the
imaginary part contains inter-trial coherence (ITC), i.e., out = avg_power + i * itc.

"""
# Set output type
dtype = np.float64
Expand Down Expand Up @@ -692,33 +701,34 @@ def _compute_tfr(
The epochs.
freqs : array-like of floats, shape (n_freqs)
The frequencies.
sfreq : float | int, default 1.0
sfreq : float | int, default = 1.0
Sampling frequency of the data.
method : 'morlet'
The time-frequency method. 'morlet' convolves a Morlet wavelet.
n_cycles : float | array of float, default 7.0
method : {'morlet', 'multitaper'}
The time-frequency method. 'morlet' convolves a Morlet wavelet. 'multitaper' utilizes the
multitaper method.
n_cycles : float | array of float, default = 7.0
Number of cycles in the wavelet. Fixed number
or one per frequency.
zero_mean : bool | None, default None
zero_mean : bool | None, optional
None means True for method='multitaper' and False for method='morlet'.
If True, make sure the wavelets have a mean of zero.
time_bandwidth : float, default None
time_bandwidth : float, optional
If None and method=multitaper, will be set to 4.0 (3 tapers).
Time x (Full) Bandwidth product. Only applies if
method == 'multitaper'. The number of good tapers (low-bias) is
chosen automatically based on this to equal floor(time_bandwidth - 1).
use_fft : bool, default True
use_fft : bool, default = True
Use the FFT for convolutions or not.
decim : int | slice, default 1
decim : int | slice, default = 1
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note::
Decimation may create aliasing artifacts, yet decimation
is done after the convolutions.
output : str, default 'complex'
* 'complex' : single trial complex.
output : str, default = 'complex'
* 'complex' : single trial complex containing both amplitude and phase
* 'power' : single trial power.
* 'phase' : single trial phase.
* 'avg_power' : average of single trial power.
Expand Down Expand Up @@ -838,29 +848,30 @@ def tfr_array_morlet(
Sampling frequency of the data.
freqs : array-like of float, shape (n_freqs,)
The frequencies.
n_cycles : float | array of float, default 7.0
n_cycles : float | array of float, default = 7.0
Number of cycles in the Morlet wavelet. Fixed number or one per
frequency.
zero_mean : bool | False
zero_mean : bool | False, default = False
If True, make sure the wavelets have a mean of zero. default False.
use_fft : bool
use_fft : bool, default = True
Use the FFT for convolutions or not. default True.
decim : int | slice
decim : int | slice, default = 1
To reduce memory usage, decimation factor after time-frequency
decomposition. default 1
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note::
Decimation may create aliasing artifacts, yet decimation
is done after the convolutions.
output : str, default 'complex'
* 'complex' : single trial complex.
output : str, default ='complex'
* 'complex' : single trial complex containing both amplitude and phase
* 'power' : single trial power.
* 'phase' : single trial phase.
* 'avg_power' : average of single trial power.
* 'itc' : inter-trial coherence.
* 'avg_power_itc' : average of single trial power and inter-trial
coherence across trials.
By default, the output is 'complex'.
%(n_jobs)s
The number of epochs to process at the same time. The parallelization
is implemented across channels. Default 1.
Expand Down Expand Up @@ -904,6 +915,24 @@ def tfr_array_morlet(


def _get_nfft(wavelets, X, use_fft=True, check=True):
""" Compute the optimal FFT length for convolving wavelets with signal X.

Parameters
----------
wavelets : list of arrays
List of wavelets to be convolved with the signal.
X : array
The signal data array of shape (n_signals, n_times).
use_fft : bool, default = True
Whether FFT-based convolution will be used.
check : bool, default = True
Whether to check and warn or raise an error if wavelets are longer than the signal.

Returns
-------
nfft : int
The optimized FFT length to use for convolution.
"""
n_times = X.shape[-1]
max_size = max(w.size for w in wavelets)
if max_size > n_times:
Expand All @@ -925,7 +954,45 @@ def _get_nfft(wavelets, X, use_fft=True, check=True):
def _check_tfr_param(
freqs, sfreq, method, zero_mean, n_cycles, time_bandwidth, use_fft, decim, output
):
"""Aux. function to _compute_tfr to check the params validity."""
"""Aux. function to _compute_tfr to check the params validity.
freqs : array-like of floats, shape (n_freqs)
The frequencies.
sfreq : float | int
Sampling frequency of the data.
method : {'morlet', 'multitaper'}
The time-frequency method. 'morlet' convolves a Morlet wavelet. 'multitaper' utilizes the
multitaper method.
zero_mean : bool | None
Whether to apply zero-mean normalization to the wavelets.
n_cycles : float | array of float
Number of cycles in the Morlet wavelet. Fixed number or one per
frequency.
time_bandwidth : float, optional
If None and method=multitaper, will be set to 4.0 (3 tapers).
Time x (Full) Bandwidth product. Only applies if
method == 'multitaper'. The number of good tapers (low-bias) is
chosen automatically based on this to equal floor(time_bandwidth - 1).
use_fft : bool, default = True
Whether FFT-based convolution will be used.
decim : int | slice, default = 1
To reduce memory usage, decimation factor after time-frequency
decomposition. default 1
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note::
Decimation may create aliasing artifacts, yet decimation
is done after the convolutions.
output : str, default ='complex'
* 'complex' : single trial complex containing both amplitude and phase
* 'power' : single trial power.
* 'phase' : single trial phase.
* 'avg_power' : average of single trial power.
* 'itc' : inter-trial coherence.
* 'avg_power_itc' : average of single trial power and inter-trial
coherence across trials.
By default, the output is 'complex'.

"""
# Check freqs
if not isinstance(freqs, (list, np.ndarray)):
raise ValueError("freqs must be an array-like, got %s instead." % type(freqs))
Expand Down
Loading