Skip to content

Commit d306edd

Browse files
committed
Add simple caching for dispersions
1 parent e8a03ab commit d306edd

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

src/elli/dispersions/base_dispersion.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import numpy.typing as npt
1010
import pandas as pd
11+
from lmfit import Parameter
1112
from numpy.lib.scimath import sqrt
1213

1314
from .. import dispersions
@@ -44,6 +45,14 @@ def _guard_invalid_params(params1, params2):
4445
missing_param_strings = ", ".join(f"{p}" for p in missing_params)
4546
raise InvalidParameters(f"Invalid parameter(s): {missing_param_strings}")
4647

48+
@staticmethod
49+
def _hash_params(params: dict | list[dict]) -> int:
50+
"""Creates an single_params_dict or the repeating_params_list."""
51+
if isinstance(params, list):
52+
return hash(tuple([self._hash_params(dictionary) for dictionary in params]))
53+
else:
54+
return hash(tuple([item for _, item in params.items()]))
55+
4756
@staticmethod
4857
def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
4958
BaseDispersion._guard_invalid_params(list(kwargs.keys()), list(template.keys()))
@@ -56,6 +65,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
5665

5766
for i, val in enumerate(args):
5867
key = list(template.keys())[i]
68+
if isinstance(val, Parameter):
69+
val = val.value
5970
params[key] = val
6071
pos_arguments.add(key)
6172

@@ -64,6 +75,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
6475
raise InvalidParameters(
6576
f"Parameter {key} already set by positional argument"
6677
)
78+
if isinstance(value, Parameter):
79+
value = value.value
6780
params[key] = value
6881

6982
return params
@@ -80,6 +93,10 @@ def __init__(self, *args, **kwargs):
8093
if self.single_params[param] is None:
8194
raise InvalidParameters(f"Please specify parameter {param}")
8295

96+
self.last_lbda = None
97+
self.hash_single_params = None
98+
self.hash_rep_params = None
99+
83100
@abstractmethod
84101
def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
85102
"""Calculates the dielectric function in a given wavelength window.
@@ -114,6 +131,39 @@ def get_dielectric(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray:
114131
"""Returns the dielectric constant for wavelength 'lbda' default unit (nm)
115132
in the convention ε1 + iε2."""
116133
lbda = self.default_lbda_range if lbda is None else lbda
134+
135+
from .table_epsilon import TableEpsilon
136+
from .table_index import Table
137+
138+
if not isinstance(self, (DispersionSum, IndexDispersionSum)):
139+
if isinstance(self, (TableEpsilon, Table)):
140+
if self.last_lbda is lbda:
141+
return self.cached_diel
142+
else:
143+
self.last_lbda = lbda
144+
self.cached_diel = np.asarray(
145+
self.dielectric_function(lbda), dtype=np.complex128
146+
)
147+
return self.cached_diel
148+
else:
149+
new_single_hash = self._hash_params(self.single_params)
150+
new_rep_hash = self._hash_params(self.rep_params)
151+
152+
if (
153+
self.last_lbda is lbda
154+
and self.hash_single_params == new_single_hash
155+
and self.hash_rep_params == new_rep_hash
156+
):
157+
return self.cached_diel
158+
else:
159+
self.last_lbda = lbda
160+
self.hash_single_params = new_single_hash
161+
self.hash_rep_params = new_rep_hash
162+
self.cached_diel = np.asarray(
163+
self.dielectric_function(lbda), dtype=np.complex128
164+
)
165+
return self.cached_diel
166+
117167
return np.asarray(self.dielectric_function(lbda), dtype=np.complex128)
118168

119169
def get_refractive_index(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray:

tests/benchmark_fitting.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ def test_fitting_structure_updates(benchmark, datadir):
9494

9595
@fit(psi_delta, params)
9696
def model(lbda, params):
97-
SiO2.single_params["n0"] = params["SiO2_n0"]
98-
SiO2.single_params["n1"] = params["SiO2_n1"]
99-
SiO2.single_params["n2"] = params["SiO2_n2"]
100-
SiO2.single_params["k0"] = params["SiO2_k0"]
101-
SiO2.single_params["k1"] = params["SiO2_k1"]
102-
SiO2.single_params["k2"] = params["SiO2_k2"]
97+
SiO2.single_params["n0"] = params["SiO2_n0"].value
98+
SiO2.single_params["n1"] = params["SiO2_n1"].value
99+
SiO2.single_params["n2"] = params["SiO2_n2"].value
100+
SiO2.single_params["k0"] = params["SiO2_k0"].value
101+
SiO2.single_params["k1"] = params["SiO2_k1"].value
102+
SiO2.single_params["k2"] = params["SiO2_k2"].value
103103

104104
layer.set_thickness(params["SiO2_d"])
105105

0 commit comments

Comments
 (0)