|
4 | 4 | import time
|
5 | 5 | from abc import ABC, abstractmethod
|
6 | 6 | from copy import deepcopy
|
7 |
| -from typing import Any, Callable, Optional, Protocol, Tuple, Union |
| 7 | +from typing import Any, Callable, Literal, Optional, Protocol, Tuple, Union |
8 | 8 |
|
9 | 9 | import torch
|
10 | 10 | from torch import Tensor, ones
|
@@ -60,7 +60,10 @@ class VectorFieldTrainer(NeuralInference, ABC):
|
60 | 60 | def __init__(
|
61 | 61 | self,
|
62 | 62 | prior: Optional[Distribution] = None,
|
63 |
| - vector_field_estimator_builder: Union[str, VectorFieldEstimatorBuilder] = "mlp", |
| 63 | + vector_field_estimator_builder: Union[ |
| 64 | + Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"], |
| 65 | + VectorFieldEstimatorBuilder, |
| 66 | + ] = "mlp", |
64 | 67 | device: str = "cpu",
|
65 | 68 | logging_level: Union[int, str] = "WARNING",
|
66 | 69 | summary_writer: Optional[SummaryWriter] = None,
|
@@ -106,15 +109,19 @@ def __init__(
|
106 | 109 | check_estimator_arg(vector_field_estimator_builder)
|
107 | 110 | if isinstance(vector_field_estimator_builder, str):
|
108 | 111 | self._build_neural_net = self._build_default_nn_fn(
|
109 |
| - vector_field_estimator_builder=vector_field_estimator_builder, **kwargs |
| 112 | + model=vector_field_estimator_builder, **kwargs |
110 | 113 | )
|
111 | 114 | else:
|
112 | 115 | self._build_neural_net = vector_field_estimator_builder
|
113 | 116 |
|
114 | 117 | self._proposal_roundwise = []
|
115 | 118 |
|
116 | 119 | @abstractmethod
|
117 |
| - def _build_default_nn_fn(self, **kwargs) -> VectorFieldEstimatorBuilder: |
| 120 | + def _build_default_nn_fn( |
| 121 | + self, |
| 122 | + model: Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"], |
| 123 | + **kwargs, |
| 124 | + ) -> VectorFieldEstimatorBuilder: |
118 | 125 | pass
|
119 | 126 |
|
120 | 127 | def append_simulations(
|
@@ -209,12 +216,13 @@ def train(
|
209 | 216 | training_batch_size: int = 200,
|
210 | 217 | learning_rate: float = 5e-4,
|
211 | 218 | validation_fraction: float = 0.1,
|
212 |
| - stop_after_epochs: int = 50, |
213 |
| - max_num_epochs: int = 500, |
| 219 | + stop_after_epochs: int = 20, |
| 220 | + max_num_epochs: int = 2**31 - 1, |
214 | 221 | clip_max_norm: Optional[float] = 5.0,
|
215 | 222 | calibration_kernel: Optional[Callable] = None,
|
216 | 223 | ema_loss_decay: float = 0.1,
|
217 |
| - validation_times: Union[Tensor, int] = 20, |
| 224 | + validation_times: Union[Tensor, int] = 10, |
| 225 | + validation_times_nugget: float = 0.05, |
218 | 226 | resume_training: bool = False,
|
219 | 227 | force_first_round_loss: bool = False,
|
220 | 228 | discard_prior_samples: bool = False,
|
@@ -253,6 +261,9 @@ def train(
|
253 | 261 | training and validation losses.
|
254 | 262 | validation_times: Diffusion times at which to evaluate the validation loss
|
255 | 263 | to reduce variance of validation loss.
|
| 264 | + validation_times_nugget: As both diffusion and flow matching losses often |
| 265 | + have high variance losses at the end, we add a small nugget to compute |
| 266 | + the validation loss. Default is 0.05 i.e. t_min + 0.05 or t_max - 0.5. |
256 | 267 | resume_training: Can be used in case training time is limited, e.g. on a
|
257 | 268 | cluster. If `True`, the split between train and validation set, the
|
258 | 269 | optimizer, the number of epochs, and the best validation log-prob will
|
@@ -341,7 +352,9 @@ def default_calibration_kernel(x):
|
341 | 352 |
|
342 | 353 | if isinstance(validation_times, int):
|
343 | 354 | validation_times = torch.linspace(
|
344 |
| - self._neural_net.t_min, self._neural_net.t_max, validation_times |
| 355 | + self._neural_net.t_min + validation_times_nugget, |
| 356 | + self._neural_net.t_max - validation_times_nugget, |
| 357 | + validation_times, |
345 | 358 | )
|
346 | 359 | assert isinstance(
|
347 | 360 | validation_times, Tensor
|
@@ -431,6 +444,8 @@ def default_calibration_kernel(x):
|
431 | 444 | times_batch, *([1] * (masks_batch.ndim - 1))
|
432 | 445 | )
|
433 | 446 |
|
| 447 | + # This will repeat the validation times for each batch in the |
| 448 | + # validation set. |
434 | 449 | validation_times_rep = validation_times.repeat_interleave(
|
435 | 450 | val_batch_size, dim=0
|
436 | 451 | )
|
@@ -486,6 +501,76 @@ def default_calibration_kernel(x):
|
486 | 501 |
|
487 | 502 | return deepcopy(self._neural_net)
|
488 | 503 |
|
| 504 | + def _converged(self, epoch: int, stop_after_epochs: int) -> bool: |
| 505 | + """Return whether the training converged yet and save best model state so far. |
| 506 | +
|
| 507 | + Diffusion or flow matching objectives are inherently more stochastic than MLE |
| 508 | + for e.g. NPE because they additionally add "noise" by construction. We hence |
| 509 | + use a statistical approach to detect convergence by tracking standard deviation |
| 510 | + of validation losses. Training is considered converged when the current loss is |
| 511 | + significantly worse than the best loss for a sustained period (more than 2 std |
| 512 | + deviations above best). |
| 513 | +
|
| 514 | + NOTE: The standard deviation of the `validation_loss `is computed in a running |
| 515 | + fashion over the most recent 2 × stop_after_epochs loss values. |
| 516 | +
|
| 517 | + Args: |
| 518 | + epoch: Current epoch in training. |
| 519 | + stop_after_epochs: How many fruitless epochs to let pass before stopping. |
| 520 | +
|
| 521 | + Returns: |
| 522 | + Whether the training has stopped improving, i.e. has converged. |
| 523 | + """ |
| 524 | + converged = False |
| 525 | + |
| 526 | + assert self._neural_net is not None |
| 527 | + neural_net = self._neural_net |
| 528 | + |
| 529 | + # Initialize tracking variables if not exists |
| 530 | + if not hasattr(self, '_best_val_loss'): |
| 531 | + self._best_val_loss = float('inf') |
| 532 | + self._epochs_since_last_improvement = 0 |
| 533 | + self._best_model_state_dict = None |
| 534 | + |
| 535 | + # Check if we have a new best loss |
| 536 | + if self._val_loss < self._best_val_loss: |
| 537 | + self._best_val_loss = self._val_loss |
| 538 | + self._epochs_since_last_improvement = 0 |
| 539 | + self._best_model_state_dict = deepcopy(neural_net.state_dict()) |
| 540 | + else: |
| 541 | + # Only start statistical analysis after we have enough data |
| 542 | + if len(self._summary["validation_loss"]) >= stop_after_epochs: |
| 543 | + # Calculate running statistics of recent losses |
| 544 | + recent_losses = torch.tensor( |
| 545 | + self._summary["validation_loss"][-stop_after_epochs * 2 :] |
| 546 | + ) |
| 547 | + loss_std = recent_losses.std().item() |
| 548 | + |
| 549 | + # Calculate how many standard deviations the current loss is from the |
| 550 | + # best |
| 551 | + diff_to_best_normalized = ( |
| 552 | + self._val_loss - self._best_val_loss |
| 553 | + ) / loss_std |
| 554 | + # Consider it "no improvement" if current loss is significantly |
| 555 | + # worse than the best loss (more than 2 std deviations above best) |
| 556 | + # This accounts for natural fluctuations while being sensitive to |
| 557 | + # real degradation |
| 558 | + if diff_to_best_normalized > 2.0: |
| 559 | + self._epochs_since_last_improvement += 1 |
| 560 | + else: |
| 561 | + # Reset counter if loss is within acceptable range |
| 562 | + self._epochs_since_last_improvement = 0 |
| 563 | + else: |
| 564 | + return False |
| 565 | + |
| 566 | + # If no validation improvement over many epochs, stop training. |
| 567 | + if self._epochs_since_last_improvement > stop_after_epochs - 1: |
| 568 | + if self._best_model_state_dict is not None: |
| 569 | + neural_net.load_state_dict(self._best_model_state_dict) |
| 570 | + converged = True |
| 571 | + |
| 572 | + return converged |
| 573 | + |
489 | 574 | def _get_potential_function(
|
490 | 575 | self, prior: Distribution, estimator: ConditionalVectorFieldEstimator
|
491 | 576 | ) -> Tuple[VectorFieldBasedPotential, TorchTransform]:
|
|
0 commit comments