9
9
import torch
10
10
11
11
from fast_llm .config import Configurable
12
- from fast_llm .core .distributed import safe_barrier
12
+ from fast_llm .core .distributed import allreduce_scalar , safe_barrier
13
13
from fast_llm .data .data .abstract import Data
14
14
from fast_llm .data .dataset .config import SamplingParameters
15
15
from fast_llm .engine .config_utils .run import Run , is_main_rank , log_main_rank , log_pipeline_parallel_main_rank
23
23
from fast_llm .engine .training .config import TrainerConfig , TrainingCheckpointBaseConfig , TrainingCheckpointConfig
24
24
from fast_llm .engine .training .wandb import Wandb
25
25
from fast_llm .logging import format_metrics , get_memory_usage_mib , log_memory_usage
26
- from fast_llm .utils import Assert
26
+ from fast_llm .utils import Assert , Interrupter
27
27
28
28
logger = logging .getLogger (__name__ )
29
29
@@ -214,6 +214,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
214
214
distributed_config = self ._config .model .distributed , start_step = self ._completed_steps
215
215
)
216
216
217
+ interrupter = Interrupter (self ._config .training .checkpoint .enabled ())
217
218
train_iterator = self ._get_data_iterator (
218
219
PhaseType .training .value ,
219
220
self ._completed_steps ,
@@ -231,7 +232,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
231
232
start_iteration = self ._completed_steps
232
233
last_iteration = start_iteration
233
234
stop = False
234
- with profiler :
235
+ with profiler , interrupter :
235
236
while not stop :
236
237
# Iteration starts at 1, so we increment at the beginning.
237
238
self ._completed_steps += 1
@@ -317,8 +318,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
317
318
profiler .step ()
318
319
319
320
done = self ._completed_steps >= self ._config .training .train_iters
320
- # TODO: Signal-based stop.
321
- stop = done or self ._config .training .shutdown .enabled (self ._completed_steps )
321
+
322
322
# Evaluation
323
323
# TODO: Adjust valid iterator length.
324
324
if PhaseType .validation in self ._samples_per_split and (
@@ -366,11 +366,19 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
366
366
if is_main_rank () and metrics :
367
367
self ._wandb .log_metrics (self ._completed_steps , metrics )
368
368
369
- if self ._config .training .checkpoint .enabled (None if stop else self ._completed_steps ):
370
- self ._save_checkpoint (self ._config .training .checkpoint , metrics )
369
+ stop = done or self ._config .training .shutdown .enabled (self ._completed_steps )
371
370
372
371
if self ._config .training .export .enabled (None if done else self ._completed_steps ):
373
372
self ._save_checkpoint (self ._config .training .export , metrics )
373
+
374
+ if interrupter .enabled :
375
+ stop = stop or allreduce_scalar (
376
+ interrupter .interrupted , torch .int32 , self ._distributed .world_group
377
+ )
378
+
379
+ if self ._config .training .checkpoint .enabled (None if stop else self ._completed_steps ):
380
+ self ._save_checkpoint (self ._config .training .checkpoint , metrics )
381
+
374
382
# The profiler calls the trace_fn at the end and this could lead to
375
383
profiler .step ()
376
384
return done , metrics
0 commit comments