@@ -39,6 +39,12 @@ class AuthorityType:
39
39
def __init__ (self ):
40
40
self ._lock = threading .RLock ()
41
41
self ._cache = {}
42
+ self .key_makers = {
43
+ self .CredentialType .REFRESH_TOKEN : self ._build_rt_key ,
44
+ self .CredentialType .ACCESS_TOKEN : self ._build_at_key ,
45
+ self .CredentialType .ID_TOKEN : self ._build_idt_key ,
46
+ self .CredentialType .ACCOUNT : self ._build_account_key ,
47
+ }
42
48
43
49
def find (self , credential_type , target = None , query = None ):
44
50
target = target or []
@@ -83,14 +89,9 @@ def add(self, event, now=None):
83
89
with self ._lock :
84
90
85
91
if access_token :
86
- key = "-" .join ([
87
- home_account_id or "" ,
88
- environment or "" ,
89
- self .CredentialType .ACCESS_TOKEN ,
90
- event .get ("client_id" , "" ),
91
- realm or "" ,
92
- target ,
93
- ]).lower ()
92
+ key = self ._build_at_key (
93
+ home_account_id , environment , event .get ("client_id" , "" ),
94
+ realm , target )
94
95
now = time .time () if now is None else now
95
96
expires_in = response .get ("expires_in" , 3599 )
96
97
self ._cache .setdefault (self .CredentialType .ACCESS_TOKEN , {})[key ] = {
@@ -110,11 +111,7 @@ def add(self, event, now=None):
110
111
if client_info :
111
112
decoded_id_token = json .loads (
112
113
base64decode (id_token .split ('.' )[1 ])) if id_token else {}
113
- key = "-" .join ([
114
- home_account_id or "" ,
115
- environment or "" ,
116
- realm or "" ,
117
- ]).lower ()
114
+ key = self ._build_account_key (home_account_id , environment , realm )
118
115
self ._cache .setdefault (self .CredentialType .ACCOUNT , {})[key ] = {
119
116
"home_account_id" : home_account_id ,
120
117
"environment" : environment ,
@@ -129,14 +126,8 @@ def add(self, event, now=None):
129
126
}
130
127
131
128
if id_token :
132
- key = "-" .join ([
133
- home_account_id or "" ,
134
- environment or "" ,
135
- self .CredentialType .ID_TOKEN ,
136
- event .get ("client_id" , "" ),
137
- realm or "" ,
138
- "" # Albeit irrelevant, schema requires an empty scope here
139
- ]).lower ()
129
+ key = self ._build_idt_key (
130
+ home_account_id , environment , event .get ("client_id" , "" ), realm )
140
131
self ._cache .setdefault (self .CredentialType .ID_TOKEN , {})[key ] = {
141
132
"credential_type" : self .CredentialType .ID_TOKEN ,
142
133
"secret" : id_token ,
@@ -170,6 +161,24 @@ def add(self, event, now=None):
170
161
"family_id" : response .get ("foci" ), # None is also valid
171
162
}
172
163
164
+ def modify (self , credential_type , old_entry , new_key_value_pairs = None ):
165
+ # Modify the specified old_entry with new_key_value_pairs,
166
+ # or remove the old_entry if the new_key_value_pairs is None.
167
+
168
+ # This helper exists to consolidate all token modify/remove behaviors,
169
+ # so that the sub-classes will have only one method to work on,
170
+ # instead of patching a pair of update_xx() and remove_xx() per type.
171
+ # You can monkeypatch self.key_makers to support more types on-the-fly.
172
+ key = self .key_makers [credential_type ](** old_entry )
173
+ with self ._lock :
174
+ if new_key_value_pairs : # Update with them
175
+ entries = self ._cache .setdefault (credential_type , {})
176
+ entry = entries .get (key , {}) # key usually exists, but we'll survive its absence
177
+ entry .update (new_key_value_pairs )
178
+ else : # Remove old_entry
179
+ self ._cache .setdefault (credential_type , {}).pop (key , None )
180
+
181
+
173
182
@staticmethod
174
183
def _build_appmetadata_key (environment , client_id ):
175
184
return "appmetadata-{}-{}" .format (environment or "" , client_id or "" )
@@ -178,7 +187,7 @@ def _build_appmetadata_key(environment, client_id):
178
187
def _build_rt_key (
179
188
cls ,
180
189
home_account_id = None , environment = None , client_id = None , target = None ,
181
- ** ignored ):
190
+ ** ignored_payload_from_a_real_token ):
182
191
return "-" .join ([
183
192
home_account_id or "" ,
184
193
environment or "" ,
@@ -189,16 +198,61 @@ def _build_rt_key(
189
198
]).lower ()
190
199
191
200
def remove_rt (self , rt_item ):
192
- key = self ._build_rt_key (** rt_item )
193
- with self ._lock :
194
- self ._cache .setdefault (self .CredentialType .REFRESH_TOKEN , {}).pop (key , None )
201
+ assert rt_item .get ("credential_type" ) == self .CredentialType .REFRESH_TOKEN
202
+ return self .modify (self .CredentialType .REFRESH_TOKEN , rt_item )
195
203
196
204
def update_rt (self , rt_item , new_rt ):
197
- key = self ._build_rt_key (** rt_item )
198
- with self ._lock :
199
- RTs = self ._cache .setdefault (self .CredentialType .REFRESH_TOKEN , {})
200
- rt = RTs .get (key , {}) # key usually exists, but we'll survive its absence
201
- rt ["secret" ] = new_rt
205
+ assert rt_item .get ("credential_type" ) == self .CredentialType .REFRESH_TOKEN
206
+ return self .modify (
207
+ self .CredentialType .REFRESH_TOKEN , rt_item , {"secret" : new_rt })
208
+
209
+ @classmethod
210
+ def _build_at_key (cls ,
211
+ home_account_id = None , environment = None , client_id = None ,
212
+ realm = None , target = None , ** ignored_payload_from_a_real_token ):
213
+ return "-" .join ([
214
+ home_account_id or "" ,
215
+ environment or "" ,
216
+ cls .CredentialType .ACCESS_TOKEN ,
217
+ client_id ,
218
+ realm or "" ,
219
+ target or "" ,
220
+ ]).lower ()
221
+
222
+ def remove_at (self , at_item ):
223
+ assert at_item .get ("credential_type" ) == self .CredentialType .ACCESS_TOKEN
224
+ return self .modify (self .CredentialType .ACCESS_TOKEN , at_item )
225
+
226
+ @classmethod
227
+ def _build_idt_key (cls ,
228
+ home_account_id = None , environment = None , client_id = None , realm = None ,
229
+ ** ignored_payload_from_a_real_token ):
230
+ return "-" .join ([
231
+ home_account_id or "" ,
232
+ environment or "" ,
233
+ cls .CredentialType .ID_TOKEN ,
234
+ client_id or "" ,
235
+ realm or "" ,
236
+ "" # Albeit irrelevant, schema requires an empty scope here
237
+ ]).lower ()
238
+
239
+ def remove_idt (self , idt_item ):
240
+ assert idt_item .get ("credential_type" ) == self .CredentialType .ID_TOKEN
241
+ return self .modify (self .CredentialType .ID_TOKEN , idt_item )
242
+
243
+ @classmethod
244
+ def _build_account_key (cls ,
245
+ home_account_id = None , environment = None , realm = None ,
246
+ ** ignored_payload_from_a_real_entry ):
247
+ return "-" .join ([
248
+ home_account_id or "" ,
249
+ environment or "" ,
250
+ realm or "" ,
251
+ ]).lower ()
252
+
253
+ def remove_account (self , account_item ):
254
+ assert "authority_type" in account_item
255
+ return self .modify (self .CredentialType .ACCOUNT , account_item )
202
256
203
257
204
258
class SerializableTokenCache (TokenCache ):
@@ -221,7 +275,7 @@ class SerializableTokenCache(TokenCache):
221
275
...
222
276
223
277
:var bool has_state_changed:
224
- Indicates whether the cache state has changed since last
278
+ Indicates whether the cache state in the memory has changed since last
225
279
:func:`~serialize` or :func:`~deserialize` call.
226
280
"""
227
281
has_state_changed = False
@@ -230,12 +284,9 @@ def add(self, event, **kwargs):
230
284
super (SerializableTokenCache , self ).add (event , ** kwargs )
231
285
self .has_state_changed = True
232
286
233
- def remove_rt (self , rt_item ):
234
- super (SerializableTokenCache , self ).remove_rt (rt_item )
235
- self .has_state_changed = True
236
-
237
- def update_rt (self , rt_item , new_rt ):
238
- super (SerializableTokenCache , self ).update_rt (rt_item , new_rt )
287
+ def modify (self , credential_type , old_entry , new_key_value_pairs = None ):
288
+ super (SerializableTokenCache , self ).modify (
289
+ credential_type , old_entry , new_key_value_pairs )
239
290
self .has_state_changed = True
240
291
241
292
def deserialize (self , state ):
0 commit comments