Skip to content

Commit 1b78d06

Browse files
authored
refactor: rename VectorFieldInference to VectorFieldTrainer (#1614)
1 parent ce31030 commit 1b78d06

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

sbi/inference/trainers/fmpe/fmpe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from sbi.inference.posteriors.vector_field_posterior import VectorFieldPosterior
1212
from sbi.inference.trainers.npse.vector_field_inference import (
1313
VectorFieldEstimatorBuilder,
14-
VectorFieldInference,
14+
VectorFieldTrainer,
1515
)
1616
from sbi.neural_nets import flowmatching_nn
1717
from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator
1818

1919

20-
class FMPE(VectorFieldInference):
20+
class FMPE(VectorFieldTrainer):
2121
"""Flow Matching Posterior Estimation (FMPE)."""
2222

2323
def __init__(

sbi/inference/trainers/npse/npse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from sbi.inference.posteriors.vector_field_posterior import VectorFieldPosterior
1010
from sbi.inference.trainers.npse.vector_field_inference import (
1111
VectorFieldEstimatorBuilder,
12-
VectorFieldInference,
12+
VectorFieldTrainer,
1313
)
1414
from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator
1515
from sbi.neural_nets.factory import posterior_score_nn
1616

1717

18-
class NPSE(VectorFieldInference):
18+
class NPSE(VectorFieldTrainer):
1919
"""Neural Posterior Score Estimation as in Geffner et al. and Sharrock et al.
2020
2121
Instead of performing conditonal *density* estimation, NPSE methods perform

sbi/inference/trainers/npse/vector_field_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __call__(self, theta: Tensor, x: Tensor) -> ConditionalVectorFieldEstimator:
5151
...
5252

5353

54-
class VectorFieldInference(NeuralInference, ABC):
54+
class VectorFieldTrainer(NeuralInference, ABC):
5555
def __init__(
5656
self,
5757
prior: Optional[Distribution] = None,
@@ -119,7 +119,7 @@ def append_simulations(
119119
proposal: Optional[DirectPosterior] = None,
120120
exclude_invalid_x: Optional[bool] = None,
121121
data_device: Optional[str] = None,
122-
) -> "VectorFieldInference":
122+
) -> "VectorFieldTrainer":
123123
r"""Store parameters and simulation outputs to use them for later training.
124124
125125
Data are stored as entries in lists for each type of variable (parameter/data).
@@ -146,7 +146,7 @@ def append_simulations(
146146
much VRAM can set to 'cpu' to store data on system memory instead.
147147
148148
Returns:
149-
VectorFieldInference object (returned so that this function is chainable).
149+
VectorFieldTrainer object (returned so that this function is chainable).
150150
"""
151151
inference_name = self.__class__.__name__
152152
assert proposal is None, (

0 commit comments

Comments
 (0)