Skip to content

Customizable token cache #759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import functools
import json
import time
Expand Down Expand Up @@ -238,6 +239,10 @@ class ClientApplication(object):
"You can enable broker by following these instructions. "
"https://msal-python.readthedocs.io/en/latest/#publicclientapplication")

_TOKEN_CACHE_DATA: dict[str, str] = { # field_in_data: field_in_cache

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation dict[str, str] for _TOKEN_CACHE_DATA uses Python 3.9+ syntax. If you intend to support Python 3.7 or 3.8 (as implied by the import on line 9), consider using Dict[str, str] from typing for compatibility.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a comment
# Maps field names from input data to their corresponding field names in the token cache.
# This is used to ensure that certain token types (e.g., SSH certificates, POP tokens)
# are correctly associated with their identifying keys in the cache.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @maliksahil , it has been a while, how is it going? :-)

Thanks for catching that type annotation issue. I'll take a look at why that typo was not caught by our test automation, when we revive this PR.

Why did you notice this PR in the first place? Are you working on some cutting-edge scenario that needs the behavior of this PR?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha going well my friend. We are always looking at cutting edge stuff. Hope you and everyone else are well.

"key_id": "key_id", # Some token types (SSH-certs, POP) are bound to a key
}

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
Expand Down Expand Up @@ -651,6 +656,7 @@ def __init__(

self._decide_broker(allow_broker, enable_pii_log)
self.token_cache = token_cache or TokenCache()
self.token_cache._set(data_to_at=self._TOKEN_CACHE_DATA)
self._region_configured = azure_region
self._region_detected = None
self.client, self._regional_client = self._build_client(
Expand Down Expand Up @@ -1528,9 +1534,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"realm": authority.tenant,
"home_account_id": (account or {}).get("home_account_id"),
}
key_id = kwargs.get("data", {}).get("key_id")
if key_id: # Some token types (SSH-certs, POP) are bound to a key
query["key_id"] = key_id
for field_in_data, field_in_cache in self._TOKEN_CACHE_DATA.items():
value = kwargs.get("data", {}).get(field_in_data)
if value:
query[field_in_cache] = value
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in self.token_cache.search( # A generator allows us to
Expand Down
36 changes: 30 additions & 6 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from __future__ import annotations
import json
import threading
import time
from typing import Optional # Needed in Python 3.7 & 3.8
import logging
import warnings

Expand Down Expand Up @@ -39,6 +41,25 @@ class AuthorityType:
ADFS = "ADFS"
MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA

_data_to_at: dict[str, str] = { # field_in_data: field_in_cache
# Store extra data which we explicitly allow,
# so that we won't accidentally store a user's password etc.
# It can be used to store for example key_id used in SSH-cert or POP
}
_response_to_at: dict[str, str] = { # field_in_response: field_in_cache
}

def _set(
self,
*,
data_to_at: Optional[dict[str, str]] = None,
response_to_at: Optional[dict[str, str]] = None,
) -> None:
# This helper should probably be better in __init__(),
# but there is no easy way for MSAL EX to pick up a kwargs
self._data_to_at = data_to_at or {}
self._response_to_at = response_to_at or {}

def __init__(self):
self._lock = threading.RLock()
self._cache = {}
Expand Down Expand Up @@ -267,11 +288,14 @@ def __add(self, event, now=None):
"expires_on": str(now + expires_in), # Same here
"extended_expires_on": str(now + ext_expires_in) # Same here
}
at.update({k: data[k] for k in data if k in {
# Also store extra data which we explicitly allow
# So that we won't accidentally store a user's password etc.
"key_id", # It happens in SSH-cert or POP scenario
}})
for field_in_resp, field_in_cache in self._response_to_at.items():
value = response.get(field_in_resp)
if value:
at[field_in_cache] = value
for field_in_data, field_in_cache in self._data_to_at.items():
value = data.get(field_in_data)
if value:
at[field_in_cache] = value
if "refresh_in" in response:
refresh_in = response["refresh_in"] # It is an integer
at["refresh_on"] = str(now + refresh_in) # Schema wants a string
Expand Down
1 change: 1 addition & 0 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def assertFoundAccessToken(self, *, scopes, query, data=None, now=None):
def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
scopes = ["s2", "s1", "s3"] # Not in particular order
now = 1000
self.cache._set(data_to_at={"key_id": "key_id"})
self.cache.add({
"data": data,
"client_id": "my_client_id",
Expand Down