Skip to content

[ENH] Annotations in dss_line_iter plots and specification of a saving path #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 26 additions & 34 deletions examples/example_dss_line.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions examples/example_dss_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@
print(data.shape) # n_samples, n_chans, n_trials

# Apply dss_line(), removing only one component
out1, _ = dss.dss_line(data, fline, sfreq, nremove=1, nfft=400)
out1, _ = dss.dss_line(data, fline, sfreq, nfft=400, nremove=1)

###############################################################################
# Now try dss_line_iter(). This applies dss_line() repeatedly until the
# artifact is gone
out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400)
out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400, show=True)
print(f"Removed {iterations} components")

###############################################################################
Expand Down
38 changes: 26 additions & 12 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Authors: Nicolas Barascud <nicolas.barascud@gmail.com>
# Maciej Szul <maciej.szul@isc.cnrs.fr>
import numpy as np
from pathlib import Path
from numpy.lib.stride_tricks import sliding_window_view
from scipy import linalg
from scipy.signal import welch
Expand Down Expand Up @@ -264,7 +265,7 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,


def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
nfft=512, show=False, prefix="dss_iter", n_iter_max=100):
nfft=512, show=False, dirname=None, extension=".png", n_iter_max=100):
"""Remove power line artifact iteratively.

This method applies dss_line() until the artifact has been smoothed out
Expand All @@ -288,9 +289,12 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
FFT size for the internal PSD calculation (default=512).
show: bool
Produce a visual output of each iteration (default=False).
prefix : str
Path and first part of the visualisation output file
"{prefix}_{iteration number}.png" (default="dss_iter").
dirname: str
Path to the directory where visual outputs are saved when show is 'True'.
If 'None', does not save the outputs. (default=None)
extension: str
Extension of the images filenames. Must be compatible with plt.savefig()
function. (default=".png")
n_iter_max : int
Maximum number of iterations (default=100).

Expand Down Expand Up @@ -357,26 +361,36 @@ def nan_basic_interp(array):
y = mean_sens[freq_rn_ix]
ax.flat[0].plot(freq_used, y)
ax.flat[0].set_title("Mean PSD across trials")
ax.flat[0].set_xlabel("Frequency (Hz)")
ax.flat[0].set_ylabel("Power")

ax.flat[1].plot(freq_used, mean_psd_tf, c="gray")
ax.flat[1].plot(freq_used, mean_psd, c="blue")
ax.flat[1].plot(freq_used, clean_fit_line, c="red")
ax.flat[1].plot(freq_used, mean_psd_tf, c="gray", label="Interpolated mean PSD")
ax.flat[1].plot(freq_used, mean_psd, c="blue", label="Mean PSD")
ax.flat[1].plot(freq_used, clean_fit_line, c="red", label="Fitted polynomial")
ax.flat[1].set_title("Mean PSD across trials and sensors")
ax.flat[1].set_xlabel("Frequency (Hz)")
ax.flat[1].set_ylabel("Power")
ax.flat[1].legend()

tf_ix = np.where(freq_used <= fline)[0][-1]
ax.flat[2].plot(residuals, freq_used)
ax.flat[2].plot(freq_used, residuals)
color = "green"
if mean_score <= 0:
color = "red"
ax.flat[2].scatter(residuals[tf_ix], freq_used[tf_ix], c=color)
ax.flat[2].scatter(freq_used[tf_ix], residuals[tf_ix], c=color)
ax.flat[2].set_title("Residuals")
ax.flat[2].set_xlabel("Frequency (Hz)")
ax.flat[2].set_ylabel("Power")

ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker="o")
ax.flat[3].set_title("Iterations")
ax.flat[3].set_title("Aggregated residuals")
ax.flat[3].set_xlabel("Iteration")
ax.flat[3].set_ylabel("Power")

plt.tight_layout()
plt.savefig(f"{prefix}_{iterations:03}.png")
plt.close("all")
if dirname is not None:
plt.savefig(Path(dirname) / f"dss_iter_{iterations:03}{extension}")
plt.show()

if mean_score <= 0:
break
Expand Down
11 changes: 6 additions & 5 deletions tests/test_dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_dss_line_iter():
# # time x channel x trial sf=200 fline=50

sr = 200
fline = 25
fline = 50
n_samples = 9000
n_chans = 10

Expand All @@ -147,9 +147,8 @@ def test_dss_line_iter():
show=False, n_iter_max=2)

with TemporaryDirectory() as tmpdir:
out, _ = dss.dss_line_iter(x, fline + .5, sr,
prefix=os.path.join(tmpdir, "dss_iter_"),
show=True)
out, _ = dss.dss_line_iter(x, fline + 1, sr,
show=True, dirname=tmpdir)

def _plot(before, after):
f, ax = plt.subplots(1, 2, sharey=True)
Expand All @@ -171,7 +170,9 @@ def _plot(before, after):
# # Test n_trials > 1 TODO
x, _ = create_line_data(n_samples, n_chans=n_chans, n_trials=2,
noise_dim=10, SNR=2, fline=fline / sr)
out, _ = dss.dss_line_iter(x, fline, sr, show=False)
with TemporaryDirectory() as tmpdir:
out, _ = dss.dss_line_iter(x, fline, sr,
show=True, dirname=tmpdir)
plt.close("all")


Expand Down
Loading