Skip to content

Commit 2fa9187

Browse files
committed
Skip over existing splits (when restarting a failed job)
1 parent 419cf85 commit 2fa9187

File tree

2 files changed

+99
-35
lines changed

2 files changed

+99
-35
lines changed

sotodlib/toast/ops/splits.py

Lines changed: 98 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,12 @@ class Splits(Operator):
449449
"interval.",
450450
)
451451

452+
skip_existing = Bool(
453+
False,
454+
help="If the mapmaker reports that all requested output files exist "
455+
"for a split, skip the mapmaking step."
456+
)
457+
452458
shared_flags = Unicode(
453459
defaults.shared_flags,
454460
allow_none=True,
@@ -586,49 +592,57 @@ def _exec(self, data, detectors=None, **kwargs):
586592

587593
# Loop over splits
588594
for split_name, spl in self._split_obj.items():
589-
log.info_rank(f"Running Split '{split_name}'", comm=data.comm.comm_world)
590595
# Set mapmaker name based on split and the name of this
591596
# Splits operator.
592597
mname = f"{self.name}_{split_name}"
593598
self.mapmaker.name = mname
599+
log.info_rank(
600+
f"Running Split '{split_name}'", comm=data.comm.comm_world
601+
)
594602

595603
# Apply this split
596604
for ob in data.obs:
597605
spl.create_split(ob)
598606

599-
if self.splits_as_flags:
600-
# Split is applied through flagging, not as a view
601-
# Flag samples outside the intervals by prefixing '~'
602-
# to the view name
603-
FlagIntervals(
604-
shared_flags=self.shared_flags,
605-
shared_flag_bytes=1,
606-
view_mask=[
607-
(f"~{spl.split_intervals}", np.uint8(self.shared_flag_mask))
608-
],
609-
reset=True,
610-
).apply(data)
611-
map_binner.shared_flag_mask |= self.shared_flag_mask
607+
if self.skip_existing and self.splits_exist(data, split_name=split_name):
608+
log.info_rank(
609+
f"All outputs for split '{split_name}' exist, skipping...",
610+
comm=data.comm.comm_world,
611+
)
612612
else:
613-
# Set mapmaking tools to use the current split interval list
614-
map_binner.pixel_pointing.view = spl.split_intervals
615-
616-
if not map_binner.full_pointing:
617-
# We are not using full pointing and so we clear the
618-
# residual pointing for this split
619-
toast.ops.Delete(
620-
detdata=[
621-
map_binner.pixel_pointing.pixels,
622-
map_binner.stokes_weights.weights,
623-
map_binner.pixel_pointing.detector_pointing.quats,
624-
],
625-
).apply(data)
626-
627-
# Run mapmaking
628-
self.mapmaker.apply(data)
629-
630-
# Write
631-
self.write_splits(data, split_name=split_name)
613+
if self.splits_as_flags:
614+
# Split is applied through flagging, not as a view
615+
# Flag samples outside the intervals by prefixing '~'
616+
# to the view name
617+
FlagIntervals(
618+
shared_flags=self.shared_flags,
619+
shared_flag_bytes=1,
620+
view_mask=[
621+
(f"~{spl.split_intervals}", np.uint8(self.shared_flag_mask))
622+
],
623+
reset=True,
624+
).apply(data)
625+
map_binner.shared_flag_mask |= self.shared_flag_mask
626+
else:
627+
# Set mapmaking tools to use the current split interval list
628+
map_binner.pixel_pointing.view = spl.split_intervals
629+
630+
if not map_binner.full_pointing:
631+
# We are not using full pointing and so we clear the
632+
# residual pointing for this split
633+
toast.ops.Delete(
634+
detdata=[
635+
map_binner.pixel_pointing.pixels,
636+
map_binner.stokes_weights.weights,
637+
map_binner.pixel_pointing.detector_pointing.quats,
638+
],
639+
).apply(data)
640+
641+
# Run mapmaking
642+
self.mapmaker.apply(data)
643+
644+
# Write
645+
self.write_splits(data, split_name=split_name)
632646

633647
# Remove split
634648
for ob in data.obs:
@@ -640,6 +654,56 @@ def _exec(self, data, detectors=None, **kwargs):
640654
map_binner.pixel_pointing.view = pointing_view_save
641655
map_binner.pixel_pointing.shared_flag_mask = shared_flag_mask_save
642656

657+
def splits_exist(self, data, split_name=None):
658+
"""Write out all split products."""
659+
if not hasattr(self, "_split_obj"):
660+
msg = "No splits have been created yet, cannot check existence"
661+
raise RuntimeError(msg)
662+
663+
if hasattr(self.mapmaker, "map_binning"):
664+
pixel_pointing = self.mapmaker.map_binning.pixel_pointing
665+
else:
666+
pixel_pointing = self.mapmaker.binning.pixel_pointing
667+
668+
if split_name is None:
669+
to_write = dict(self._split_obj)
670+
else:
671+
to_write = {split_name: self._split_obj[split_name]}
672+
673+
if self.mapmaker.write_hdf5:
674+
fname_suffix = "h5"
675+
else:
676+
fname_suffix = "fits"
677+
678+
all_exist = True
679+
for spname, spl in to_write.items():
680+
mname = f"{self.name}_{split_name}"
681+
for prod, binner_key, write in [
682+
("hits", None, self.write_hits),
683+
("cov", None, self.write_cov),
684+
("invcov", None, self.write_invcov),
685+
("rcond", None, self.write_rcond),
686+
("map", "binned", self.write_map),
687+
("noiseweighted_map", "noiseweighted", self.write_noiseweighted_map),
688+
]:
689+
if not write:
690+
continue
691+
if binner_key is not None:
692+
# get the product name from BinMap
693+
mkey = getattr(self.mapmaker.binning, binner_key)
694+
else:
695+
# hits and covariance are not made by BinMap.
696+
# Try synthesizing the product name
697+
mkey = f"{mname}_{prod}"
698+
fname = os.path.join(
699+
self.output_dir,
700+
f"{self.name}_{spname}_{prod}.{fname_suffix}"
701+
)
702+
if not os.path.isfile(fname):
703+
all_exist = False
704+
break
705+
return all_exist
706+
643707
def write_splits(self, data, split_name=None):
644708
"""Write out all split products."""
645709
if not hasattr(self, "_split_obj"):
@@ -682,7 +746,7 @@ def write_splits(self, data, split_name=None):
682746
mkey = f"{mname}_{prod}"
683747
if mkey not in data:
684748
msg = f"'{mkey}' not found in data. "
685-
smg += f"Available keys are {data.keys()}"
749+
msg += f"Available keys are {data.keys()}"
686750
raise RuntimeError(msg)
687751
fname = os.path.join(
688752
self.output_dir,

sotodlib/toast/workflows/proc_demodulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def demodulate(job, otherargs, runargs, data):
9292
timer=timer,
9393
)
9494
demod_weights = toast.ops.StokesWeightsDemod(
95-
mode=job_ops.weights_radec.mode
95+
mode=job_ops.demodulate.mode
9696
)
9797
job_ops.weights_radec = demod_weights
9898
if hasattr(job_ops, "binner"):

0 commit comments

Comments
 (0)