Skip to content

Commit 3186e96

Browse files
author
domosedy
committed
[feat] added transform method to ExponentialFamily
1 parent 73da345 commit 3186e96

File tree

3 files changed

+184
-96
lines changed

3 files changed

+184
-96
lines changed

src/pysatl_core/families/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@
1414
from .builtins import __all__ as _builtins_all
1515
from .configuration import configure_families_register
1616
from .distribution import ParametricFamilyDistribution
17-
from .parametric_family import ParametricFamily
18-
from .parametrizations import (
19-
Parametrization,
20-
ParametrizationConstraint,
21-
constraint,
22-
parametrization,
23-
)
2417
from .exponential_family import (
2518
ExponentialConjugateHyperparameters,
2619
ExponentialFamily,
@@ -29,6 +22,13 @@
2922
SpacePredicate,
3023
SpacePredicateArray,
3124
)
25+
from .parametric_family import ParametricFamily
26+
from .parametrizations import (
27+
Parametrization,
28+
ParametrizationConstraint,
29+
constraint,
30+
parametrization,
31+
)
3232
from .registry import ParametricFamilyRegister
3333

3434
__all__ = [

src/pysatl_core/families/exponential_family.py

Lines changed: 111 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
from __future__ import annotations
2+
23
from collections.abc import Callable
3-
from typing import Any, cast, TYPE_CHECKING
4-
from scipy.integrate import nquad, quad
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, Any, Iterable, Sized, cast
6+
57
import numpy as np
8+
from scipy.integrate import nquad
9+
from scipy.linalg import det
10+
from scipy.differentiate import jacobian
611

7-
from pysatl_core.distributions.fitters import _ppf_brentq_from_cdf
12+
from pysatl_core.distributions import (
13+
SamplingStrategy,
14+
)
815
from pysatl_core.families.parametric_family import ParametricFamily
916
from pysatl_core.families.parametrizations import Parametrization, parametrization
1017
from pysatl_core.types import (
18+
GenericCharacteristicName,
1119
DistributionType,
1220
ParametrizationName,
1321
)
14-
from pysatl_core.distributions import (
15-
SamplingStrategy,
16-
)
1722

1823
if TYPE_CHECKING:
1924
from pysatl_core.distributions.support import Support
@@ -32,34 +37,37 @@
3237
KURT = "kurtosis"
3338

3439

40+
@dataclass
3541
class ExponentialFamilyParametrization(Parametrization):
3642
"""
3743
Standard parametrization of Exponential Family.
3844
"""
3945

40-
theta: list[Callable[[float], float]] # TODO: mb more clever
46+
theta: list[float] # TODO: mb more clever
4147

4248

4349
class ExponentialConjugateHyperparameters:
4450
def __init__(self, alpha: Any, beta: int):
4551
self.alpha = alpha
4652
self.beta = beta
4753

48-
def __str__(self):
54+
def __str__(self) -> str:
4955
return f"alpha={self.alpha}, beta={self.beta}"
5056

5157

52-
def doesAccept(x, support):
58+
def doesAccept(x: list[float] | float, support: list[tuple[float, float]]) -> bool:
5359
if not hasattr(x, "__len__"):
5460
x = [x]
5561

56-
def accept_1D(x, borders):
62+
x = cast(list[float], x)
63+
64+
def accept_1D(x: float, borders: tuple[float, float]) -> bool:
5765
left, right = borders
5866
if abs(x) == 0 and (abs(left) == 0 or abs(right) == 0):
5967
return False
6068
return left <= x <= right
6169

62-
return all(accept_1D(x_i, border) for x_i, border in zip(x, support))
70+
return all(accept_1D(x_i, border) for x_i, border in zip(x, support, strict=False))
6371

6472

6573
class SpacePredicate:
@@ -100,7 +108,10 @@ def __init__(
100108
self._parameter_space = parameter_space
101109
self._sufficient_statistics_values = sufficient_statistics_values
102110

103-
distr_characteristics = {
111+
distr_characteristics: dict[
112+
GenericCharacteristicName,
113+
dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction,
114+
] = {
104115
PDF: self.density,
105116
MEAN: self._mean,
106117
VAR: self._var,
@@ -115,20 +126,29 @@ def __init__(
115126
sampling_strategy=sampling_strategy,
116127
support_by_parametrization=support_by_parametrization,
117128
)
118-
parametrization(family=self, name="theta")((ExponentialFamilyParametrization))
129+
parametrization(family=self, name="theta")(ExponentialFamilyParametrization)
130+
131+
def _transform_to_natural_parametrization(
132+
self, theta_parametrization: ExponentialFamilyParametrization
133+
) -> ExponentialFamilyParametrization:
134+
return theta_parametrization
119135

120136
@property
121137
def log_density(self) -> ParametrizedFunction:
122-
def log_density_func(
123-
parametrization: ExponentialFamilyParametrization, x: Any
124-
) -> Any:
138+
def log_density_func(parametrization: Parametrization, x: Any) -> Any:
139+
parametrization = cast(ExponentialFamilyParametrization, parametrization)
140+
parametrization = self._transform_to_natural_parametrization(
141+
parametrization
142+
)
125143
if not self._support.accepts(x):
126144
return float("-inf")
127145

128-
params = cast(ExponentialFamilyParametrization, parametrization)
129-
theta = params.parameters.get("theta")
146+
theta = parametrization.theta
130147
sufficient = self._sufficient(x)
131148
dot = np.dot(theta, sufficient)
149+
if hasattr(dot, "__len__"):
150+
dot = dot[0]
151+
132152
result = float(
133153
np.log(self._normalization(x))
134154
+ dot
@@ -143,26 +163,31 @@ def density(self) -> ParametrizedFunction:
143163
return lambda parametrization, x: np.exp(self.log_density(parametrization, x))
144164

145165
@property
146-
def conjugate_prior_family(self):
147-
def conjugate_sufficient(theta: Any):
166+
def conjugate_prior_family(self) -> NaturalExponentialFamily:
167+
def conjugate_sufficient(
168+
theta: float,
169+
) -> list[Any]:
148170
if not self._parameter_space.accepts(theta):
149171
return [float("-inf"), float("-inf")]
150172

173+
parametrization = ExponentialFamilyParametrization([theta])
174+
# parametrization.theta = [theta]
151175
return [
152176
theta,
153-
self._log_partition(ExponentialFamilyParametrization(theta=[theta])),
177+
self._log_partition(parametrization),
154178
]
155179

156-
def conjugate_log_partition(parametrization: ExponentialFamilyParametrization):
180+
def conjugate_log_partition(
181+
parametrization: ExponentialFamilyParametrization,
182+
) -> Any:
157183
alpha = parametrization.theta[0]
158184
beta = parametrization.theta[1]
159185

160-
def pdf(theta: Any):
186+
def pdf(theta: Any) -> Any:
161187
if not hasattr(theta, "__len__"):
162188
theta = [theta]
163-
parametrization = ExponentialFamilyParametrization(
164-
theta=theta,
165-
)
189+
parametrization = ExponentialFamilyParametrization(theta=theta)
190+
# parametrization.theta = theta
166191
return np.exp(
167192
np.dot(theta, alpha) + beta * self._log_partition(parametrization)
168193
)[0]
@@ -180,15 +205,14 @@ def pdf(theta: Any):
180205

181206
def conjugate_sufficient_accepts(
182207
parametrization: ExponentialFamilyParametrization,
183-
):
184-
parametrization = cast(parametrization, ExponentialFamilyParametrization)
185-
theta = parametrization.parameters.get("theta")
208+
) -> bool:
209+
theta = parametrization.theta
186210
xi = theta[:-1]
187211
nu = theta[-1]
188212

189-
return self._sufficient_statistics_values(xi) and SpacePredicateArray(
190-
[(0, float("+inf"))]
191-
).accepts(nu)
213+
return self._sufficient_statistics_values.accepts(
214+
xi
215+
) and SpacePredicateArray([(0, float("+inf"))]).accepts(nu)
192216

193217
return NaturalExponentialFamily(
194218
log_partition=conjugate_log_partition,
@@ -197,15 +221,50 @@ def conjugate_sufficient_accepts(
197221
support=self._parameter_space,
198222
sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this
199223
parameter_space=SpacePredicate(conjugate_sufficient_accepts),
224+
name=self.name,
200225
sampling_strategy=self.sampling_strategy,
201226
distr_type=self._distr_type,
202227
distr_parametrizations=self.parametrization_names,
203228
support_by_parametrization=self.support_resolver,
204229
)
205230

231+
def transform(
232+
self,
233+
transform_function: Callable[[Any], Any],
234+
) -> NaturalExponentialFamily:
235+
def calculate_jacobian(x: Any) -> Any:
236+
if type(x) is not list:
237+
x = np.array([x])
238+
239+
return np.abs(det(jacobian(transform_function, x).df))
240+
241+
def new_support(x: Any) -> bool:
242+
return self._support.accepts(transform_function(x))
243+
244+
def new_sufficient(x: Any) -> Any:
245+
return self._sufficient(transform_function(x))
246+
247+
def new_normalization(x: Any) -> Any:
248+
return self._normalization(x) * calculate_jacobian(x)
249+
250+
return NaturalExponentialFamily(
251+
log_partition=self._log_partition,
252+
sufficient_statistics=new_sufficient,
253+
normalization_constant=new_normalization,
254+
support=SpacePredicate(new_support),
255+
parameter_space=self._parameter_space,
256+
sufficient_statistics_values=self._sufficient_statistics_values,
257+
name=f"Transformed{self._name}",
258+
distr_type=self._distr_type,
259+
distr_parametrizations=self.parametrization_names,
260+
sampling_strategy=self.sampling_strategy,
261+
support_by_parametrization=self.support_resolver,
262+
)
263+
206264
@property
207265
def _mean(self) -> ParametrizedFunction:
208266
def mean_func(parametrization: Parametrization, x: Any) -> Any:
267+
parametrization = cast(ExponentialFamilyParametrization, parametrization)
209268
dimension_size = 1
210269
if hasattr(x, "__len__"):
211270
dimension_size = len(x)
@@ -223,6 +282,7 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
223282
@property
224283
def _second_moment(self) -> ParametrizedFunction:
225284
def func(parametrization: Parametrization, x: Any) -> Any:
285+
parametrization = cast(ExponentialFamilyParametrization, parametrization)
226286
dimension_size = 1
227287
if hasattr(x, "__len__"):
228288
dimension_size = len(x)
@@ -238,8 +298,9 @@ def func(parametrization: Parametrization, x: Any) -> Any:
238298
return func
239299

240300
@property
241-
def _var(self):
242-
def func(parametrization, x: Any):
301+
def _var(self) -> ParametrizedFunction:
302+
def func(parametrization: Parametrization, x: Any) -> Any:
303+
parametrization = cast(ExponentialFamilyParametrization, parametrization)
243304
return (
244305
self._second_moment(parametrization, x)
245306
- self._mean(parametrization, x) ** 2
@@ -248,8 +309,8 @@ def func(parametrization, x: Any):
248309
return func
249310

250311
def posterior_hyperparameters(
251-
self, prior_hyper: ExponentialConjugateHyperparameters, sample
252-
):
312+
self, prior_hyper: ExponentialConjugateHyperparameters, sample: list[Any]
313+
) -> ExponentialConjugateHyperparameters:
253314
alpha = prior_hyper.alpha
254315
beta = prior_hyper.beta
255316

@@ -275,6 +336,9 @@ def __init__(
275336
sufficient_statistics: Callable[[Any], Any],
276337
normalization_constant: Callable[[Any], Any],
277338
parameter_from_natural_parameter: Callable[[Any], Any],
339+
natural_parameter: Callable[
340+
[ExponentialFamilyParametrization], ExponentialFamilyParametrization
341+
],
278342
support: SpacePredicate,
279343
parameter_space: SpacePredicate,
280344
sufficient_statistics_values: SpacePredicate,
@@ -284,11 +348,10 @@ def __init__(
284348
name: str = "ExponentialFamily",
285349
support_by_parametrization: SupportArg = None,
286350
):
287-
def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization):
288-
eta_parametrizaion = cast(
289-
ExponentialFamilyParametrization, eta_parametrizaion
290-
)
291-
eta = eta_parametrizaion.parameters.get("theta")
351+
def natural_log_partition(
352+
eta_parametrizaion: ExponentialFamilyParametrization,
353+
) -> Any:
354+
eta = eta_parametrizaion.theta
292355
theta = parameter_from_natural_parameter(eta)
293356
return log_partition(ExponentialFamilyParametrization(theta=[theta]))
294357

@@ -297,6 +360,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization):
297360
parameter_from_natural_parameter(eta)
298361
)
299362
)
363+
364+
self._natural_parameter = natural_parameter
300365
natural_parameter_space = SpacePredicate(
301366
lambda eta: parameter_space.accepts(parameter_from_natural_parameter(eta)),
302367
)
@@ -315,3 +380,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization):
315380
sampling_strategy=sampling_strategy,
316381
support_by_parametrization=support_by_parametrization,
317382
)
383+
384+
def _transform_to_natural_parametrization(
385+
self, theta_parametrization: ExponentialFamilyParametrization
386+
) -> ExponentialFamilyParametrization:
387+
return self._natural_parameter(theta_parametrization)

0 commit comments

Comments
 (0)