Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,3 @@ repos:
- numpy>=2
- scipy>=1.13
- pytest>=8
exclude: '^tests/.*'
81 changes: 58 additions & 23 deletions examples/example-parametric.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ exclude_lines = [
]

[tool.mypy]
files = [ "src", "tests" ]
python_version = "3.12"
strict = true
warn_unused_configs = true
Expand All @@ -103,3 +104,10 @@ no_implicit_optional = true
namespace_packages = true
explicit_package_bases = true
mypy_path = [ "src" ]

[[tool.mypy.overrides]]
module = [ "tests.*" ]
disallow_untyped_defs = false
check_untyped_defs = true
warn_return_any = false
implicit_reexport = true
2 changes: 0 additions & 2 deletions src/pysatl_core/families/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pysatl_core.families.parametrizations import (
Parametrization,
ParametrizationConstraint,
ParametrizationSpec,
constraint,
parametrization,
)
Expand All @@ -27,7 +26,6 @@
"ParametricFamilyRegister",
"ParametrizationConstraint",
"Parametrization",
"ParametrizationSpec",
"ParametricFamily",
"ParametricFamilyDistribution",
"constraint",
Expand Down
48 changes: 18 additions & 30 deletions src/pysatl_core/families/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from collections.abc import Mapping
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any

from pysatl_core.distributions import (
Expand Down Expand Up @@ -73,37 +72,26 @@ def family(self) -> ParametricFamily:
def analytical_computations(
self,
) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]:
"""
Get analytical computation functions for this distribution.
"""Lazily computed analytical computations for this distribution instance.

Returns
-------
Mapping[GenericCharacteristicName, AnalyticalComputation]
Mapping from characteristic names to computation functions.
Delegates construction to the parent family (precomputed plan) and
caches the result per-instance. The cache auto-invalidates when either
the **parametrization object** changes (by identity) or the
**parametrization name** changes.

*If you mutate numeric fields of the same parametrization object*,
the callables see fresh values because they close over that object.
"""
analytical_computations = {}

# First form list of all characteristics, available from current parametrization
for characteristic, forms in self.family.distr_characteristics.items():
if self.parameters.name in forms:
analytical_computations[characteristic] = AnalyticalComputation(
target=characteristic,
func=partial(forms[self.parameters.name], self.parameters),
)
# TODO: Second, apply rule set, for, e.g. approximations

# Finally, fill other chacteristics
base_name = self.family.parametrizations.base_parametrization_name
base_parameters = self.family.parametrizations.get_base_parameters(self.parameters)
for characteristic, forms in self.family.distr_characteristics.items():
if characteristic in analytical_computations:
continue
if base_name in forms:
analytical_computations[characteristic] = AnalyticalComputation(
target=characteristic, func=partial(forms[base_name], base_parameters)
)

return analytical_computations
key = (id(self.parameters), self.parameters.name)
cache_key = getattr(self, "_analytical_cache_key", None)
cache_val = getattr(self, "_analytical_cache_val", None)

if cache_key != key or cache_val is None:
cache_val = self.family._build_analytical_computations(self.parameters)
self._analytical_cache_key = key
self._analytical_cache_val = cache_val

return cache_val

@property
def sampling_strategy(self) -> SamplingStrategy:
Expand Down
Loading