Skip to content

Commit 371d671

Browse files
committed
feat: big base for distribution package
1 parent cc5fe62 commit 371d671

File tree

10 files changed

+411
-0
lines changed

10 files changed

+411
-0
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,7 @@ repos:
5959
hooks:
6060
- id: mypy
6161
args: [ "--config-file=pyproject.toml" ]
62+
additional_dependencies:
63+
# Иначе mypy не находит пакеты в венве
64+
# Альтернатива сделать repo: local, но это менее чисто с точки зрения CI
65+
- numpy>=2

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ urls.Issues = "https://github.com/PySATL/pysatl-core/issues"
4343

4444
[tool.poetry.dependencies]
4545
python = ">=3.12"
46+
numpy = "^2.0.0"
4647

4748
[tool.poetry.group.dev.dependencies]
4849
ruff = ">=0.6"

src/pysatl_core/distributions/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Any, Protocol
2+
3+
from pysatl_core.types import (
4+
GenericCharacteristicName,
5+
)
6+
7+
from .distribution import Distribution
8+
9+
10+
class GenericCharacteristic[In, Out](Protocol):
11+
name: GenericCharacteristicName
12+
13+
def __call__(self, distribution: Distribution, data: In, **options: Any) -> Out: ...
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from collections.abc import Callable, Sequence
2+
from dataclasses import dataclass
3+
from typing import Protocol, runtime_checkable
4+
5+
from pysatl_core.distributions.distribution import Distribution
6+
from pysatl_core.types import (
7+
GenericCharacteristicName,
8+
)
9+
10+
11+
@runtime_checkable
12+
class Computation[In, Out](Protocol):
13+
@property
14+
def target(self) -> GenericCharacteristicName: ...
15+
def __call__(self, data: In) -> Out: ...
16+
17+
18+
@runtime_checkable
19+
class FittedComputationMethodProtocol[In, Out](Protocol):
20+
@property
21+
def target(self) -> GenericCharacteristicName: ...
22+
@property
23+
def sources(self) -> Sequence[GenericCharacteristicName]: ...
24+
def __call__(self, data: In) -> Out: ...
25+
26+
27+
@runtime_checkable
28+
class ComputationMethodProtocol[In, Out](Protocol):
29+
@property
30+
def target(self) -> GenericCharacteristicName: ...
31+
@property
32+
def sources(self) -> Sequence[GenericCharacteristicName]: ...
33+
def fit(self, distribution: Distribution) -> FittedComputationMethodProtocol[In, Out]: ...
34+
35+
36+
@dataclass(frozen=True, slots=True)
37+
class AnalyticalComputation[In, Out]:
38+
target: GenericCharacteristicName
39+
func: Callable[[In], Out]
40+
41+
def __call__(self, data: In) -> Out:
42+
return self.func(data)
43+
44+
45+
@dataclass(frozen=True, slots=True)
46+
class FittedComputationMethod[In, Out]:
47+
target: GenericCharacteristicName
48+
sources: Sequence[GenericCharacteristicName]
49+
func: Callable[[In], Out]
50+
51+
def __call__(self, data: In) -> Out:
52+
return self.func(data)
53+
54+
55+
@dataclass(frozen=True, slots=True)
56+
class ComputationMethod[In, Out]:
57+
target: GenericCharacteristicName
58+
sources: Sequence[GenericCharacteristicName]
59+
fitter: Callable[[Distribution], FittedComputationMethod[In, Out]]
60+
61+
def fit(self, distribution: Distribution) -> FittedComputationMethod[In, Out]:
62+
return self.fitter(distribution)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable, Mapping
4+
from dataclasses import dataclass
5+
from typing import Any, Protocol, runtime_checkable
6+
7+
from pysatl_core.distributions.computation import AnalyticalComputation
8+
from pysatl_core.distributions.sampling import Sample
9+
from pysatl_core.distributions.strategies import (
10+
ComputationStrategy,
11+
DefaultComputationStrategy,
12+
DefaultSamplingStrategy,
13+
SamplingStrategy,
14+
)
15+
from pysatl_core.types import (
16+
Dimension,
17+
GenericCharacteristicName,
18+
Kind,
19+
)
20+
21+
22+
@runtime_checkable
23+
class Distribution(Protocol):
24+
@property
25+
def dimension(self) -> Dimension: ...
26+
@property
27+
def kind(self) -> Kind: ...
28+
29+
@property
30+
def analytical_computations(
31+
self,
32+
) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: ...
33+
34+
@property
35+
def sampling_strategy(self) -> SamplingStrategy: ...
36+
@property
37+
def computation_strategy(self) -> ComputationStrategy[Any, Any]: ...
38+
39+
def sample(self, n: int, **options: Any) -> Sample: ...
40+
def log_likelihood(self, batch: Sample) -> float: ...
41+
42+
43+
@dataclass(slots=True)
44+
class StandaloneDistribution:
45+
_dimension: Dimension
46+
_kind: Kind
47+
_analytical: dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]]
48+
49+
def __init__(
50+
self,
51+
dimension: Dimension,
52+
kind: Kind,
53+
analytical_computations: Iterable[AnalyticalComputation[Any, Any]]
54+
| Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] = (),
55+
):
56+
self._dimension = dimension
57+
self._kind = kind
58+
if isinstance(analytical_computations, Mapping):
59+
self._analytical = dict(analytical_computations)
60+
else:
61+
self._analytical = {ac.target: ac for ac in analytical_computations}
62+
63+
@property
64+
def dimension(self) -> Dimension:
65+
return self._dimension
66+
67+
@property
68+
def kind(self) -> Kind:
69+
return self._kind
70+
71+
@property
72+
def analytical_computations(
73+
self,
74+
) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]:
75+
return self._analytical
76+
77+
@property
78+
def sampling_strategy(self) -> SamplingStrategy:
79+
return DefaultSamplingStrategy()
80+
81+
@property
82+
def computation_strategy(self) -> ComputationStrategy[Any, Any]:
83+
return DefaultComputationStrategy()
84+
85+
def sample(self, n: int, **options: Any) -> Sample:
86+
return self.sampling_strategy.sample(n, d=self.dimension, **options)
87+
88+
def log_likelihood(self, batch: Sample) -> float:
89+
# TODO: Не ноль)
90+
return 0.0
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from functools import lru_cache
5+
from typing import Any, ClassVar, Self
6+
7+
from pysatl_core.distributions.computation import ComputationMethod
8+
from pysatl_core.types import (
9+
Dimension,
10+
GenericCharacteristicName,
11+
Kind,
12+
)
13+
14+
15+
@dataclass
16+
class GenericCharacteristicRegister:
17+
dimension: Dimension
18+
kind: Kind
19+
20+
__registered_indefinitive_chars: dict[
21+
GenericCharacteristicName,
22+
list[ComputationMethod[Any, Any]],
23+
] = field(default_factory=dict)
24+
25+
__registered_definitive_chars: dict[
26+
GenericCharacteristicName, list[ComputationMethod[Any, Any]]
27+
] = field(default_factory=dict)
28+
29+
def register_indefinitive_characteristic(
30+
self,
31+
default_computation_method: ComputationMethod[Any, Any],
32+
) -> None:
33+
self.__registered_indefinitive_chars[default_computation_method.target].append(
34+
default_computation_method
35+
)
36+
37+
def register_definitive_characteristic(
38+
self,
39+
default_computation_method: ComputationMethod[Any, Any],
40+
default_inversion_method: ComputationMethod[Any, Any],
41+
) -> None:
42+
if (
43+
default_computation_method.target not in default_inversion_method.sources
44+
or default_inversion_method.target not in default_computation_method.sources
45+
):
46+
raise AttributeError(
47+
"New characteristic always must be in targets for default "
48+
"computation method and in sources for default inversion method"
49+
)
50+
self.__registered_definitive_chars[default_computation_method.target].append(
51+
default_computation_method
52+
)
53+
self.__registered_definitive_chars[default_inversion_method.target].append(
54+
default_inversion_method
55+
)
56+
57+
def get_available_indefinitive_characteristics(
58+
self,
59+
) -> dict[GenericCharacteristicName, list[ComputationMethod[Any, Any]]]:
60+
return self.__registered_indefinitive_chars
61+
62+
def get_available_definitive_characteristics(
63+
self,
64+
) -> dict[GenericCharacteristicName, list[ComputationMethod[Any, Any]]]:
65+
return self.__registered_definitive_chars
66+
67+
def get_all_available_characteristics_keys(self) -> list[GenericCharacteristicName]:
68+
return list(self.__registered_definitive_chars.keys()) + list(
69+
self.__registered_indefinitive_chars.keys()
70+
)
71+
72+
73+
class DistributionTypeRegister:
74+
_instance: ClassVar[Self | None] = None
75+
_register_kinds: dict[tuple[Dimension, Kind], GenericCharacteristicRegister]
76+
77+
def __new__(cls) -> Self:
78+
if cls._instance is None:
79+
self = super().__new__(cls)
80+
self._register_kinds = {}
81+
cls._instance = self
82+
return cls._instance
83+
84+
def get(self, dimension: Dimension, kind: Kind) -> GenericCharacteristicRegister:
85+
key = (dimension, kind)
86+
reg = self._register_kinds.get(key)
87+
if reg is None:
88+
reg = GenericCharacteristicRegister(dimension=dimension, kind=kind)
89+
self._register_kinds[key] = reg
90+
if reg.dimension != dimension or reg.kind != kind:
91+
raise TypeError(
92+
f"Inconsistent registry under key ({dimension}, {kind}): "
93+
f"got ({reg.dimension}, {reg.kind}) inside"
94+
)
95+
return reg
96+
97+
__call__ = get
98+
99+
100+
def _configure(reg: DistributionTypeRegister) -> None:
101+
reg.get(1, Kind.DISCRETE)
102+
reg.get(1, Kind.CONTINUOUS)
103+
104+
# Тут много заполнений дефолтами
105+
106+
107+
@lru_cache(maxsize=1)
108+
def distribution_type_register() -> DistributionTypeRegister:
109+
reg = DistributionTypeRegister()
110+
_configure(reg)
111+
return reg
112+
113+
114+
def _reset_distribution_type_register_for_tests() -> None:
115+
distribution_type_register.cache_clear()
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterator
4+
from typing import Any, Protocol
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
9+
from pysatl_core.types import Dimension
10+
11+
12+
class Sample(Protocol):
13+
def __len__(self) -> int: ...
14+
@property
15+
def dim(self) -> Dimension: ...
16+
@property
17+
def array(self) -> npt.NDArray[np.floating[Any]]: ...
18+
@property
19+
def shape(self) -> tuple[int, ...]: ...
20+
21+
22+
class ArraySample(Sample):
23+
dimension: Dimension
24+
data: npt.NDArray[np.floating[Any]]
25+
26+
def __init__(self, data: npt.NDArray[np.floating[Any]]) -> None:
27+
self.data = data
28+
self.dimension = data.shape[1]
29+
30+
def __len__(self) -> int:
31+
return int(self.data.shape[0])
32+
33+
@property
34+
def dim(self) -> Dimension:
35+
return self.dimension
36+
37+
def __iter__(self) -> Iterator[npt.NDArray[np.floating[Any]]]:
38+
yield from self.data
39+
40+
@property
41+
def array(self) -> npt.NDArray[np.floating[Any]]:
42+
return self.data
43+
44+
@property
45+
def shape(self) -> tuple[int, ...]:
46+
n, d = self.data.shape
47+
return int(n), int(d)

0 commit comments

Comments
 (0)