Skip to content

Commit 92e1bd5

Browse files
authored
FederatedAlgorithm should restore current iteration if crash during gather_results (#132)
1 parent d92cc9f commit 92e1bd5

2 files changed

Lines changed: 82 additions & 2 deletions

File tree

pfl/algorithm/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from pfl.aggregate.base import Backend
1515
from pfl.callback.base import TrainingProcessCallback
16-
from pfl.common_types import Population, Saveable
16+
from pfl.common_types import Checkpointer, Population, Saveable
1717
from pfl.context import CentralContext
1818
from pfl.data.dataset import AbstractDatasetType
1919
from pfl.exception import CheckpointNotFoundError
@@ -77,10 +77,17 @@ def __init__(self):
7777
self._random_state = np.random.RandomState(
7878
np.random.randint(0, 2**32, dtype=np.uint32))
7979
self._current_central_iteration = 0
80+
self._checkpointer = None
8081

8182
def _get_seed(self):
8283
return self._random_state.randint(0, 2**32, dtype=np.uint32)
8384

85+
def set_checkpointer(self, checkpointer: Checkpointer) -> None:
86+
"""
87+
Set checkpointer such that intermediate state can be saved.
88+
"""
89+
self._checkpointer = checkpointer
90+
8491
def save(self, dir_path: str) -> None:
8592
state_path = os.path.join(dir_path, 'algorithm_checkpoint.json')
8693
with open(state_path, 'w') as f:
@@ -274,11 +281,19 @@ def run(self,
274281
all_metrics) = self.get_next_central_contexts(
275282
model, self._current_central_iteration, algorithm_params,
276283
model_train_params, model_eval_params)
284+
277285
if new_central_contexts is None:
278286
break
279287
else:
280288
central_contexts = new_central_contexts
281289

290+
if self._checkpointer is not None:
291+
# Need to save new algorithm state here such that if we
292+
# restore from a crashed experiment, we restore the current
293+
# central iteration, and not from last time checkpointing
294+
# callback was called.
295+
self._checkpointer.invoke_save(self)
296+
282297
if not has_reported_on_train_metrics:
283298
all_metrics |= on_train_metrics
284299
has_reported_on_train_metrics = True

tests/algorithm/test_fedavg.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright © 2023-2024 Apple Inc.
2+
import json
23
import math
4+
import os
35
import pathlib
46
from unittest.mock import MagicMock, call, patch
57

@@ -9,7 +11,7 @@
911

1012
from pfl.algorithm import FederatedAveraging, NNAlgorithmParams
1113
from pfl.callback.checkpoint import ModelCheckpointingCallback
12-
from pfl.common_types import Population
14+
from pfl.common_types import LocalDiskCheckpointer, Population
1315
from pfl.context import CentralContext
1416
from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams
1517
from pfl.internal.ops.common_ops import get_pytorch_major_version, get_tf_major_version
@@ -403,3 +405,66 @@ def on_train_begin(model):
403405
assert 'on_train_begin' not in mock_platform.consume_metrics.call_args_list[
404406
4][0][0]
405407
assert mock_platform.consume_metrics.call_count == 5
408+
409+
@patch('pfl.algorithm.base.get_platform')
410+
def test_checkpoint_during_gather_results_crash(self, mock_get_platform,
411+
fedavg_setup, mock_model,
412+
nn_eval_params,
413+
nn_train_params, tmp_path):
414+
"""
415+
Test that checkpoint is saved with current_central_iteration before
416+
gather_results, so if gather_results crashes, the restored checkpoint
417+
has the correct iteration number (not iteration - 1).
418+
"""
419+
420+
mock_platform = MagicMock()
421+
mock_get_platform.return_value = mock_platform
422+
423+
algo = FederatedAveraging()
424+
checkpointer = LocalDiskCheckpointer(str(tmp_path))
425+
algo.set_checkpointer(checkpointer)
426+
427+
mock_backend = MagicMock()
428+
429+
async def mock_gather_results_with_crash(*args, **kwargs):
430+
central_context = kwargs['central_context']
431+
# Crash on iteration=2
432+
if central_context.current_central_iteration == 2:
433+
checkpoint_path = os.path.join(tmp_path,
434+
'algorithm_checkpoint.json')
435+
assert os.path.exists(checkpoint_path)
436+
with open(checkpoint_path) as f:
437+
state = json.load(f)
438+
assert state[
439+
'current_central_iteration'] == central_context.current_central_iteration
440+
441+
raise RuntimeError("Simulated crash during gather_results")
442+
443+
if central_context.population == Population.TRAIN:
444+
stats = MappedVectorStatistics(
445+
{'var1': np.ones((2, 3))},
446+
weight=central_context.cohort_size)
447+
else:
448+
stats = None
449+
return stats, Metrics()
450+
451+
mock_backend.async_gather_results.side_effect = mock_gather_results_with_crash
452+
453+
with pytest.raises(RuntimeError,
454+
match="Simulated crash during gather_results"):
455+
algo.run(algorithm_params=fedavg_setup['algorithm_params'],
456+
backend=mock_backend,
457+
model=mock_model,
458+
model_train_params=nn_train_params,
459+
model_eval_params=nn_eval_params,
460+
callbacks=[])
461+
462+
checkpoint_path = os.path.join(tmp_path, 'algorithm_checkpoint.json')
463+
with open(checkpoint_path) as f:
464+
state = json.load(f)
465+
assert state['current_central_iteration'] == 2
466+
467+
# Restore from checkpoint
468+
restored_algo = FederatedAveraging()
469+
restored_algo.load(str(tmp_path))
470+
assert restored_algo._current_central_iteration == 2

0 commit comments

Comments
 (0)