Skip to content

Commit 0eae845

Browse files
author
domosedy
committed
[refactor] refactoring for mypy
1 parent 3186e96 commit 0eae845

File tree

2 files changed

+29
-62
lines changed

2 files changed

+29
-62
lines changed

src/pysatl_core/families/exponential_family.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Iterable, Sized, cast
5+
from typing import TYPE_CHECKING, Any, cast
66

77
import numpy as np
8+
from scipy.differentiate import jacobian
89
from scipy.integrate import nquad
910
from scipy.linalg import det
10-
from scipy.differentiate import jacobian
1111

1212
from pysatl_core.distributions import (
1313
SamplingStrategy,
1414
)
1515
from pysatl_core.families.parametric_family import ParametricFamily
1616
from pysatl_core.families.parametrizations import Parametrization, parametrization
1717
from pysatl_core.types import (
18-
GenericCharacteristicName,
1918
DistributionType,
19+
GenericCharacteristicName,
2020
ParametrizationName,
2121
)
2222

@@ -137,9 +137,7 @@ def _transform_to_natural_parametrization(
137137
def log_density(self) -> ParametrizedFunction:
138138
def log_density_func(parametrization: Parametrization, x: Any) -> Any:
139139
parametrization = cast(ExponentialFamilyParametrization, parametrization)
140-
parametrization = self._transform_to_natural_parametrization(
141-
parametrization
142-
)
140+
parametrization = self._transform_to_natural_parametrization(parametrization)
143141
if not self._support.accepts(x):
144142
return float("-inf")
145143

@@ -150,9 +148,7 @@ def log_density_func(parametrization: Parametrization, x: Any) -> Any:
150148
dot = dot[0]
151149

152150
result = float(
153-
np.log(self._normalization(x))
154-
+ dot
155-
+ self._log_partition(parametrization)
151+
np.log(self._normalization(x)) + dot + self._log_partition(parametrization)
156152
)
157153
return result
158154

@@ -188,31 +184,24 @@ def pdf(theta: Any) -> Any:
188184
theta = [theta]
189185
parametrization = ExponentialFamilyParametrization(theta=theta)
190186
# parametrization.theta = theta
191-
return np.exp(
192-
np.dot(theta, alpha) + beta * self._log_partition(parametrization)
193-
)[0]
187+
return np.exp(np.dot(theta, alpha) + beta * self._log_partition(parametrization))[0]
194188

195189
all_value = nquad(
196-
lambda x: pdf(x) if self._parameter_space.accepts(x) else 0,
190+
lambda x: pdf(x) if self._parameter_space.accepts(x) else 0, # type: ignore[arg-type]
197191
[(float("-inf"), float("+inf"))],
198192
)[0]
199193
return -np.log(all_value)
200194

201-
# TODO: remove hardcoding - Done, all hardcoding is only on user's hands
202-
# 1. pr with prototype/draft - in progress
203-
# 2. write instruction about to add distributions as member of exponential family - not started
204-
# 3. parametrization's spaces (передавать в конструктор) - maybe impossible, discuss this with desiment on meeting
205-
206195
def conjugate_sufficient_accepts(
207196
parametrization: ExponentialFamilyParametrization,
208197
) -> bool:
209198
theta = parametrization.theta
210199
xi = theta[:-1]
211200
nu = theta[-1]
212201

213-
return self._sufficient_statistics_values.accepts(
214-
xi
215-
) and SpacePredicateArray([(0, float("+inf"))]).accepts(nu)
202+
return self._sufficient_statistics_values.accepts(xi) and SpacePredicateArray(
203+
[(0, float("+inf"))]
204+
).accepts(nu)
216205

217206
return NaturalExponentialFamily(
218207
log_partition=conjugate_log_partition,
@@ -269,10 +258,8 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
269258
if hasattr(x, "__len__"):
270259
dimension_size = len(x)
271260
return nquad(
272-
lambda x: (
273-
np.dot(x, self.density(parametrization, x))
274-
if self._support.accepts(x)
275-
else 0
261+
lambda x: ( # type: ignore[arg-type]
262+
np.dot(x, self.density(parametrization, x)) if self._support.accepts(x) else 0
276263
),
277264
[(float("-inf"), float("inf"))] * dimension_size,
278265
)[0]
@@ -287,10 +274,8 @@ def func(parametrization: Parametrization, x: Any) -> Any:
287274
if hasattr(x, "__len__"):
288275
dimension_size = len(x)
289276
return nquad(
290-
lambda x: (
291-
x**2 * self.density(parametrization, x)
292-
if self._support.accepts(x)
293-
else 0
277+
lambda x: ( # type: ignore[arg-type]
278+
x**2 * self.density(parametrization, x) if self._support.accepts(x) else 0
294279
),
295280
[(float("-inf"), float("inf"))] * dimension_size,
296281
)[0]
@@ -301,10 +286,7 @@ def func(parametrization: Parametrization, x: Any) -> Any:
301286
def _var(self) -> ParametrizedFunction:
302287
def func(parametrization: Parametrization, x: Any) -> Any:
303288
parametrization = cast(ExponentialFamilyParametrization, parametrization)
304-
return (
305-
self._second_moment(parametrization, x)
306-
- self._mean(parametrization, x) ** 2
307-
)
289+
return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2
308290

309291
return func
310292

@@ -323,9 +305,7 @@ def posterior_hyperparameters(
323305
alpha_post = self._sufficient(sample)
324306
beta_post = 1
325307

326-
return ExponentialConjugateHyperparameters(
327-
alpha=alpha + alpha_post, beta=beta + beta_post
328-
)
308+
return ExponentialConjugateHyperparameters(alpha=alpha + alpha_post, beta=beta + beta_post)
329309

330310

331311
class ExponentialFamily(NaturalExponentialFamily):
@@ -356,9 +336,7 @@ def natural_log_partition(
356336
return log_partition(ExponentialFamilyParametrization(theta=[theta]))
357337

358338
natural_sufficient_statistics_values = SpacePredicate(
359-
lambda eta: sufficient_statistics_values.accepts(
360-
parameter_from_natural_parameter(eta)
361-
)
339+
lambda eta: sufficient_statistics_values.accepts(parameter_from_natural_parameter(eta))
362340
)
363341

364342
self._natural_parameter = natural_parameter

tests/unit/families/test_exponential_family.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717

1818
def gamma_pdf(alpha: float, beta: float, x: float) -> float:
19-
return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item()
19+
return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() # type: ignore[attr-defined]
2020

2121

2222
@pytest.fixture(scope="function")
2323
def conjugate_for_exponential() -> ExponentialFamily:
2424
def get_parameter_from_natural_parameter(
25-
eta_parametrization: ExponentialFamilyParametrization,
26-
):
25+
eta_parametrization: Any,
26+
) -> Any:
2727
if hasattr(eta_parametrization, "__len__"):
2828
if len(eta_parametrization) > 1:
2929
return list(-1 * np.array(eta_parametrization))
@@ -34,18 +34,15 @@ def natural_parameter(
3434
theta_parametrization: Any,
3535
) -> Any:
3636
if type(theta_parametrization) is ExponentialFamilyParametrization:
37-
theta_parametrization = cast(
38-
ExponentialFamilyParametrization, theta_parametrization
39-
)
40-
eta = -theta_parametrization.theta
37+
eta = list(-np.array(theta_parametrization.theta))
4138
return ExponentialFamilyParametrization(theta=eta)
4239

4340
return -1 * theta_parametrization
4441

45-
def transform_function(x: list[Any]) -> list[Any]:
46-
if type(x) is not list:
47-
return -x
48-
return [-x[0]]
42+
def transform_function(x: list[float] | float) -> list[float] | float:
43+
if type(x) is list:
44+
return [-x[0]]
45+
return -x # type: ignore[operator]
4946

5047
fam = ExponentialFamily(
5148
log_partition=lambda parametrization: np.log(parametrization.theta[0]),
@@ -77,16 +74,12 @@ def test_exponential_pdf(theta1, theta2, conjugate_for_exponential):
7774
alpha = theta2 + 1
7875
beta = theta1
7976

80-
exponential = gamma_family(
81-
theta=np.array([theta1, theta2]), parametrization_name="theta"
82-
)
77+
exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta")
8378
pdf = exponential.computation_strategy.query_method("pdf", distr=exponential)
8479

8580
x = [i / 10 for i in range(100)]
8681

87-
assert_allclose(
88-
[pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6
89-
)
82+
assert_allclose([pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6)
9083

9184

9285
@pytest.mark.parametrize("theta1", range(2, 5))
@@ -97,9 +90,7 @@ def test_exponential_mean(theta1, theta2, conjugate_for_exponential):
9790
alpha = theta2 + 1
9891
beta = theta1
9992

100-
exponential = gamma_family(
101-
theta=np.array([theta1, theta2]), parametrization_name="theta"
102-
)
93+
exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta")
10394
mean = exponential.computation_strategy.query_method("mean", distr=exponential)
10495
assert np.isclose(mean(12), alpha / beta, rtol=1e-6)
10596

@@ -112,8 +103,6 @@ def test_exponential_var(theta1, theta2, conjugate_for_exponential):
112103
alpha = theta2 + 1
113104
beta = theta1
114105

115-
exponential = gamma_family(
116-
theta=np.array([theta1, theta2]), parametrization_name="theta"
117-
)
106+
exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta")
118107
var = exponential.computation_strategy.query_method("var", distr=exponential)
119108
assert np.isclose(var(12), alpha / beta**2, rtol=1e-6)

0 commit comments

Comments
 (0)