22
33from collections .abc import Callable
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING , Any , Iterable , Sized , cast
5+ from typing import TYPE_CHECKING , Any , cast
66
77import numpy as np
8+ from scipy .differentiate import jacobian
89from scipy .integrate import nquad
910from scipy .linalg import det
10- from scipy .differentiate import jacobian
1111
1212from pysatl_core .distributions import (
1313 SamplingStrategy ,
1414)
1515from pysatl_core .families .parametric_family import ParametricFamily
1616from pysatl_core .families .parametrizations import Parametrization , parametrization
1717from pysatl_core .types import (
18- GenericCharacteristicName ,
1918 DistributionType ,
19+ GenericCharacteristicName ,
2020 ParametrizationName ,
2121)
2222
@@ -137,9 +137,7 @@ def _transform_to_natural_parametrization(
137137 def log_density (self ) -> ParametrizedFunction :
138138 def log_density_func (parametrization : Parametrization , x : Any ) -> Any :
139139 parametrization = cast (ExponentialFamilyParametrization , parametrization )
140- parametrization = self ._transform_to_natural_parametrization (
141- parametrization
142- )
140+ parametrization = self ._transform_to_natural_parametrization (parametrization )
143141 if not self ._support .accepts (x ):
144142 return float ("-inf" )
145143
@@ -150,9 +148,7 @@ def log_density_func(parametrization: Parametrization, x: Any) -> Any:
150148 dot = dot [0 ]
151149
152150 result = float (
153- np .log (self ._normalization (x ))
154- + dot
155- + self ._log_partition (parametrization )
151+ np .log (self ._normalization (x )) + dot + self ._log_partition (parametrization )
156152 )
157153 return result
158154
@@ -188,31 +184,24 @@ def pdf(theta: Any) -> Any:
188184 theta = [theta ]
189185 parametrization = ExponentialFamilyParametrization (theta = theta )
190186 # parametrization.theta = theta
191- return np .exp (
192- np .dot (theta , alpha ) + beta * self ._log_partition (parametrization )
193- )[0 ]
187+ return np .exp (np .dot (theta , alpha ) + beta * self ._log_partition (parametrization ))[0 ]
194188
195189 all_value = nquad (
196- lambda x : pdf (x ) if self ._parameter_space .accepts (x ) else 0 ,
190+ lambda x : pdf (x ) if self ._parameter_space .accepts (x ) else 0 , # type: ignore[arg-type]
197191 [(float ("-inf" ), float ("+inf" ))],
198192 )[0 ]
199193 return - np .log (all_value )
200194
201- # TODO: remove hardcoding - Done, all hardcoding is only on user's hands
202- # 1. pr with prototype/draft - in progress
203- # 2. write instruction about to add distributions as member of exponential family - not started
204- # 3. parametrization's spaces (передавать в конструктор) - maybe impossible, discuss this with desiment on meeting
205-
206195 def conjugate_sufficient_accepts (
207196 parametrization : ExponentialFamilyParametrization ,
208197 ) -> bool :
209198 theta = parametrization .theta
210199 xi = theta [:- 1 ]
211200 nu = theta [- 1 ]
212201
213- return self ._sufficient_statistics_values .accepts (
214- xi
215- ) and SpacePredicateArray ([( 0 , float ( "+inf" ))]) .accepts (nu )
202+ return self ._sufficient_statistics_values .accepts (xi ) and SpacePredicateArray (
203+ [( 0 , float ( "+inf" ))]
204+ ).accepts (nu )
216205
217206 return NaturalExponentialFamily (
218207 log_partition = conjugate_log_partition ,
@@ -269,10 +258,8 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
269258 if hasattr (x , "__len__" ):
270259 dimension_size = len (x )
271260 return nquad (
272- lambda x : (
273- np .dot (x , self .density (parametrization , x ))
274- if self ._support .accepts (x )
275- else 0
261+ lambda x : ( # type: ignore[arg-type]
262+ np .dot (x , self .density (parametrization , x )) if self ._support .accepts (x ) else 0
276263 ),
277264 [(float ("-inf" ), float ("inf" ))] * dimension_size ,
278265 )[0 ]
@@ -287,10 +274,8 @@ def func(parametrization: Parametrization, x: Any) -> Any:
287274 if hasattr (x , "__len__" ):
288275 dimension_size = len (x )
289276 return nquad (
290- lambda x : (
291- x ** 2 * self .density (parametrization , x )
292- if self ._support .accepts (x )
293- else 0
277+ lambda x : ( # type: ignore[arg-type]
278+ x ** 2 * self .density (parametrization , x ) if self ._support .accepts (x ) else 0
294279 ),
295280 [(float ("-inf" ), float ("inf" ))] * dimension_size ,
296281 )[0 ]
@@ -301,10 +286,7 @@ def func(parametrization: Parametrization, x: Any) -> Any:
301286 def _var (self ) -> ParametrizedFunction :
302287 def func (parametrization : Parametrization , x : Any ) -> Any :
303288 parametrization = cast (ExponentialFamilyParametrization , parametrization )
304- return (
305- self ._second_moment (parametrization , x )
306- - self ._mean (parametrization , x ) ** 2
307- )
289+ return self ._second_moment (parametrization , x ) - self ._mean (parametrization , x ) ** 2
308290
309291 return func
310292
@@ -323,9 +305,7 @@ def posterior_hyperparameters(
323305 alpha_post = self ._sufficient (sample )
324306 beta_post = 1
325307
326- return ExponentialConjugateHyperparameters (
327- alpha = alpha + alpha_post , beta = beta + beta_post
328- )
308+ return ExponentialConjugateHyperparameters (alpha = alpha + alpha_post , beta = beta + beta_post )
329309
330310
331311class ExponentialFamily (NaturalExponentialFamily ):
@@ -356,9 +336,7 @@ def natural_log_partition(
356336 return log_partition (ExponentialFamilyParametrization (theta = [theta ]))
357337
358338 natural_sufficient_statistics_values = SpacePredicate (
359- lambda eta : sufficient_statistics_values .accepts (
360- parameter_from_natural_parameter (eta )
361- )
339+ lambda eta : sufficient_statistics_values .accepts (parameter_from_natural_parameter (eta ))
362340 )
363341
364342 self ._natural_parameter = natural_parameter
0 commit comments