8
8
import numpy as np
9
9
import numpy .typing as npt
10
10
import pandas as pd
11
+ from lmfit import Parameter
11
12
from numpy .lib .scimath import sqrt
12
13
13
14
from .. import dispersions
@@ -44,6 +45,14 @@ def _guard_invalid_params(params1, params2):
44
45
missing_param_strings = ", " .join (f"{ p } " for p in missing_params )
45
46
raise InvalidParameters (f"Invalid parameter(s): { missing_param_strings } " )
46
47
48
+ @staticmethod
49
+ def _hash_params (params : dict | list [dict ]) -> int :
50
+ """Creates an single_params_dict or the repeating_params_list."""
51
+ if isinstance (params , list ):
52
+ return hash (tuple ([self ._hash_params (dictionary ) for dictionary in params ]))
53
+ else :
54
+ return hash (tuple ([item for _ , item in params .items ()]))
55
+
47
56
@staticmethod
48
57
def _fill_params_dict (template : dict , * args , ** kwargs ) -> dict :
49
58
BaseDispersion ._guard_invalid_params (list (kwargs .keys ()), list (template .keys ()))
@@ -56,6 +65,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
56
65
57
66
for i , val in enumerate (args ):
58
67
key = list (template .keys ())[i ]
68
+ if isinstance (val , Parameter ):
69
+ val = val .value
59
70
params [key ] = val
60
71
pos_arguments .add (key )
61
72
@@ -64,6 +75,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
64
75
raise InvalidParameters (
65
76
f"Parameter { key } already set by positional argument"
66
77
)
78
+ if isinstance (value , Parameter ):
79
+ value = value .value
67
80
params [key ] = value
68
81
69
82
return params
@@ -80,6 +93,10 @@ def __init__(self, *args, **kwargs):
80
93
if self .single_params [param ] is None :
81
94
raise InvalidParameters (f"Please specify parameter { param } " )
82
95
96
+ self .last_lbda = None
97
+ self .hash_single_params = None
98
+ self .hash_rep_params = None
99
+
83
100
@abstractmethod
84
101
def dielectric_function (self , lbda : npt .ArrayLike ) -> npt .NDArray :
85
102
"""Calculates the dielectric function in a given wavelength window.
@@ -114,6 +131,39 @@ def get_dielectric(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray:
114
131
"""Returns the dielectric constant for wavelength 'lbda' default unit (nm)
115
132
in the convention ε1 + iε2."""
116
133
lbda = self .default_lbda_range if lbda is None else lbda
134
+
135
+ from .table_epsilon import TableEpsilon
136
+ from .table_index import Table
137
+
138
+ if not isinstance (self , (DispersionSum , IndexDispersionSum )):
139
+ if isinstance (self , (TableEpsilon , Table )):
140
+ if self .last_lbda is lbda :
141
+ return self .cached_diel
142
+ else :
143
+ self .last_lbda = lbda
144
+ self .cached_diel = np .asarray (
145
+ self .dielectric_function (lbda ), dtype = np .complex128
146
+ )
147
+ return self .cached_diel
148
+ else :
149
+ new_single_hash = self ._hash_params (self .single_params )
150
+ new_rep_hash = self ._hash_params (self .rep_params )
151
+
152
+ if (
153
+ self .last_lbda is lbda
154
+ and self .hash_single_params == new_single_hash
155
+ and self .hash_rep_params == new_rep_hash
156
+ ):
157
+ return self .cached_diel
158
+ else :
159
+ self .last_lbda = lbda
160
+ self .hash_single_params = new_single_hash
161
+ self .hash_rep_params = new_rep_hash
162
+ self .cached_diel = np .asarray (
163
+ self .dielectric_function (lbda ), dtype = np .complex128
164
+ )
165
+ return self .cached_diel
166
+
117
167
return np .asarray (self .dielectric_function (lbda ), dtype = np .complex128 )
118
168
119
169
def get_refractive_index (self , lbda : Optional [npt .ArrayLike ] = None ) -> npt .NDArray :
0 commit comments