Skip to content

Commit f63204f

Browse files
committed
feat: Add flag to set TRT fallback behavior
1 parent 5cc8d56 commit f63204f

File tree

6 files changed

+43
-2
lines changed

6 files changed

+43
-2
lines changed

fmriprep/cli/parser.py

+20
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ def _slice_time_ref(value, parser):
151151
raise parser.error(f'Slice time reference must be in range 0-1. Received {value}.')
152152
return value
153153

154+
def _fallback_trt(value, parser):
155+
if value == 'estimated':
156+
return value
157+
try:
158+
return float(value)
159+
except ValueError:
160+
raise parser.error(
161+
f'Falling back to TRT must be a number or "estimated". Received {value}.'
162+
) from None
163+
154164
verstr = f'fMRIPrep v{config.environment.version}'
155165
currentv = Version(config.environment.version)
156166
is_release = not any((currentv.is_devrelease, currentv.is_prerelease, currentv.is_postrelease))
@@ -165,6 +175,7 @@ def _slice_time_ref(value, parser):
165175
PositiveInt = partial(_min_one, parser=parser)
166176
BIDSFilter = partial(_bids_filter, parser=parser)
167177
SliceTimeRef = partial(_slice_time_ref, parser=parser)
178+
FallbackTRT = partial(_fallback_trt, parser=parser)
168179

169180
# Arguments as specified by BIDS-Apps
170181
# required, positional arguments
@@ -423,6 +434,15 @@ def _slice_time_ref(value, parser):
423434
type=int,
424435
help='Number of nonsteady-state volumes. Overrides automatic detection.',
425436
)
437+
g_conf.add_argument(
438+
'--fallback-total-readout-time',
439+
required=False,
440+
action='store',
441+
default=None,
442+
type=FallbackTRT,
443+
help='Fallback value for Total Readout Time (TRT) calculation. '
444+
'May be a number or "estimated".',
445+
)
426446
g_conf.add_argument(
427447
'--random-seed',
428448
dest='_random_seed',

fmriprep/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,9 @@ class workflow(_Config):
575575
"""Remove the mean from fieldmaps."""
576576
force_syn = None
577577
"""Run *fieldmap-less* susceptibility-derived distortions estimation."""
578+
fallback_total_readout_time = None
579+
"""Infer the total readout time if unavailable from authoritative metadata.
580+
This may be a number or the string "estimated"."""
578581
hires = None
579582
"""Run FreeSurfer ``recon-all`` with the ``-hires`` flag."""
580583
fs_no_resume = None

fmriprep/interfaces/resampling.py

+9
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def _run_interface(self, runtime):
184184
class DistortionParametersInputSpec(TraitedSpec):
185185
in_file = File(exists=True, desc='EPI image corresponding to the metadata')
186186
metadata = traits.Dict(mandatory=True, desc='metadata corresponding to the inputs')
187+
fallback = traits.Either(
188+
None,
189+
'estimated',
190+
traits.Float,
191+
usedefault=True,
192+
desc='Fallback value for missing metadata',
193+
)
187194

188195

189196
class DistortionParametersOutputSpec(TraitedSpec):
@@ -208,6 +215,8 @@ def _run_interface(self, runtime):
208215
self._results['readout_time'] = get_trt(
209216
self.inputs.metadata,
210217
self.inputs.in_file or None,
218+
use_estimate=self.inputs.fallback == 'estimated',
219+
fallback=self.inputs.fallback if isinstance(self.inputs.fallback, float) else None,
211220
)
212221
self._results['pe_direction'] = self.inputs.metadata['PhaseEncodingDirection']
213222
except (KeyError, ValueError):

fmriprep/workflows/bold/apply.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def init_bold_volumetric_resample_wf(
1717
metadata: dict,
1818
mem_gb: dict[str, float],
1919
jacobian: bool,
20+
fallback_total_readout_time: str | float | None = None,
2021
fieldmap_id: str | None = None,
2122
omp_nthreads: int = 1,
2223
name: str = 'bold_volumetric_resample_wf',
@@ -161,7 +162,10 @@ def init_bold_volumetric_resample_wf(
161162
run_without_submitting=True,
162163
)
163164
distortion_params = pe.Node(
164-
DistortionParameters(metadata=metadata),
165+
DistortionParameters(
166+
metadata=metadata,
167+
fallback=fallback_total_readout_time,
168+
),
165169
name='distortion_params',
166170
run_without_submitting=True,
167171
)

fmriprep/workflows/bold/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def init_bold_wf(
380380
# Resample to anatomical space
381381
bold_anat_wf = init_bold_volumetric_resample_wf(
382382
metadata=all_metadata[0],
383+
fallback_total_readout_time=config.workflow.fallback_total_readout_time,
383384
fieldmap_id=fieldmap_id if not multiecho else None,
384385
omp_nthreads=omp_nthreads,
385386
mem_gb=mem_gb,

fmriprep/workflows/bold/fit.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,11 @@ def init_bold_native_wf(
843843
)
844844

845845
distortion_params = pe.Node(
846-
DistortionParameters(metadata=metadata, in_file=bold_file),
846+
DistortionParameters(
847+
metadata=metadata,
848+
in_file=bold_file,
849+
fallback=config.workflow.fallback_total_readout_time,
850+
),
847851
name='distortion_params',
848852
run_without_submitting=True,
849853
)

0 commit comments

Comments
 (0)