Skip to content

Commit 4f8d8ef

Browse files
committed
FIX: update pet mask workflow
1 parent 52bf8bd commit 4f8d8ef

File tree

4 files changed

+53
-36
lines changed

4 files changed

+53
-36
lines changed

fmriprep/workflows/pet/confounds.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from ...config import DEFAULT_MEMORY_MIN_GB
3737
from ...interfaces import DerivativesDataSink
3838
from ...interfaces.confounds import (
39-
FilterDropped,
4039
PETSummary,
4140
FramewiseDisplacement,
4241
FSLMotionParams,
@@ -148,7 +147,7 @@ def init_pet_confs_wf(
148147
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
149148
from niworkflows.interfaces.images import SignalExtraction
150149
from niworkflows.interfaces.morphology import BinaryDilation, BinarySubtraction
151-
from niworkflows.interfaces.nibabel import ApplyMask, Binarize
150+
from niworkflows.interfaces.nibabel import Binarize
152151
from niworkflows.interfaces.utility import AddTSVHeader, DictMerge
153152

154153
from ...interfaces.confounds import aCompCorMasks
@@ -395,8 +394,8 @@ def _select_cols(table):
395394
(acompcor_tfm, acompcor_bin, [('output_image', 'in_file')]),
396395
(acompcor_bin, merge_rois, [
397396
(('out_mask', _last), 'in3'),
398-
(('out_mask', lambda l: l[0]), 'in1'),
399-
(('out_mask', lambda l: l[1]), 'in2'),
397+
(('out_mask', lambda masks: masks[0]), 'in1'),
398+
(('out_mask', lambda masks: masks[1]), 'in2'),
400399
]),
401400
(merge_rois, signals, [('out', 'label_files')]),
402401

fmriprep/workflows/pet/fit.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,7 @@ def init_pet_fit_wf(
198198

199199
summary = pe.Node(
200200
FunctionalSummary(
201-
registration=(
202-
'Precomputed'
203-
if petref2anat_xform
204-
else 'mri_coreg'
205-
),
201+
registration=('Precomputed' if petref2anat_xform else 'mri_coreg'),
206202
registration_dof=config.workflow.pet2anat_dof,
207203
orientation=orientation,
208204
),
@@ -220,7 +216,6 @@ def init_pet_fit_wf(
220216
(petref_buffer, outputnode, [
221217
('petref', 'petref'),
222218
]),
223-
(merge_mask, outputnode, [('out', 'pet_mask')]),
224219
(hmc_buffer, outputnode, [
225220
('hmc_xforms', 'motion_xfm'),
226221
]),
@@ -265,11 +260,17 @@ def init_pet_fit_wf(
265260
# Ensure all stage-1 workflows were created successfully before
266261
# attempting to connect them. Nipype's ``connect`` call will fail
267262
# with a ``NoneType`` error if any node is undefined.
268-
stage1_nodes = [petref_wf, petref_buffer, ds_petref_wf,
269-
func_fit_reports_wf, petref_source_buffer]
263+
stage1_nodes = [
264+
petref_wf,
265+
petref_buffer,
266+
ds_petref_wf,
267+
func_fit_reports_wf,
268+
petref_source_buffer,
269+
]
270270
if any(node is None for node in stage1_nodes):
271-
raise RuntimeError('PET reference stage could not be built - '
272-
'check inputs and configuration.')
271+
raise RuntimeError(
272+
'PET reference stage could not be built - check inputs and configuration.'
273+
)
273274

274275
workflow.connect([
275276
(petref_wf, petref_buffer, [
@@ -373,14 +374,16 @@ def init_pet_fit_wf(
373374
else:
374375
t1w_mask_tfm.inputs.transforms = petref2anat_xform
375376

376-
workflow.connect([
377-
(inputnode, t1w_mask_tfm, [('t1w_mask', 'input_image')]),
378-
(petref_buffer, t1w_mask_tfm, [('petref', 'reference_image')]),
379-
(petref_buffer, petref_mask, [('petref', 'in_file')]),
380-
(petref_mask, merge_mask, [('out_mask', 'mask1')]),
381-
(t1w_mask_tfm, merge_mask, [('output_image', 'mask2')]),
382-
(merge_mask, petref_buffer, [('out', 'pet_mask')]),
383-
])
377+
workflow.connect(
378+
[
379+
(inputnode, t1w_mask_tfm, [('t1w_mask', 'input_image')]),
380+
(petref_buffer, t1w_mask_tfm, [('petref', 'reference_image')]),
381+
(petref_buffer, petref_mask, [('petref', 'in_file')]),
382+
(petref_mask, merge_mask, [('out_mask', 'mask1')]),
383+
(t1w_mask_tfm, merge_mask, [('output_image', 'mask2')]),
384+
(merge_mask, outputnode, [('out', 'pet_mask')]),
385+
]
386+
)
384387

385388
ds_petmask_wf = init_ds_petmask_wf(
386389
output_dir=config.execution.petprep_dir,
@@ -482,18 +485,13 @@ def init_pet_native_wf(
482485
)
483486
outputnode.inputs.metadata = metadata
484487

485-
petbuffer = pe.Node(
486-
niu.IdentityInterface(fields=['pet_file']), name='petbuffer'
487-
)
488+
petbuffer = pe.Node(niu.IdentityInterface(fields=['pet_file']), name='petbuffer')
488489

489490
# PET source: track original PET file(s)
490491
# The Select interface requires an index to choose from ``inlist``. Since
491492
# ``pet_file`` is a single path, explicitly set the index to ``0`` to avoid
492493
# missing mandatory input errors when the node runs.
493-
pet_source = pe.Node(
494-
niu.Select(inlist=[pet_file], index=0),
495-
name='pet_source'
496-
)
494+
pet_source = pe.Node(niu.Select(inlist=[pet_file], index=0), name='pet_source')
497495
validate_pet = pe.Node(ValidateImage(), name='validate_pet')
498496
workflow.connect([
499497
(pet_source, validate_pet, [('out', 'in_file')]),

fmriprep/workflows/pet/tests/test_fit.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,23 @@ def test_pet_native_precomputes(
134134

135135
flatgraph = wf._create_flat_graph()
136136
generate_expanded_graph(flatgraph)
137+
138+
139+
def test_pet_fit_mask_connections(bids_root: Path, tmp_path: Path):
140+
"""Ensure the PET mask is generated and connected correctly."""
141+
pet_file = str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')
142+
img = nb.Nifti1Image(np.zeros((2, 2, 2, 1)), np.eye(4))
143+
img.to_filename(pet_file)
144+
145+
with mock_config(bids_dir=bids_root):
146+
wf = init_pet_fit_wf(pet_file=pet_file, precomputed={}, omp_nthreads=1)
147+
148+
assert 'merge_mask' in wf.list_node_names()
149+
assert 'ds_petmask_wf.ds_petmask' in wf.list_node_names()
150+
151+
merge_mask = wf.get_node('merge_mask')
152+
edge = wf._graph.get_edge_data(merge_mask, wf.get_node('outputnode'))
153+
assert ('out', 'pet_mask') in edge['connect']
154+
155+
ds_edge = wf._graph.get_edge_data(merge_mask, wf.get_node('ds_petmask_wf'))
156+
assert ('out', 'inputnode.petmask') in ds_edge['connect']

fmriprep/workflows/pet/tests/test_reference.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,22 @@
66

77
def test_reference_frame_select(tmp_path):
88
img = nb.Nifti1Image(np.zeros((5, 5, 5, 4)), np.eye(4))
9-
pet_file = tmp_path / "pet.nii.gz"
9+
pet_file = tmp_path / 'pet.nii.gz'
1010
img.to_filename(pet_file)
1111

1212
wf = init_raw_petref_wf(pet_file=str(pet_file), reference_frame=2)
1313
node_names = [n.name for n in wf._get_all_nodes()]
14-
assert "extract_frame" in node_names
15-
assert "gen_avg" not in node_names
16-
node = wf.get_node("extract_frame")
14+
assert 'extract_frame' in node_names
15+
assert 'gen_avg' not in node_names
16+
node = wf.get_node('extract_frame')
1717
assert node.interface.inputs.t_min == 2
1818

1919

2020
def test_reference_frame_average(tmp_path):
2121
img = nb.Nifti1Image(np.zeros((5, 5, 5, 4)), np.eye(4))
22-
pet_file = tmp_path / "pet.nii.gz"
22+
pet_file = tmp_path / 'pet.nii.gz'
2323
img.to_filename(pet_file)
2424

25-
wf = init_raw_petref_wf(pet_file=str(pet_file), reference_frame="average")
25+
wf = init_raw_petref_wf(pet_file=str(pet_file), reference_frame='average')
2626
node_names = [n.name for n in wf._get_all_nodes()]
27-
assert "gen_avg" in node_names
27+
assert 'gen_avg' in node_names

0 commit comments

Comments
 (0)