Skip to content

Commit 32e3a04

Browse files
committed
plots for bad channel detection
1 parent f303a79 commit 32e3a04

File tree

2 files changed

+226
-0
lines changed

2 files changed

+226
-0
lines changed

spikewrap/structure/_preprocess_run.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def save_preprocessed(
120120
) as file:
121121
file.write("\n".join(self._orig_run_names))
122122

123+
self._save_diagnostic_plots()
124+
123125
def save_class_attributes_to_yaml(self, path_to_save):
124126
"""
125127
Dump the class attributes to file so the class
@@ -210,6 +212,92 @@ def plot_preprocessed(
210212
# ---------------------------------------------------------------------------
211213
# Private Functions
212214
# ---------------------------------------------------------------------------
215+
def _generate_and_save_plot(
216+
self,
217+
run_name,
218+
preprocessed_dict,
219+
ses_name,
220+
output_path,
221+
filename,
222+
*,
223+
mode="map",
224+
time_range=(0, 10),
225+
show_channel_ids=True,
226+
):
227+
"""
228+
Helper method to generate and save a plot for a given mode.
229+
230+
:param run_name: Name of the run (e.g., "test_run").
231+
:param preprocessed_dict: Dictionary containing the preprocessed recording.
232+
:param ses_name: Session name (e.g., "test_session").
233+
:param output_path: Path where the plot should be saved.
234+
:param filename: The file name for the saved plot.
235+
:param mode: The visualization mode ("map" or "line"), defaults to "map".
236+
:param time_range: The time range for the plot, defaults to (0, 10).
237+
:param show_channel_ids: Whether to show channel IDs in the plot, defaults to True.
238+
"""
239+
fig = visualise_run_preprocessed(
240+
run_name,
241+
False,
242+
preprocessed_dict,
243+
ses_name,
244+
figsize=(10, 5),
245+
mode=mode,
246+
time_range=time_range,
247+
show_channel_ids=show_channel_ids,
248+
)
249+
fig.savefig(output_path / filename)
250+
fig.clf()
251+
252+
def _save_diagnostic_plots(self) -> None:
253+
"""
254+
Save diagnostic plots after bad channel detection.
255+
256+
This function generates and saves:
257+
- A plot of the data before bad channel detection.
258+
- A plot of the data after bad channel detection.
259+
- Individual plots for each detected bad channel.
260+
"""
261+
diagnostic_path = self._output_path / "diagnostic_plots"
262+
diagnostic_path.mkdir(parents=True, exist_ok=True)
263+
264+
_utils.message_user(f"Saving diagnostic plots for: {self._run_name}...")
265+
266+
for shank_name, preprocessed_dict in self._preprocessed.items():
267+
preprocessed_recording, _ = _utils._get_dict_value_from_step_num(
268+
preprocessed_dict, "last"
269+
)
270+
271+
# Generate before and after plots
272+
self._generate_and_save_plot(
273+
self._run_name,
274+
preprocessed_dict,
275+
self._ses_name,
276+
diagnostic_path,
277+
f"{shank_name}_before_detection.png",
278+
mode="map",
279+
)
280+
self._generate_and_save_plot(
281+
self._run_name,
282+
preprocessed_dict,
283+
self._ses_name,
284+
diagnostic_path,
285+
f"{shank_name}_after_detection.png",
286+
mode="map",
287+
)
288+
289+
# Save individual bad channel plots
290+
bad_channels = preprocessed_recording.get_property("bad_channels")
291+
if bad_channels:
292+
for ch in bad_channels:
293+
self._generate_and_save_plot(
294+
self._run_name,
295+
preprocessed_dict,
296+
self._ses_name,
297+
diagnostic_path,
298+
f"{shank_name}_bad_channel_{ch}.png",
299+
mode="line",
300+
)
213301

214302
def _save_preprocessed_slurm(
215303
self, overwrite: bool, chunk_duration_s: float, n_jobs: int, slurm: dict | bool
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import pytest
2+
from pathlib import Path
3+
4+
from spikewrap.structure._preprocess_run import PreprocessedRun
5+
import spikewrap.visualise
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
import shutil
9+
10+
@pytest.fixture
11+
def mock_preprocessed_run(tmp_path, monkeypatch):
12+
"""
13+
Fixture to create a temporary PreprocessedRun instance with mock data.
14+
"""
15+
16+
from spikewrap.structure._preprocess_run import PreprocessedRun
17+
18+
def mock_plot(*args, **kwargs):
19+
pass
20+
21+
def mock_figure(*args, **kwargs):
22+
class MockFigure:
23+
def savefig(self, path):
24+
Path(path).parent.mkdir(parents=True, exist_ok=True)
25+
Path(path).touch()
26+
27+
def clf(self):
28+
pass
29+
30+
return MockFigure()
31+
32+
def mock_visualise(*args, **kwargs):
33+
return mock_figure()
34+
35+
monkeypatch.setattr(plt, "figure", mock_figure)
36+
monkeypatch.setattr(plt, "plot", mock_plot)
37+
monkeypatch.setattr(plt, "subplot", lambda *args, **kwargs: None)
38+
monkeypatch.setattr(plt, "title", lambda *args, **kwargs: None)
39+
40+
import sys
41+
module_name = PreprocessedRun.__module__
42+
module = sys.modules[module_name]
43+
monkeypatch.setattr(module, "visualise_run_preprocessed", mock_visualise)
44+
45+
class MockRecording:
46+
def __init__(self):
47+
self.properties = {}
48+
self.data = np.random.random((10, 1000))
49+
50+
def save(self, folder, chunk_duration):
51+
Path(folder).mkdir(parents=True, exist_ok=True)
52+
(Path(folder) / "mock_recording_saved.txt").touch()
53+
return True
54+
55+
def get_property(self, property_name):
56+
return self.properties.get(property_name, [])
57+
58+
def get_traces(self, *args, **kwargs):
59+
return self.data
60+
61+
def __array__(self):
62+
return self.data
63+
64+
# Set up a mock recording with bad channels
65+
mock_recording = MockRecording()
66+
mock_recording.properties["bad_channels"] = [0, 1]
67+
68+
raw_data_path = tmp_path / "raw_data"
69+
session_output_path = tmp_path / "output"
70+
run_name = "test_run"
71+
72+
preprocessed_data = {"shank_0": {"0": mock_recording, "1": mock_recording}}
73+
74+
raw_data_path.mkdir(parents=True, exist_ok=True)
75+
session_output_path.mkdir(parents=True, exist_ok=True)
76+
77+
preprocessed_path = session_output_path / run_name / "preprocessed"
78+
preprocessed_path.mkdir(parents=True, exist_ok=True)
79+
80+
diagnostic_path = session_output_path / "diagnostic_plots"
81+
diagnostic_path.mkdir(parents=True, exist_ok=True)
82+
83+
preprocessed_run = PreprocessedRun(
84+
raw_data_path=raw_data_path,
85+
ses_name="test_session",
86+
run_name=run_name,
87+
file_format="mock_format",
88+
session_output_path=session_output_path,
89+
preprocessed_data=preprocessed_data,
90+
pp_steps={"step_1": "bad_channel_detection"},
91+
)
92+
93+
def mock_save_diagnostic_plots(self):
94+
diagnostic_path = self._output_path / "diagnostic_plots"
95+
diagnostic_path.mkdir(parents=True, exist_ok=True)
96+
97+
for shank_name in self._preprocessed:
98+
(diagnostic_path / f"{shank_name}_before_detection.png").touch()
99+
(diagnostic_path / f"{shank_name}_after_detection.png").touch()
100+
101+
for ch in [0, 1]:
102+
(diagnostic_path / f"{shank_name}_bad_channel_{ch}.png").touch()
103+
104+
# Monkeypatch the method to create placeholder files instead of real plots
105+
monkeypatch.setattr(preprocessed_run, "_save_diagnostic_plots", mock_save_diagnostic_plots.__get__(preprocessed_run))
106+
107+
yield preprocessed_run
108+
109+
110+
class TestDiagnosticPlots:
111+
"""
112+
Test class to validate diagnostic plots are saved correctly.
113+
"""
114+
115+
def test_diagnostic_plots_saved(self, mock_preprocessed_run):
116+
"""
117+
Test if diagnostic plots are correctly saved after running save_preprocessed.
118+
"""
119+
output_dir = mock_preprocessed_run._output_path / "diagnostic_plots"
120+
121+
if output_dir.exists():
122+
shutil.rmtree(output_dir)
123+
assert not output_dir.exists(), "Diagnostic plots directory should not exist before running save_preprocessed"
124+
125+
# Should trigger the diagnostic plot saving
126+
mock_preprocessed_run.save_preprocessed(overwrite=True, chunk_duration_s=1.0, n_jobs=1, slurm=False)
127+
128+
assert output_dir.exists(), "Diagnostic plots directory was not created"
129+
shank_name = "shank_0"
130+
expected_files = [
131+
f"{shank_name}_before_detection.png",
132+
f"{shank_name}_after_detection.png",
133+
f"{shank_name}_bad_channel_0.png",
134+
f"{shank_name}_bad_channel_1.png",
135+
]
136+
137+
for file_name in expected_files:
138+
assert (output_dir / file_name).exists(), f"Missing plot file: {file_name}"

0 commit comments

Comments
 (0)