diff --git a/CHANGES/10725.feature.rst b/CHANGES/10725.feature.rst new file mode 100644 index 00000000000..1adf6f0c38a --- /dev/null +++ b/CHANGES/10725.feature.rst @@ -0,0 +1 @@ +Added a digest authentication helper class. diff --git a/CHANGES/2213.feature.rst b/CHANGES/2213.feature.rst new file mode 120000 index 00000000000..d118975e478 --- /dev/null +++ b/CHANGES/2213.feature.rst @@ -0,0 +1 @@ +10725.feature.rst \ No newline at end of file diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index e3ddd3e3d6a..066987ef655 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -193,6 +193,7 @@ Jesus Cea Jian Zeng Jinkyu Yi Joel Watts +John Feusi John Parton Jon Nabozny Jonas Krüger Svensson diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 7759a997cb9..688450e8a4b 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -50,7 +50,7 @@ from .connector import AddrInfoType, SocketFactoryType from .cookiejar import CookieJar, DummyCookieJar from .formdata import FormData -from .helpers import BasicAuth, ChainMapProxy, ETag +from .helpers import BasicAuth, ChainMapProxy, DigestAuth, ETag from .http import ( HttpVersion, HttpVersion10, @@ -164,6 +164,7 @@ # helpers "BasicAuth", "ChainMapProxy", + "DigestAuth", "ETag", # http "HttpVersion", diff --git a/aiohttp/client.py b/aiohttp/client.py index 04f03b710f0..8f50384e4d6 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -95,6 +95,7 @@ from .helpers import ( _SENTINEL, EMPTY_BODY_METHODS, + AuthBase, BasicAuth, TimeoutHandle, frozen_dataclass_decorator, @@ -174,7 +175,7 @@ class _RequestOptions(TypedDict, total=False): cookies: Union[LooseCookies, None] headers: Union[LooseHeaders, None] skip_auto_headers: Union[Iterable[str], None] - auth: Union[BasicAuth, None] + auth: Union[AuthBase, None] allow_redirects: bool max_redirects: int compress: Union[str, bool] @@ -272,7 +273,7 @@ def __init__( proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, skip_auto_headers: Optional[Iterable[str]] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[AuthBase] = None, json_serialize: JSONEncoder = json.dumps, request_class: Type[ClientRequest] = ClientRequest, response_class: Type[ClientResponse] = ClientResponse, @@ -429,7 +430,7 @@ async def _request( cookies: Optional[LooseCookies] = None, headers: Optional[LooseHeaders] = None, skip_auto_headers: Optional[Iterable[str]] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[AuthBase] = None, allow_redirects: bool = True, max_redirects: int = 10, compress: Union[str, bool] = False, @@ -672,6 +673,13 @@ async def _request( resp = await req.send(conn) try: await resp.start(conn) + # Try performing digest authentication. It returns + # True if we need to resend the request, as + # DigestAuth is a bit of a handshake. This is + # a no-op for BasicAuth. If it + if auth is not None and auth.authenticate(resp): + resp.close() + continue except BaseException: resp.close() raise @@ -824,7 +832,7 @@ def ws_connect( autoclose: bool = True, autoping: bool = True, heartbeat: Optional[float] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[AuthBase] = None, origin: Optional[str] = None, params: Query = None, headers: Optional[LooseHeaders] = None, @@ -872,7 +880,7 @@ async def _ws_connect( autoclose: bool = True, autoping: bool = True, heartbeat: Optional[float] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[AuthBase] = None, origin: Optional[str] = None, params: Query = None, headers: Optional[LooseHeaders] = None, @@ -1247,7 +1255,7 @@ def skip_auto_headers(self) -> FrozenSet[istr]: return self._skip_auto_headers @property - def auth(self) -> Optional[BasicAuth]: # type: ignore[misc] + def auth(self) -> Optional[AuthBase]: """An object that represents HTTP Basic Authorization""" return self._default_auth @@ -1412,8 +1420,7 @@ def request( headers - (optional) Dictionary of HTTP Headers to send with the request cookies - (optional) Dict object to send with the request - auth - (optional) BasicAuth named tuple represent HTTP Basic Auth - auth - aiohttp.helpers.BasicAuth + auth - (optional) something implementing AuthBase for authentication allow_redirects - (optional) If set to False, do not follow redirects version - Request HTTP version. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index d30e8704d3e..8e0cd708960 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -43,8 +43,10 @@ from .hdrs import CONTENT_TYPE from .helpers import ( _SENTINEL, + AuthBase, BaseTimerContext, BasicAuth, + DigestAuth, HeadersMixin, TimerNoop, basicauth_from_netrc, @@ -230,7 +232,7 @@ def __init__( skip_auto_headers: Optional[Iterable[str]] = None, data: Any = None, cookies: Optional[LooseCookies] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[AuthBase] = None, version: http.HttpVersion = http.HttpVersion11, compress: Union[str, bool] = False, chunked: Optional[bool] = None, @@ -287,12 +289,12 @@ def __init__( self.update_auto_headers(skip_auto_headers) self.update_cookies(cookies) self.update_content_encoding(data, compress) - self.update_auth(auth, trust_env) self.update_proxy(proxy, proxy_auth, proxy_headers) self.update_body_from_data(data) if data is not None or self.method not in self.GET_METHODS: self.update_transfer_encoding() + self.update_auth(auth, trust_env) # Must be after we set the body self.update_expect_continue(expect100) self._traces = [] if traces is None else traces @@ -322,7 +324,7 @@ def ssl(self) -> Union["SSLContext", bool, Fingerprint]: return self._ssl @property - def connection_key(self) -> ConnectionKey: # type: ignore[misc] + def connection_key(self) -> ConnectionKey: if proxy_headers := self.proxy_headers: h: Optional[int] = hash(tuple(proxy_headers.items())) else: @@ -370,7 +372,7 @@ def update_host(self, url: URL) -> None: # basic auth info if url.raw_user or url.raw_password: - self.auth = helpers.BasicAuth(url.user or "", url.password or "") + self.auth = BasicAuth(url.user or "", url.password or "") def update_version(self, version: Union[http.HttpVersion, str]) -> None: """Convert request version to two elements tuple. @@ -494,7 +496,7 @@ def update_transfer_encoding(self) -> None: if hdrs.CONTENT_LENGTH not in self.headers: self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) - def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: + def update_auth(self, auth: Optional[AuthBase], trust_env: bool = False) -> None: """Set basic auth.""" if auth is None: auth = self.auth @@ -505,10 +507,12 @@ def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> Non if auth is None: return - if not isinstance(auth, helpers.BasicAuth): - raise TypeError("BasicAuth() tuple is required instead") + if not isinstance(auth, BasicAuth) and not isinstance(auth, DigestAuth): + raise TypeError("BasicAuth() or DigestAuth() is required instead") - self.headers[hdrs.AUTHORIZATION] = auth.encode() + authorization_hdr = auth.encode(self.method, self.url, self.body) + if authorization_hdr: + self.headers[hdrs.AUTHORIZATION] = authorization_hdr def update_body_from_data(self, body: Any) -> None: if body is None: @@ -570,7 +574,7 @@ def update_proxy( self.proxy_headers = None return - if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth): + if proxy_auth and not isinstance(proxy_auth, BasicAuth): raise ValueError("proxy_auth must be None or BasicAuth() tuple") self.proxy_auth = proxy_auth diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 22a459586c7..628d0a27833 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -8,16 +8,18 @@ import datetime import enum import functools +import hashlib import inspect import netrc import os import platform import re import sys +import threading import time import warnings import weakref -from collections import namedtuple +from abc import ABC, abstractmethod from contextlib import suppress from email.parser import HeaderParser from email.utils import parsedate @@ -31,6 +33,7 @@ Callable, ContextManager, Dict, + Final, Generic, Iterable, Iterator, @@ -40,6 +43,7 @@ Protocol, Tuple, Type, + TypedDict, TypeVar, Union, final, @@ -53,7 +57,7 @@ from propcache.api import under_cached_property as reify from yarl import URL -from . import hdrs +from . import client_exceptions, hdrs from .log import client_logger from .typedefs import PathLike # noqa @@ -64,6 +68,8 @@ if TYPE_CHECKING: from dataclasses import dataclass as frozen_dataclass_decorator + + from .client import ClientResponse elif sys.version_info < (3, 10): frozen_dataclass_decorator = functools.partial(dataclasses.dataclass, frozen=True) else: @@ -71,7 +77,16 @@ dataclasses.dataclass, frozen=True, slots=True ) -__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify") +__all__ = ( + "AuthBase", + "BasicAuth", + "ChainMapProxy", + "ETag", + "frozen_dataclass_decorator", + "reify", + "DigestAuth", + "DigestFunctions", +) PY_310 = sys.version_info >= (3, 10) @@ -127,12 +142,14 @@ json_re = re.compile(r"(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE) -class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): - """Http basic authentication helper.""" +class AuthBase(ABC): + """Abstract base class for HTTP auth helpers.""" - def __new__( - cls, login: str, password: str = "", encoding: str = "latin1" - ) -> "BasicAuth": + def __init__( + self, + login: str = "", + password: str = "", + ) -> None: if login is None: raise ValueError("None is not allowed as login value") @@ -140,12 +157,39 @@ def __new__( raise ValueError("None is not allowed as password value") if ":" in login: - raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') + raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)') + + self.login: Final = login + self.password: Final = password + + @abstractmethod + def encode(self, method: str, url: URL, body: Any) -> str: # type: ignore[misc] + pass - return super().__new__(cls, login, password, encoding) + @abstractmethod + def authenticate(self, resp: "ClientResponse") -> bool: + pass + + +class BasicAuth(AuthBase): + """Http basic authentication helper.""" + + def __init__( + self, login: str = "", password: str = "", encoding: str = "latin1" + ) -> None: + super().__init__(login, password) + self.encoding = encoding + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, BasicAuth) + and self.login == other.login + and self.password == other.password + and self.encoding == other.encoding + ) @classmethod - def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": # type: ignore[misc] + def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": """Create a BasicAuth object from an Authorization HTTP header.""" try: auth_type, encoded_credentials = auth_header.split(" ", 1) @@ -171,10 +215,10 @@ def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": # t except ValueError: raise ValueError("Invalid credentials.") - return cls(username, password, encoding=encoding) + return cls(username, password, encoding) @classmethod - def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: # type: ignore[misc] + def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: """Create BasicAuth from url.""" if not isinstance(url, URL): raise TypeError("url should be yarl.URL instance") @@ -182,12 +226,18 @@ def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth" # to already have these values parsed from the netloc in the cache. if url.raw_user is None and url.raw_password is None: return None - return cls(url.user or "", url.password or "", encoding=encoding) + return cls(url.user or "", url.password or "", encoding) - def encode(self) -> str: + def encode( + self, unused_method: str = "", unused_url: URL = URL(""), body: Any = None + ) -> str: """Encode credentials.""" creds = (f"{self.login}:{self.password}").encode(self.encoding) - return "Basic %s" % base64.b64encode(creds).decode(self.encoding) + return f"Basic {base64.b64encode(creds).decode(self.encoding)}" + + def authenticate(self, resp: "ClientResponse") -> bool: + """Always returns false because request need not be resent for BasicAuth""" + return False def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: @@ -245,7 +295,7 @@ def netrc_from_env() -> Optional[netrc.netrc]: @frozen_dataclass_decorator -class ProxyInfo: # type: ignore[misc] +class ProxyInfo: proxy: URL proxy_auth: Optional[BasicAuth] @@ -279,6 +329,246 @@ def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAu return BasicAuth(username, password) +DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = { + "MD5": hashlib.md5, + "MD5-SESS": hashlib.md5, + "SHA": hashlib.sha1, + "SHA256": hashlib.sha256, + "SHA512": hashlib.sha512, +} + + +class DigestAuthChallenge(TypedDict, total=False): + realm: str + nonce: str + qop: str + algorithm: str + opaque: str + ... + + +class DigestAuthContext(threading.local): + """Thread-local storage for DigestAuth""" + + init: bool + last_nonce: str + nonce_count: int + challenge: DigestAuthChallenge + handled_401: bool + + def __init__(self) -> None: + super().__init__() + self.init = False + + def init_thread(self) -> None: + if self.init: + return + self.init = True + self.last_nonce = "" + self.nonce_count = 0 + self.challenge = {} + self.handled_401 = False + + +def parse_header_pairs(header: str) -> Dict[str, str]: + """Parses header pairs in the www-authenticate header value""" + # RFC 7616 accepts header key/values that look like + # key1="value1", key2=value2, key3="some value, with, commas" + # + # This regex attempts to parse that out + pattern = re.compile( + r'(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))' + # | | | | | | | | | || | + # +----|--|-|-|--|----|------|----|--||-----|--> alphanumeric key + # +--|-|-|--|----|------|----|--||-----|--> maybe whitespace + # | | | | | | | || | + # +-|-|--|----|------|----|--||-----|--> = (delimiter) + # +-|--|----|------|----|--||-----|--> maybe whitespace + # | | | | | || | + # +--|----|------|----|--||-----|--> group quoted or unquoted + # | | | | || | + # +----|------|----|--||-----|--> if quoted... + # +------|----|--||-----|--> anything but " or \ + # +----|--||-----|--> escaped characters allowed + # +--||-----|--> or can be empty string + # || | + # +|-----|--> if unquoted... + # +-----|--> anything but , or + # +--> at least one char req'd + ) + + header_pairs = {} + for key, quoted_val, unquoted_val in pattern.findall(header): + val = quoted_val if quoted_val else unquoted_val + if val: + val = val.replace('\\"', '"') # unescape any escaped quotes + header_pairs[key] = val + + return header_pairs + + +class DigestAuth(AuthBase): + """ + HTTP digest authentication helper. + + The work here is based off of + https://github.com/requests/requests/blob/v2.18.4/requests/auth.py. + + Please also refer to RFC7616. + """ + + def __init__( + self, + login: str, + password: str, + ) -> None: + super().__init__(login, password) + self.ctx = DigestAuthContext() + + def encode(self, method: str, url: URL, body: Any) -> str: + """Build digest header""" + self.ctx.init_thread() + + if not self.ctx.handled_401: + return "" + + if "realm" not in self.ctx.challenge: + raise client_exceptions.ClientError("Challenge is missing realm") + + if "nonce" not in self.ctx.challenge: + raise client_exceptions.ClientError("Challenge is missing nonce") + + realm: str = self.ctx.challenge.get("realm", "") + nonce: str = self.ctx.challenge.get("nonce", "") + qop_raw: str = self.ctx.challenge.get("qop", "") + algorithm: str = self.ctx.challenge.get("algorithm", "MD5").upper() + opaque: str = self.ctx.challenge.get("opaque", "") + + qop: str = "" + if qop_raw: + qop_list = [q.strip() for q in qop_raw.split(",") if q.strip()] + valid_qops = {"auth", "auth-int"}.intersection(qop_list) + if not valid_qops: + raise client_exceptions.ClientError( + f"Unsupported qop value(s): {qop_raw}" + ) + + qop = "auth-int" if "auth-int" in valid_qops else "auth" + + if algorithm not in DigestFunctions: + return "" + hash_fn: Final = DigestFunctions[algorithm] + + def H(x: str) -> str: + return hash_fn(x.encode()).hexdigest() + + def KD(s: str, d: str) -> str: + return H(f"{s}:{d}") + + path = URL(url).path_qs + A1 = f"{self.login}:{realm}:{self.password}" + A2 = f"{method.upper()}:{path}" + if qop == "auth-int": + if isinstance(body, bytes): + entity_str = body.decode("utf-8", errors="replace") + elif isinstance(body, str): + entity_str = body + else: + entity_str = "" + entity_hash = H(entity_str) + A2 = f"{A2}:{entity_hash}" + + HA1 = H(A1) + HA2 = H(A2) + + if nonce == self.ctx.last_nonce: + self.ctx.nonce_count += 1 + else: + self.ctx.nonce_count = 1 + + self.ctx.last_nonce = nonce + + ncvalue = f"{self.ctx.nonce_count:08x}" + + # cnonce is just a random string generated by the client. + cnonce_data = "".join( + [ + str(self.ctx.nonce_count), + nonce, + time.ctime(), + os.urandom(8).decode(errors="ignore"), + ] + ).encode() + cnonce = hashlib.sha1(cnonce_data).hexdigest()[:16] + + if algorithm == "MD5-SESS": + HA1 = H(f"{HA1}:{nonce}:{cnonce}") + + if qop: + noncebit = ":".join([nonce, ncvalue, cnonce, qop, HA2]) + response_digest = KD(HA1, noncebit) + else: + response_digest = KD(HA1, f"{nonce}:{HA2}") + + pairs = [ + f'username="{self.login}"', + f'realm="{realm}"', + f'nonce="{nonce}"', + f'uri="{path}"', + f'response="{response_digest}"', + f'algorithm="{algorithm}"', + ] + if opaque: + pairs.append(f'opaque="{opaque}"') + if qop: + pairs.append(f'qop="{qop}"') + pairs.append(f"nc={ncvalue}") + pairs.append(f'cnonce="{cnonce}"') + + self.ctx.handled_401 = False + + return f"Digest {', '.join(pairs)}" + + def authenticate(self, response: "ClientResponse") -> bool: + """ + Takes the given response and tries digest-auth, if needed. + + Returns true if the original request must be resent. + """ + # Effective recursion guard + self.ctx.init_thread() + if self.ctx.handled_401: + return False + + if response.status != 401: + self.ctx.handled_401 = False + return False + + auth_header = response.headers.get("www-authenticate", "") + + parts = auth_header.split(" ", 1) + if "digest" == parts[0].lower() and len(parts) > 1 and not self.ctx.handled_401: + self.ctx.handled_401 = True + + header_pairs = parse_header_pairs(parts[1]) + + self.ctx.challenge = {} + if "realm" in header_pairs and header_pairs["realm"]: + self.ctx.challenge["realm"] = header_pairs["realm"] + if "nonce" in header_pairs and header_pairs["nonce"]: + self.ctx.challenge["nonce"] = header_pairs["nonce"] + if "qop" in header_pairs and header_pairs["qop"]: + self.ctx.challenge["qop"] = header_pairs["qop"] + if "algorithm" in header_pairs and header_pairs["algorithm"]: + self.ctx.challenge["algorithm"] = header_pairs["algorithm"] + if "opaque" in header_pairs and header_pairs["opaque"]: + self.ctx.challenge["opaque"] = header_pairs["opaque"] + + return True + + return False + + def proxies_from_env() -> Dict[str, ProxyInfo]: proxy_urls = { k: URL(v) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 0f6eb99974b..28c7aaed4bb 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -67,6 +67,8 @@ argument. An instance of :class:`BasicAuth` can be passed in like this:: async with ClientSession(auth=auth) as session: ... +Similarly for :class:`DigestAuth`. + Note that if the request is redirected and the redirect URL contains credentials, those credentials will supersede any previously set credentials. In other words, if ``http://user@example.com`` redirects to diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 43e02ebfeaa..be7d98346c1 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -106,13 +106,13 @@ The client session supports the context manager protocol for self closing. Iterable of :class:`str` or :class:`~multidict.istr` (optional) - :param aiohttp.BasicAuth auth: an object that represents HTTP Basic - Authorization (optional). It will be included - with any request. However, if the - ``_base_url`` parameter is set, the request - URL's origin must match the base URL's origin; - otherwise, the default auth will not be - included. + :param aiohttp.AuthBase auth: an object that represents HTTP Authorization + (optional). It will be included + with any request. However, if the + ``_base_url`` parameter is set, the request + URL's origin must match the base URL's origin; + otherwise, the default auth will not be + included. :param collections.abc.Callable json_serialize: Json *serializer* callable. @@ -306,7 +306,7 @@ The client session supports the context manager protocol for self closing. An object that represents HTTP Basic Authorization. - :class:`~aiohttp.BasicAuth` (optional) + :class:`~aiohttp.AuthBase` (optional) .. versionadded:: 3.7 @@ -433,8 +433,8 @@ The client session supports the context manager protocol for self closing. Iterable of :class:`str` or :class:`~multidict.istr` (optional) - :param aiohttp.BasicAuth auth: an object that represents HTTP - Basic Authorization (optional) + :param aiohttp.AuthBase auth: an object that represents HTTP + Authorization (optional) :param bool allow_redirects: Whether to process redirects or not. When ``True``, redirects are followed (up to ``max_redirects`` times) @@ -698,8 +698,8 @@ The client session supports the context manager protocol for self closing. (``10.0`` seconds for the websocket to close). ``None`` means no timeout will be used. - :param aiohttp.BasicAuth auth: an object that represents HTTP - Basic Authorization (optional) + :param aiohttp.AuthBase auth: an object that represents HTTP + Authorization (optional) :param bool autoclose: Automatically close websocket connection on close message from server. If *autoclose* is False @@ -903,8 +903,8 @@ certification chaining. Iterable of :class:`str` or :class:`~multidict.istr` (optional) - :param aiohttp.BasicAuth auth: an object that represents HTTP Basic - Authorization (optional) + :param aiohttp.AuthBase auth: an object that represents HTTP + Authorization (optional) :param bool allow_redirects: Whether to process redirects or not. When ``True``, redirects are followed (up to ``max_redirects`` times) @@ -1971,6 +1971,34 @@ Utilities .. versionadded:: 3.2 +.. class:: AuthBase(login, password) + + Abstract base class for HTTP authentication helpers. + + :param str login: Username used for authentication. + :param str password: Password used for authentication. + + :raises ValueError: If login or password is None or login contains a colon. + + .. method:: encode(method, url, body) -> str + + :param str method: HTTP method (e.g., GET, POST). + :param ~yarl.URL url: Full request URL. + :param body: Optional request body (used for `auth-int` qop). + + :return: A complete `Authorization` header. + + Abstract method to generate the `Authorization` header. + + .. method:: authenticate(resp: ClientResponse) -> bool + + :param ClientResponse response: The HTTP response to inspect. + + :return: ``True`` if a retry is required to complete authentication. + + Abstract method to handle 401 responses and extract challenges. + + .. class:: BasicAuth(login, password='', encoding='latin1') HTTP basic authentication helper. @@ -2010,6 +2038,52 @@ Utilities :return: encoded authentication data, :class:`str`. + .. method:: authenticate(resp) + + :param ClientResponse resp: Unused + + :return: Always returns ``False`` because no request resending is required with basic authorization. + + .. versionadded:: 3.12 + + +.. class:: DigestAuth(login, password) + + HTTP digest authentication helper. + + :param str login: login + :param str password: password + + This class builds digest authentication headers, supporting both `auth` and + `auth-int` qop modes, and a variety of hashing algorithms. + + .. versionadded:: 3.12 + + .. method:: encode(method, url, body) -> str + + Generates the `Authorization: Digest ...` header using stored challenge data. + + :param str method: HTTP method (e.g., GET, POST). + :param ~yarl.URL url: Full request URL. + :param body: Optional request body (used for `auth-int` qop). + + :return: A complete `Authorization` header. + + :raises ClientError: If the challenge is missing required fields or contains unsupported qop values. + + .. versionadded:: 3.12 + + .. method:: authenticate(response) -> bool + + Parses the `WWW-Authenticate` header from a 401 response and stores the challenge + in thread-local state. + + :param ClientResponse response: The HTTP response to inspect. + + :return: ``True`` if a retry is required to complete authentication. + + .. versionadded:: 3.12 + .. class:: CookieJar(*, unsafe=False, quote_cookie=True, treat_as_secure_origin = []) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 5eabd185d05..1070c33002c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -252,6 +252,7 @@ pyflakes pyright pytest Pytest +qop Quickstart quote’s rc diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 2b5e2725c49..1b5d3a88fff 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -1465,7 +1465,7 @@ def test_gen_default_accept_encoding(has_brotli: bool, expected: str) -> None: indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") -def test_basicauth_from_netrc_present( # type: ignore[misc] +def test_basicauth_from_netrc_present( make_request: _RequestMaker, expected_auth: helpers.BasicAuth, ) -> None: diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 974d330a3c9..d257d5090b5 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -31,6 +31,7 @@ from aiohttp.client_reqrep import ClientRequest, ConnectionKey from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector from aiohttp.cookiejar import CookieJar +from aiohttp.helpers import DigestAuth from aiohttp.http import RawResponseMessage from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.test_utils import make_mocked_coro @@ -455,6 +456,38 @@ async def test_borrow_connector_loop( assert session._loop is loop +async def test_retry_on_401( + create_session: Callable[..., Awaitable[ClientSession]], + create_mocked_conn: Callable[[], ResponseHandler], +) -> None: + resp = mock.create_autospec(aiohttp.ClientResponse) + resp.status = 401 + resp.url = URL("http://example.com") + resp.cookies = SimpleCookie() + resp.headers = {"www-authenticate": "Digest realm=foo nonce=abcd, algorithm=SHA512"} + resp.start = mock.AsyncMock() + + req = mock.Mock() + req_factory = mock.Mock(return_value=req) + req.send = mock.AsyncMock(return_value=resp) + session = await create_session( + request_class=req_factory, auth=DigestAuth("user", "pw") + ) + + async def create_connection( + req: object, traces: object, timeout: object + ) -> ResponseHandler: + # return self.transport, self.protocol + return create_mocked_conn() + + with mock.patch.object(session._connector, "_create_connection", create_connection): + with mock.patch.object( + session._connector, "_release", autospec=True, spec_set=True + ): + await session.request("get", "http://example.com") + assert req.send.call_count == 2 + + async def test_reraise_os_error( create_session: Callable[..., Awaitable[ClientSession]], create_mocked_conn: Callable[[], ResponseHandler], diff --git a/tests/test_connector.py b/tests/test_connector.py index c4019df3cdf..e4b13800e6e 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -54,25 +54,25 @@ @pytest.fixture -def key() -> ConnectionKey: # type: ignore[misc] +def key() -> ConnectionKey: # Connection key return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture -def key2() -> ConnectionKey: # type: ignore[misc] +def key2() -> ConnectionKey: # Connection key return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture -def other_host_key2() -> ConnectionKey: # type: ignore[misc] +def other_host_key2() -> ConnectionKey: # Connection key return ConnectionKey("otherhost", 80, False, True, None, None, None) @pytest.fixture -def ssl_key() -> ConnectionKey: # type: ignore[misc] +def ssl_key() -> ConnectionKey: # Connection key return ConnectionKey("localhost", 80, True, True, None, None, None) @@ -221,7 +221,7 @@ async def test_del(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: @pytest.mark.xfail -async def test_del_with_scheduled_cleanup( # type: ignore[misc] +async def test_del_with_scheduled_cleanup( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: loop.set_debug(True) @@ -251,7 +251,7 @@ async def test_del_with_scheduled_cleanup( # type: ignore[misc] @pytest.mark.skipif( sys.implementation.name != "cpython", reason="CPython GC is required for the test" ) -def test_del_with_closed_loop( # type: ignore[misc] +def test_del_with_closed_loop( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: async def make_conn() -> aiohttp.BaseConnector: @@ -444,7 +444,7 @@ async def test_release(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> N @pytest.mark.usefixtures("enable_cleanup_closed") -async def test_release_ssl_transport( # type: ignore[misc] +async def test_release_ssl_transport( loop: asyncio.AbstractEventLoop, ssl_key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(enable_cleanup_closed=True) @@ -1924,7 +1924,7 @@ async def test_cleanup(key: ConnectionKey) -> None: @pytest.mark.usefixtures("enable_cleanup_closed") -async def test_cleanup_close_ssl_transport( # type: ignore[misc] +async def test_cleanup_close_ssl_transport( loop: asyncio.AbstractEventLoop, ssl_key: ConnectionKey ) -> None: proto = create_mocked_conn(loop) @@ -3887,7 +3887,7 @@ async def test_available_connections_with_limit_per_host( @pytest.mark.parametrize("limit_per_host", [0, 10]) -async def test_available_connections_without_limit_per_host( # type: ignore[misc] +async def test_available_connections_without_limit_per_host( key: ConnectionKey, other_host_key2: ConnectionKey, limit_per_host: int ) -> None: """Verify expected values based on active connections with higher host limit.""" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 8adc33f53fc..3cf5aa6b7e2 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -14,11 +14,15 @@ from multidict import CIMultiDict, MultiDict, MultiDictProxy from yarl import URL -from aiohttp import helpers, web +from aiohttp import client_exceptions, helpers, web from aiohttp.helpers import ( EMPTY_BODY_METHODS, + DigestAuth, + DigestAuthContext, + DigestFunctions, is_expected_content_type, must_be_empty_body, + parse_header_pairs, parse_http_date, should_remove_content_length, ) @@ -189,7 +193,7 @@ def test_basic_auth_decode_invalid_credentials() -> None: ), ), ) -def test_basic_auth_decode_blank_username( # type: ignore[misc] +def test_basic_auth_decode_blank_username( credentials: str, expected_auth: helpers.BasicAuth ) -> None: header = f"Basic {base64.b64encode(credentials.encode()).decode()}" @@ -628,6 +632,7 @@ def test_proxies_from_env_http_with_auth(url_input: str, expected_scheme: str) - assert ret.keys() == {expected_scheme} assert ret[expected_scheme].proxy == url.with_user(None) proxy_auth = ret[expected_scheme].proxy_auth + assert isinstance(proxy_auth, helpers.BasicAuth) assert proxy_auth is not None assert proxy_auth.login == "user" assert proxy_auth.password == "pass" @@ -1105,7 +1110,7 @@ def test_netrc_from_home_does_not_raise_if_access_denied( indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") -def test_basicauth_present_in_netrc( # type: ignore[misc] +def test_basicauth_present_in_netrc( expected_auth: helpers.BasicAuth, ) -> None: """Test that netrc file contents are properly parsed into BasicAuth tuples""" @@ -1168,3 +1173,254 @@ def test_should_remove_content_length_is_subset_of_must_be_empty_body() -> None: assert should_remove_content_length("CONNECT", 300) is False assert must_be_empty_body("CONNECT", 300) is False + + +@pytest.mark.parametrize( + "header, expected", + [ + ( + 'realm="testrealm", nonce="abc123", qop="auth"', + {"realm": "testrealm", "nonce": "abc123", "qop": "auth"}, + ), + ("qop=auth, algorithm=MD5", {"qop": "auth", "algorithm": "MD5"}), + ('nonce="abc", opaque=""', {"nonce": "abc", "opaque": ""}), + ( + 'realm="a realm with spaces", qop=auth', + {"realm": "a realm with spaces", "qop": "auth"}, + ), + ( + 'realm="escaped \\"quote\\"", qop=auth', + {"realm": 'escaped "quote"', "qop": "auth"}, + ), + (' algorithm = MD5 , realm = "x" ', {"algorithm": "MD5", "realm": "x"}), + ], +) +def test_parse_key_value_header(header: str, expected: Dict[str, str]) -> None: + result = parse_header_pairs(header) + assert result == expected + + +# ------------------- DigestAuth ----------------------------------- + + +@pytest.fixture +def digest_auth() -> DigestAuth: + return DigestAuth("user", "pass") + + +def test_context_initialization() -> None: + ctx = DigestAuthContext() + assert not ctx.init + ctx.init_thread() + assert ctx.nonce_count == 0 + assert ctx.last_nonce == "" + assert ctx.challenge == {} + assert ctx.handled_401 is False + assert ctx.init is True + + +def test_authenticate_valid_digest(digest_auth: DigestAuth) -> None: + response = mock.Mock() + response.status = 401 + response.headers = { + "www-authenticate": 'Digest realm="test", nonce="abc", qop="auth", opaque="xyz", algorithm=MD5' + } + + assert digest_auth.authenticate(response) + assert digest_auth.ctx.challenge["realm"] == "test" + assert digest_auth.ctx.challenge["nonce"] == "abc" + assert digest_auth.ctx.challenge["qop"] == "auth" + assert digest_auth.ctx.challenge["algorithm"] == "MD5" + assert digest_auth.ctx.challenge["opaque"] == "xyz" + + +def test_authenticate_invalid_status(digest_auth: DigestAuth) -> None: + response = mock.Mock() + response.status = 200 + response.headers = {} + assert not digest_auth.authenticate(response) + + +def test_authenticate_already_handled(digest_auth: DigestAuth) -> None: + response = mock.Mock() + response.status = 401 + response.headers = { + "www-authenticate": 'Digest realm="test", nonce="abc", qop="auth"' + } + digest_auth.ctx.init_thread() + digest_auth.ctx.handled_401 = True + assert not digest_auth.authenticate(response) + + +def test_encode_without_challenge(digest_auth: DigestAuth) -> None: + digest_auth.ctx.init_thread() + digest_auth.ctx.handled_401 = False + assert digest_auth.encode("GET", URL("http://example.com/resource"), "") == "" + + +def test_encode_missing_realm_or_nonce(digest_auth: DigestAuth) -> None: + digest_auth.ctx.init_thread() + digest_auth.ctx.handled_401 = True + digest_auth.ctx.challenge = {"nonce": "abc"} + with pytest.raises(Exception): + digest_auth.encode("GET", URL("http://example.com/resource"), "") + + +def test_encode_digest_with_md5(digest_auth: DigestAuth) -> None: + digest_auth.ctx.init_thread() + digest_auth.ctx.handled_401 = True + digest_auth.ctx.challenge = { + "realm": "test", + "nonce": "abc", + "qop": "auth", + "algorithm": "MD5", + "opaque": "xyz", + } + header = digest_auth.encode("GET", URL("http://example.com/resource"), "") + assert header.startswith("Digest ") + assert 'username="user"' in header + assert 'algorithm="MD5"' in header + + +def test_encode_digest_with_md5_sess(digest_auth: DigestAuth) -> None: + digest_auth.ctx.init_thread() + digest_auth.ctx.handled_401 = True + digest_auth.ctx.challenge = { + "realm": "test", + "nonce": "abc", + "qop": "auth", + "algorithm": "MD5-SESS", + } + header = digest_auth.encode("GET", URL("http://example.com/resource"), "") + assert 'algorithm="MD5-SESS"' in header + + +def test_encode_unsupported_algorithm(digest_auth: DigestAuth) -> None: + digest_auth.ctx.init_thread() + digest_auth.ctx.handled_401 = True + digest_auth.ctx.challenge = { + "realm": "test", + "nonce": "abc", + "algorithm": "UNSUPPORTED", + } + assert digest_auth.encode("GET", URL("http://example.com/resource"), "") == "" + + +def test_invalid_qop_rejected() -> None: + auth = DigestAuth("u", "p") + auth.ctx.init_thread() + auth.ctx.challenge = { + "realm": "r", + "nonce": "n", + "qop": "badvalue", + "algorithm": "MD5", + } + auth.ctx.handled_401 = True + with pytest.raises(client_exceptions.ClientError): + auth.encode("GET", URL("http://x"), "") + + +def compute_expected_digest( + algorithm: str, + username: str, + password: str, + realm: str, + nonce: str, + uri: str, + method: str, + qop: str, + nc: str, + cnonce: str, + body: str = "", +) -> str: + hash_fn = DigestFunctions[algorithm] + + def H(x: str) -> str: + return hash_fn(x.encode()).hexdigest() + + def KD(secret: str, data: str) -> str: + return H(f"{secret}:{data}") + + A1 = f"{username}:{realm}:{password}" + HA1 = H(A1) + + if algorithm.upper() == "MD5-SESS": + HA1 = H(f"{HA1}:{nonce}:{cnonce}") + + A2 = f"{method}:{uri}" + if "auth-int" in qop: + entity_hash = H(body) + A2 = f"{A2}:{entity_hash}" + HA2 = H(A2) + + if qop: + response = KD(HA1, f"{nonce}:{nc}:{cnonce}:{qop}:{HA2}") + else: + response = KD(HA1, f"{nonce}:{HA2}") + + return response + + +@pytest.mark.parametrize("qop", ["auth", "auth-int", "auth,auth-int"]) +@pytest.mark.parametrize("algorithm", list(DigestFunctions.keys())) +def test_digest_response_exact_match(qop: str, algorithm: str) -> None: + # Fixed input values + login = "user" + password = "pass" + realm = "example.com" + nonce = "abc123nonce" + cnonce = "deadbeefcafebabe" + nc = 1 + ncvalue = f"{nc+1:08x}" + method = "GET" + uri = "/secret" + body = "this is a body" + qop = "auth-int" if "auth-int" in qop else "auth" + + # Create the auth object + auth = DigestAuth(login, password) + auth.ctx.init_thread() + auth.ctx.challenge = { + "realm": realm, + "nonce": nonce, + "qop": qop, + "algorithm": algorithm, + } + auth.ctx.handled_401 = True + auth.ctx.last_nonce = nonce + auth.ctx.nonce_count = nc + + # Patch cnonce manually by replacing the auth.encode() logic + # We'll monkey-patch hashlib.sha1 to return a fixed cnonce if needed + import hashlib as real_hashlib + + original_sha1 = real_hashlib.sha1 + + class FakeSHA1(mock.Mock): + def hexdigest(self) -> str: + return cnonce + + real_hashlib.sha1 = lambda *_: FakeSHA1() + + try: + header = auth.encode(method, URL(f"http://host{uri}"), body) + finally: + real_hashlib.sha1 = original_sha1 + + # Get expected digest + expected = compute_expected_digest( + algorithm=algorithm, + username=login, + password=password, + realm=realm, + nonce=nonce, + uri=uri, + method=method, + qop=qop, + nc=ncvalue, + cnonce=cnonce, + body=body, + ) + + # Check that the response digest is exactly correct + assert f'response="{expected}"' in header diff --git a/tests/test_proxy.py b/tests/test_proxy.py index a4baabb4047..acfdd5496cf 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -961,7 +961,10 @@ def test_proxy_auth_property(self) -> None: proxy_auth=aiohttp.helpers.BasicAuth("user", "pass"), loop=self.loop, ) - self.assertEqual(("user", "pass", "latin1"), req.proxy_auth) + assert isinstance(req.proxy_auth, aiohttp.helpers.BasicAuth) + self.assertEqual(req.proxy_auth.login, "user") + self.assertEqual(req.proxy_auth.password, "pass") + self.assertEqual(req.proxy_auth.encoding, "latin1") def test_proxy_auth_property_default(self) -> None: req = aiohttp.ClientRequest(