|
1 | 1 | # Copyright © 2023-2024 Apple Inc. |
| 2 | +import json |
2 | 3 | import math |
| 4 | +import os |
3 | 5 | import pathlib |
4 | 6 | from unittest.mock import MagicMock, call, patch |
5 | 7 |
|
|
9 | 11 |
|
10 | 12 | from pfl.algorithm import FederatedAveraging, NNAlgorithmParams |
11 | 13 | from pfl.callback.checkpoint import ModelCheckpointingCallback |
12 | | -from pfl.common_types import Population |
| 14 | +from pfl.common_types import LocalDiskCheckpointer, Population |
13 | 15 | from pfl.context import CentralContext |
14 | 16 | from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams |
15 | 17 | from pfl.internal.ops.common_ops import get_pytorch_major_version, get_tf_major_version |
@@ -403,3 +405,66 @@ def on_train_begin(model): |
403 | 405 | assert 'on_train_begin' not in mock_platform.consume_metrics.call_args_list[ |
404 | 406 | 4][0][0] |
405 | 407 | assert mock_platform.consume_metrics.call_count == 5 |
| 408 | + |
| 409 | + @patch('pfl.algorithm.base.get_platform') |
| 410 | + def test_checkpoint_during_gather_results_crash(self, mock_get_platform, |
| 411 | + fedavg_setup, mock_model, |
| 412 | + nn_eval_params, |
| 413 | + nn_train_params, tmp_path): |
| 414 | + """ |
| 415 | + Test that checkpoint is saved with current_central_iteration before |
| 416 | + gather_results, so if gather_results crashes, the restored checkpoint |
| 417 | + has the correct iteration number (not iteration - 1). |
| 418 | + """ |
| 419 | + |
| 420 | + mock_platform = MagicMock() |
| 421 | + mock_get_platform.return_value = mock_platform |
| 422 | + |
| 423 | + algo = FederatedAveraging() |
| 424 | + checkpointer = LocalDiskCheckpointer(str(tmp_path)) |
| 425 | + algo.set_checkpointer(checkpointer) |
| 426 | + |
| 427 | + mock_backend = MagicMock() |
| 428 | + |
| 429 | + async def mock_gather_results_with_crash(*args, **kwargs): |
| 430 | + central_context = kwargs['central_context'] |
| 431 | + # Crash on iteration=2 |
| 432 | + if central_context.current_central_iteration == 2: |
| 433 | + checkpoint_path = os.path.join(tmp_path, |
| 434 | + 'algorithm_checkpoint.json') |
| 435 | + assert os.path.exists(checkpoint_path) |
| 436 | + with open(checkpoint_path) as f: |
| 437 | + state = json.load(f) |
| 438 | + assert state[ |
| 439 | + 'current_central_iteration'] == central_context.current_central_iteration |
| 440 | + |
| 441 | + raise RuntimeError("Simulated crash during gather_results") |
| 442 | + |
| 443 | + if central_context.population == Population.TRAIN: |
| 444 | + stats = MappedVectorStatistics( |
| 445 | + {'var1': np.ones((2, 3))}, |
| 446 | + weight=central_context.cohort_size) |
| 447 | + else: |
| 448 | + stats = None |
| 449 | + return stats, Metrics() |
| 450 | + |
| 451 | + mock_backend.async_gather_results.side_effect = mock_gather_results_with_crash |
| 452 | + |
| 453 | + with pytest.raises(RuntimeError, |
| 454 | + match="Simulated crash during gather_results"): |
| 455 | + algo.run(algorithm_params=fedavg_setup['algorithm_params'], |
| 456 | + backend=mock_backend, |
| 457 | + model=mock_model, |
| 458 | + model_train_params=nn_train_params, |
| 459 | + model_eval_params=nn_eval_params, |
| 460 | + callbacks=[]) |
| 461 | + |
| 462 | + checkpoint_path = os.path.join(tmp_path, 'algorithm_checkpoint.json') |
| 463 | + with open(checkpoint_path) as f: |
| 464 | + state = json.load(f) |
| 465 | + assert state['current_central_iteration'] == 2 |
| 466 | + |
| 467 | + # Restore from checkpoint |
| 468 | + restored_algo = FederatedAveraging() |
| 469 | + restored_algo.load(str(tmp_path)) |
| 470 | + assert restored_algo._current_central_iteration == 2 |
0 commit comments