11from __future__ import annotations
2+
23from collections .abc import Callable
3- from typing import Any , cast , TYPE_CHECKING
4- from scipy .integrate import nquad , quad
4+ from dataclasses import dataclass
5+ from typing import TYPE_CHECKING , Any , Iterable , Sized , cast
6+
57import numpy as np
8+ from scipy .integrate import nquad
9+ from scipy .linalg import det
10+ from scipy .differentiate import jacobian
611
7- from pysatl_core .distributions .fitters import _ppf_brentq_from_cdf
12+ from pysatl_core .distributions import (
13+ SamplingStrategy ,
14+ )
815from pysatl_core .families .parametric_family import ParametricFamily
916from pysatl_core .families .parametrizations import Parametrization , parametrization
1017from pysatl_core .types import (
18+ GenericCharacteristicName ,
1119 DistributionType ,
1220 ParametrizationName ,
1321)
14- from pysatl_core .distributions import (
15- SamplingStrategy ,
16- )
1722
1823if TYPE_CHECKING :
1924 from pysatl_core .distributions .support import Support
3237KURT = "kurtosis"
3338
3439
40+ @dataclass
3541class ExponentialFamilyParametrization (Parametrization ):
3642 """
3743 Standard parametrization of Exponential Family.
3844 """
3945
40- theta : list [Callable [[ float ], float ] ] # TODO: mb more clever
46+ theta : list [float ] # TODO: mb more clever
4147
4248
4349class ExponentialConjugateHyperparameters :
4450 def __init__ (self , alpha : Any , beta : int ):
4551 self .alpha = alpha
4652 self .beta = beta
4753
48- def __str__ (self ):
54+ def __str__ (self ) -> str :
4955 return f"alpha={ self .alpha } , beta={ self .beta } "
5056
5157
52- def doesAccept (x , support ) :
58+ def doesAccept (x : list [ float ] | float , support : list [ tuple [ float , float ]]) -> bool :
5359 if not hasattr (x , "__len__" ):
5460 x = [x ]
5561
56- def accept_1D (x , borders ):
62+ x = cast (list [float ], x )
63+
64+ def accept_1D (x : float , borders : tuple [float , float ]) -> bool :
5765 left , right = borders
5866 if abs (x ) == 0 and (abs (left ) == 0 or abs (right ) == 0 ):
5967 return False
6068 return left <= x <= right
6169
62- return all (accept_1D (x_i , border ) for x_i , border in zip (x , support ))
70+ return all (accept_1D (x_i , border ) for x_i , border in zip (x , support , strict = False ))
6371
6472
6573class SpacePredicate :
@@ -100,7 +108,10 @@ def __init__(
100108 self ._parameter_space = parameter_space
101109 self ._sufficient_statistics_values = sufficient_statistics_values
102110
103- distr_characteristics = {
111+ distr_characteristics : dict [
112+ GenericCharacteristicName ,
113+ dict [ParametrizationName , ParametrizedFunction ] | ParametrizedFunction ,
114+ ] = {
104115 PDF : self .density ,
105116 MEAN : self ._mean ,
106117 VAR : self ._var ,
@@ -115,20 +126,29 @@ def __init__(
115126 sampling_strategy = sampling_strategy ,
116127 support_by_parametrization = support_by_parametrization ,
117128 )
118- parametrization (family = self , name = "theta" )((ExponentialFamilyParametrization ))
129+ parametrization (family = self , name = "theta" )(ExponentialFamilyParametrization )
130+
131+ def _transform_to_natural_parametrization (
132+ self , theta_parametrization : ExponentialFamilyParametrization
133+ ) -> ExponentialFamilyParametrization :
134+ return theta_parametrization
119135
120136 @property
121137 def log_density (self ) -> ParametrizedFunction :
122- def log_density_func (
123- parametrization : ExponentialFamilyParametrization , x : Any
124- ) -> Any :
138+ def log_density_func (parametrization : Parametrization , x : Any ) -> Any :
139+ parametrization = cast (ExponentialFamilyParametrization , parametrization )
140+ parametrization = self ._transform_to_natural_parametrization (
141+ parametrization
142+ )
125143 if not self ._support .accepts (x ):
126144 return float ("-inf" )
127145
128- params = cast (ExponentialFamilyParametrization , parametrization )
129- theta = params .parameters .get ("theta" )
146+ theta = parametrization .theta
130147 sufficient = self ._sufficient (x )
131148 dot = np .dot (theta , sufficient )
149+ if hasattr (dot , "__len__" ):
150+ dot = dot [0 ]
151+
132152 result = float (
133153 np .log (self ._normalization (x ))
134154 + dot
@@ -143,26 +163,31 @@ def density(self) -> ParametrizedFunction:
143163 return lambda parametrization , x : np .exp (self .log_density (parametrization , x ))
144164
145165 @property
146- def conjugate_prior_family (self ):
147- def conjugate_sufficient (theta : Any ):
166+ def conjugate_prior_family (self ) -> NaturalExponentialFamily :
167+ def conjugate_sufficient (
168+ theta : float ,
169+ ) -> list [Any ]:
148170 if not self ._parameter_space .accepts (theta ):
149171 return [float ("-inf" ), float ("-inf" )]
150172
173+ parametrization = ExponentialFamilyParametrization ([theta ])
174+ # parametrization.theta = [theta]
151175 return [
152176 theta ,
153- self ._log_partition (ExponentialFamilyParametrization ( theta = [ theta ]) ),
177+ self ._log_partition (parametrization ),
154178 ]
155179
156- def conjugate_log_partition (parametrization : ExponentialFamilyParametrization ):
180+ def conjugate_log_partition (
181+ parametrization : ExponentialFamilyParametrization ,
182+ ) -> Any :
157183 alpha = parametrization .theta [0 ]
158184 beta = parametrization .theta [1 ]
159185
160- def pdf (theta : Any ):
186+ def pdf (theta : Any ) -> Any :
161187 if not hasattr (theta , "__len__" ):
162188 theta = [theta ]
163- parametrization = ExponentialFamilyParametrization (
164- theta = theta ,
165- )
189+ parametrization = ExponentialFamilyParametrization (theta = theta )
190+ # parametrization.theta = theta
166191 return np .exp (
167192 np .dot (theta , alpha ) + beta * self ._log_partition (parametrization )
168193 )[0 ]
@@ -180,15 +205,14 @@ def pdf(theta: Any):
180205
181206 def conjugate_sufficient_accepts (
182207 parametrization : ExponentialFamilyParametrization ,
183- ):
184- parametrization = cast (parametrization , ExponentialFamilyParametrization )
185- theta = parametrization .parameters .get ("theta" )
208+ ) -> bool :
209+ theta = parametrization .theta
186210 xi = theta [:- 1 ]
187211 nu = theta [- 1 ]
188212
189- return self ._sufficient_statistics_values ( xi ) and SpacePredicateArray (
190- [( 0 , float ( "+inf" ))]
191- ).accepts (nu )
213+ return self ._sufficient_statistics_values . accepts (
214+ xi
215+ ) and SpacePredicateArray ([( 0 , float ( "+inf" ))]) .accepts (nu )
192216
193217 return NaturalExponentialFamily (
194218 log_partition = conjugate_log_partition ,
@@ -197,15 +221,50 @@ def conjugate_sufficient_accepts(
197221 support = self ._parameter_space ,
198222 sufficient_statistics_values = self ._parameter_space , # TODO: write convex hull for this
199223 parameter_space = SpacePredicate (conjugate_sufficient_accepts ),
224+ name = self .name ,
200225 sampling_strategy = self .sampling_strategy ,
201226 distr_type = self ._distr_type ,
202227 distr_parametrizations = self .parametrization_names ,
203228 support_by_parametrization = self .support_resolver ,
204229 )
205230
231+ def transform (
232+ self ,
233+ transform_function : Callable [[Any ], Any ],
234+ ) -> NaturalExponentialFamily :
235+ def calculate_jacobian (x : Any ) -> Any :
236+ if type (x ) is not list :
237+ x = np .array ([x ])
238+
239+ return np .abs (det (jacobian (transform_function , x ).df ))
240+
241+ def new_support (x : Any ) -> bool :
242+ return self ._support .accepts (transform_function (x ))
243+
244+ def new_sufficient (x : Any ) -> Any :
245+ return self ._sufficient (transform_function (x ))
246+
247+ def new_normalization (x : Any ) -> Any :
248+ return self ._normalization (x ) * calculate_jacobian (x )
249+
250+ return NaturalExponentialFamily (
251+ log_partition = self ._log_partition ,
252+ sufficient_statistics = new_sufficient ,
253+ normalization_constant = new_normalization ,
254+ support = SpacePredicate (new_support ),
255+ parameter_space = self ._parameter_space ,
256+ sufficient_statistics_values = self ._sufficient_statistics_values ,
257+ name = f"Transformed{ self ._name } " ,
258+ distr_type = self ._distr_type ,
259+ distr_parametrizations = self .parametrization_names ,
260+ sampling_strategy = self .sampling_strategy ,
261+ support_by_parametrization = self .support_resolver ,
262+ )
263+
206264 @property
207265 def _mean (self ) -> ParametrizedFunction :
208266 def mean_func (parametrization : Parametrization , x : Any ) -> Any :
267+ parametrization = cast (ExponentialFamilyParametrization , parametrization )
209268 dimension_size = 1
210269 if hasattr (x , "__len__" ):
211270 dimension_size = len (x )
@@ -223,6 +282,7 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
223282 @property
224283 def _second_moment (self ) -> ParametrizedFunction :
225284 def func (parametrization : Parametrization , x : Any ) -> Any :
285+ parametrization = cast (ExponentialFamilyParametrization , parametrization )
226286 dimension_size = 1
227287 if hasattr (x , "__len__" ):
228288 dimension_size = len (x )
@@ -238,8 +298,9 @@ def func(parametrization: Parametrization, x: Any) -> Any:
238298 return func
239299
240300 @property
241- def _var (self ):
242- def func (parametrization , x : Any ):
301+ def _var (self ) -> ParametrizedFunction :
302+ def func (parametrization : Parametrization , x : Any ) -> Any :
303+ parametrization = cast (ExponentialFamilyParametrization , parametrization )
243304 return (
244305 self ._second_moment (parametrization , x )
245306 - self ._mean (parametrization , x ) ** 2
@@ -248,8 +309,8 @@ def func(parametrization, x: Any):
248309 return func
249310
250311 def posterior_hyperparameters (
251- self , prior_hyper : ExponentialConjugateHyperparameters , sample
252- ):
312+ self , prior_hyper : ExponentialConjugateHyperparameters , sample : list [ Any ]
313+ ) -> ExponentialConjugateHyperparameters :
253314 alpha = prior_hyper .alpha
254315 beta = prior_hyper .beta
255316
@@ -275,6 +336,9 @@ def __init__(
275336 sufficient_statistics : Callable [[Any ], Any ],
276337 normalization_constant : Callable [[Any ], Any ],
277338 parameter_from_natural_parameter : Callable [[Any ], Any ],
339+ natural_parameter : Callable [
340+ [ExponentialFamilyParametrization ], ExponentialFamilyParametrization
341+ ],
278342 support : SpacePredicate ,
279343 parameter_space : SpacePredicate ,
280344 sufficient_statistics_values : SpacePredicate ,
@@ -284,11 +348,10 @@ def __init__(
284348 name : str = "ExponentialFamily" ,
285349 support_by_parametrization : SupportArg = None ,
286350 ):
287- def natural_log_partition (eta_parametrizaion : ExponentialFamilyParametrization ):
288- eta_parametrizaion = cast (
289- ExponentialFamilyParametrization , eta_parametrizaion
290- )
291- eta = eta_parametrizaion .parameters .get ("theta" )
351+ def natural_log_partition (
352+ eta_parametrizaion : ExponentialFamilyParametrization ,
353+ ) -> Any :
354+ eta = eta_parametrizaion .theta
292355 theta = parameter_from_natural_parameter (eta )
293356 return log_partition (ExponentialFamilyParametrization (theta = [theta ]))
294357
@@ -297,6 +360,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization):
297360 parameter_from_natural_parameter (eta )
298361 )
299362 )
363+
364+ self ._natural_parameter = natural_parameter
300365 natural_parameter_space = SpacePredicate (
301366 lambda eta : parameter_space .accepts (parameter_from_natural_parameter (eta )),
302367 )
@@ -315,3 +380,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization):
315380 sampling_strategy = sampling_strategy ,
316381 support_by_parametrization = support_by_parametrization ,
317382 )
383+
384+ def _transform_to_natural_parametrization (
385+ self , theta_parametrization : ExponentialFamilyParametrization
386+ ) -> ExponentialFamilyParametrization :
387+ return self ._natural_parameter (theta_parametrization )
0 commit comments