Skip to content

Commit 6210fe2

Browse files
Also pyright ignore string types
1 parent 961792d commit 6210fe2

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

sbi/samplers/score/diffuser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Diffuser:
1717

1818
def __init__(
1919
self,
20-
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821
20+
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
2121
predictor: Union[str, Predictor],
2222
corrector: Optional[Union[str, Corrector]] = None,
2323
predictor_params: Optional[dict] = None,
@@ -62,7 +62,7 @@ def __init__(
6262
def set_predictor(
6363
self,
6464
predictor: Union[str, Predictor],
65-
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821
65+
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
6666
**kwargs,
6767
):
6868
"""Set the predictor for the diffusion-based sampler."""

sbi/samplers/score/predictors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
def get_predictor(
1414
name: str,
15-
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821
15+
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
1616
**kwargs,
1717
) -> "Predictor":
1818
"""Helper function to get predictor by name.
@@ -52,7 +52,7 @@ class Predictor(ABC):
5252

5353
def __init__(
5454
self,
55-
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821
55+
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
5656
):
5757
"""Initialize predictor.
5858
@@ -92,7 +92,7 @@ def predict(self, theta: Tensor, t1: Tensor, t0: Tensor) -> Tensor:
9292
class EulerMaruyama(Predictor):
9393
def __init__(
9494
self,
95-
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821
95+
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
9696
eta: float = 1.0,
9797
):
9898
"""Simple Euler-Maruyama discretization of the associated family of reverse

sbi/samplers/vi/vi_divergence_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class DivergenceOptimizer(ABC):
4646

4747
def __init__(
4848
self,
49-
potential_fn: 'BasePotential', # noqa: F821
49+
potential_fn: 'BasePotential', # noqa: F821 # type: ignore
5050
q: PyroTransformedDistribution,
5151
prior: Optional[Distribution] = None,
5252
n_particles: int = 256,

0 commit comments

Comments
 (0)