Skip to content

Commit 98d3969

Browse files
authored
Save and quit on sigint and sigterm (#260)
1 parent f08ac90 commit 98d3969

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

fast_llm/engine/training/trainer.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010

1111
from fast_llm.config import Configurable
12-
from fast_llm.core.distributed import safe_barrier
12+
from fast_llm.core.distributed import allreduce_scalar, safe_barrier
1313
from fast_llm.data.data.abstract import Data
1414
from fast_llm.data.dataset.config import SamplingParameters
1515
from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank
@@ -23,7 +23,7 @@
2323
from fast_llm.engine.training.config import TrainerConfig, TrainingCheckpointBaseConfig, TrainingCheckpointConfig
2424
from fast_llm.engine.training.wandb import Wandb
2525
from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage
26-
from fast_llm.utils import Assert
26+
from fast_llm.utils import Assert, Interrupter
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -214,6 +214,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
214214
distributed_config=self._config.model.distributed, start_step=self._completed_steps
215215
)
216216

217+
interrupter = Interrupter(self._config.training.checkpoint.enabled())
217218
train_iterator = self._get_data_iterator(
218219
PhaseType.training.value,
219220
self._completed_steps,
@@ -231,7 +232,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
231232
start_iteration = self._completed_steps
232233
last_iteration = start_iteration
233234
stop = False
234-
with profiler:
235+
with profiler, interrupter:
235236
while not stop:
236237
# Iteration starts at 1, so we increment at the beginning.
237238
self._completed_steps += 1
@@ -317,8 +318,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
317318
profiler.step()
318319

319320
done = self._completed_steps >= self._config.training.train_iters
320-
# TODO: Signal-based stop.
321-
stop = done or self._config.training.shutdown.enabled(self._completed_steps)
321+
322322
# Evaluation
323323
# TODO: Adjust valid iterator length.
324324
if PhaseType.validation in self._samples_per_split and (
@@ -366,11 +366,19 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
366366
if is_main_rank() and metrics:
367367
self._wandb.log_metrics(self._completed_steps, metrics)
368368

369-
if self._config.training.checkpoint.enabled(None if stop else self._completed_steps):
370-
self._save_checkpoint(self._config.training.checkpoint, metrics)
369+
stop = done or self._config.training.shutdown.enabled(self._completed_steps)
371370

372371
if self._config.training.export.enabled(None if done else self._completed_steps):
373372
self._save_checkpoint(self._config.training.export, metrics)
373+
374+
if interrupter.enabled:
375+
stop = stop or allreduce_scalar(
376+
interrupter.interrupted, torch.int32, self._distributed.world_group
377+
)
378+
379+
if self._config.training.checkpoint.enabled(None if stop else self._completed_steps):
380+
self._save_checkpoint(self._config.training.checkpoint, metrics)
381+
374382
# The profiler calls the trace_fn at the end and this could lead to
375383
profiler.step()
376384
return done, metrics

fast_llm/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
import logging
33
import math
4+
import signal
45
import typing
56
from typing import Callable
67

@@ -336,3 +337,34 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple
336337
def check_equal_nested(config_a, config_b):
337338
if errors := compare_nested(config_a, config_b):
338339
raise ValueError("\n".join(errors))
340+
341+
342+
class Interrupter:
343+
def __init__(self, enabled: bool = True, signals: typing.Sequence[int] = (signal.SIGINT, signal.SIGTERM)):
344+
self._enabled = enabled
345+
self._signals = signals
346+
347+
def __enter__(self):
348+
self._interrupted = False
349+
self._old_signals = (
350+
{signum: signal.signal(signum, self._handle_signal) for signum in self._signals} if self._enabled else {}
351+
)
352+
353+
def __exit__(self, exc_type, exc_val, exc_tb):
354+
for signum, handler in self._old_signals.items():
355+
signal.signal(signum, handler)
356+
357+
def _handle_signal(self, signum, frame):
358+
logger.info(f"Interrupt signal {signal.Signals(signum).name} received.")
359+
if self._interrupted:
360+
# Raise for a repeated signal, ex. if a user really wants to ctrl-C.
361+
self._old_signals[signum](signum, frame)
362+
self._interrupted = True
363+
364+
@property
365+
def enabled(self) -> bool:
366+
return self._enabled
367+
368+
@property
369+
def interrupted(self):
370+
return self._interrupted

tests/test_mtp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_transformer_mtp(config_dict: dict[str, typing.Any]):
129129
loss.backward()
130130

131131

132+
@pytest.mark.skip(reason="Too slow")
132133
@requires_cuda
133134
@pytest.mark.skipif(not run_hybrid_test, reason="No CUDA available or Mamba not installed")
134135
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)