diff --git a/CHANGES/10725.feature.rst b/CHANGES/10725.feature.rst new file mode 100644 index 00000000000..2cb096a58e7 --- /dev/null +++ b/CHANGES/10725.feature.rst @@ -0,0 +1,6 @@ +Added a comprehensive HTTP Digest Authentication client middleware (DigestAuthMiddleware) +that implements RFC 7616. The middleware supports all standard hash algorithms +(MD5, SHA, SHA-256, SHA-512) with session variants, handles both 'auth' and +'auth-int' quality of protection options, and automatically manages the +authentication flow by intercepting 401 responses and retrying with proper +credentials -- by :user:`feus4177`, :user:`TimMenninger`, and :user:`bdraco`. diff --git a/CHANGES/2213.feature.rst b/CHANGES/2213.feature.rst new file mode 120000 index 00000000000..d118975e478 --- /dev/null +++ b/CHANGES/2213.feature.rst @@ -0,0 +1 @@ +10725.feature.rst \ No newline at end of file diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 89eb3ae621a..e54e97e8ce2 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -194,6 +194,7 @@ Jesus Cea Jian Zeng Jinkyu Yi Joel Watts +John Feusi John Parton Jon Nabozny Jonas Krüger Svensson diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index f2ada6bcf07..3f8a1cc62dc 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -47,6 +47,7 @@ WSServerHandshakeError, request, ) +from .client_middleware_digest_auth import DigestAuthMiddleware from .client_middlewares import ClientHandlerType, ClientMiddlewareType from .compression_utils import set_zlib_backend from .connector import AddrInfoType, SocketFactoryType @@ -169,6 +170,7 @@ # helpers "BasicAuth", "ChainMapProxy", + "DigestAuthMiddleware", "ETag", "set_zlib_backend", # http diff --git a/aiohttp/client_middleware_digest_auth.py b/aiohttp/client_middleware_digest_auth.py new file mode 100644 index 00000000000..e9eb3ba82e2 --- /dev/null +++ b/aiohttp/client_middleware_digest_auth.py @@ -0,0 +1,416 @@ +""" +Digest authentication middleware for aiohttp client. + +This middleware implements HTTP Digest Authentication according to RFC 7616, +providing a more secure alternative to Basic Authentication. It supports all +standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session +variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options. +""" + +import hashlib +import os +import re +import time +from typing import ( + Callable, + Dict, + Final, + FrozenSet, + List, + Literal, + Tuple, + TypedDict, + Union, +) + +from yarl import URL + +from . import hdrs +from .client_exceptions import ClientError +from .client_middlewares import ClientHandlerType +from .client_reqrep import ClientRequest, ClientResponse + + +class DigestAuthChallenge(TypedDict, total=False): + realm: str + nonce: str + qop: str + algorithm: str + opaque: str + + +DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = { + "MD5": hashlib.md5, + "MD5-SESS": hashlib.md5, + "SHA": hashlib.sha1, + "SHA-SESS": hashlib.sha1, + "SHA256": hashlib.sha256, + "SHA256-SESS": hashlib.sha256, + "SHA-256": hashlib.sha256, + "SHA-256-SESS": hashlib.sha256, + "SHA512": hashlib.sha512, + "SHA512-SESS": hashlib.sha512, + "SHA-512": hashlib.sha512, + "SHA-512-SESS": hashlib.sha512, +} + + +# Compile the regex pattern once at module level for performance +_HEADER_PAIRS_PATTERN = re.compile( + r'(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))' + # | | | | | | | | | || | + # +----|--|-|-|--|----|------|----|--||-----|--> alphanumeric key + # +--|-|-|--|----|------|----|--||-----|--> maybe whitespace + # | | | | | | | || | + # +-|-|--|----|------|----|--||-----|--> = (delimiter) + # +-|--|----|------|----|--||-----|--> maybe whitespace + # | | | | | || | + # +--|----|------|----|--||-----|--> group quoted or unquoted + # | | | | || | + # +----|------|----|--||-----|--> if quoted... + # +------|----|--||-----|--> anything but " or \ + # +----|--||-----|--> escaped characters allowed + # +--||-----|--> or can be empty string + # || | + # +|-----|--> if unquoted... + # +-----|--> anything but , or + # +--> at least one char req'd +) + + +# RFC 7616: Challenge parameters to extract +CHALLENGE_FIELDS: Final[ + Tuple[Literal["realm", "nonce", "qop", "algorithm", "opaque"], ...] +] = ( + "realm", + "nonce", + "qop", + "algorithm", + "opaque", +) + +# Supported digest authentication algorithms +# Use a tuple of sorted keys for predictable documentation and error messages +SUPPORTED_ALGORITHMS: Final[Tuple[str, ...]] = tuple(sorted(DigestFunctions.keys())) + +# RFC 7616: Fields that require quoting in the Digest auth header +# These fields must be enclosed in double quotes in the Authorization header. +# Algorithm, qop, and nc are never quoted per RFC specifications. +# This frozen set is used by the template-based header construction to +# automatically determine which fields need quotes. +QUOTED_AUTH_FIELDS: Final[FrozenSet[str]] = frozenset( + {"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"} +) + + +def escape_quotes(value: str) -> str: + """Escape double quotes for HTTP header values.""" + return value.replace('"', '\\"') + + +def unescape_quotes(value: str) -> str: + """Unescape double quotes in HTTP header values.""" + return value.replace('\\"', '"') + + +def parse_header_pairs(header: str) -> Dict[str, str]: + """ + Parse key-value pairs from WWW-Authenticate or similar HTTP headers. + + This function handles the complex format of WWW-Authenticate header values, + supporting both quoted and unquoted values, proper handling of commas in + quoted values, and whitespace variations per RFC 7616. + + Examples of supported formats: + - key1="value1", key2=value2 + - key1 = "value1" , key2="value, with, commas" + - key1=value1,key2="value2" + - realm="example.com", nonce="12345", qop="auth" + + Args: + header: The header value string to parse + + Returns: + Dictionary mapping parameter names to their values + """ + return { + stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val + for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header) + if (stripped_key := key.strip()) + } + + +class DigestAuthMiddleware: + """ + HTTP digest authentication middleware for aiohttp client. + + This middleware intercepts 401 Unauthorized responses containing a Digest + authentication challenge, calculates the appropriate digest credentials, + and automatically retries the request with the proper Authorization header. + + Features: + - Handles all aspects of Digest authentication handshake automatically + - Supports all standard hash algorithms: + - MD5, MD5-SESS + - SHA, SHA-SESS + - SHA256, SHA256-SESS, SHA-256, SHA-256-SESS + - SHA512, SHA512-SESS, SHA-512, SHA-512-SESS + - Supports 'auth' and 'auth-int' quality of protection modes + - Properly handles quoted strings and parameter parsing + - Includes replay attack protection with client nonce count tracking + + Standards compliance: + - RFC 7616: HTTP Digest Access Authentication (primary reference) + - RFC 2617: HTTP Authentication (deprecated by RFC 7616) + - RFC 1945: Section 11.1 (username restrictions) + + Implementation notes: + The core digest calculation is inspired by the implementation in + https://github.com/requests/requests/blob/v2.18.4/requests/auth.py + with added support for modern digest auth features and error handling. + """ + + def __init__( + self, + login: str, + password: str, + ) -> None: + if login is None: + raise ValueError("None is not allowed as login value") + + if password is None: + raise ValueError("None is not allowed as password value") + + if ":" in login: + raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)') + + self._login_str: Final[str] = login + self._login_bytes: Final[bytes] = login.encode("utf-8") + self._password_bytes: Final[bytes] = password.encode("utf-8") + + self._last_nonce_bytes = b"" + self._nonce_count = 0 + self._challenge: DigestAuthChallenge = {} + + def _encode(self, method: str, url: URL, body: Union[bytes, str]) -> str: + """ + Build digest authorization header for the current challenge. + + Args: + method: The HTTP method (GET, POST, etc.) + url: The request URL + body: The request body (used for qop=auth-int) + + Returns: + A fully formatted Digest authorization header string + + Raises: + ClientError: If the challenge is missing required parameters or + contains unsupported values + """ + challenge = self._challenge + if "realm" not in challenge: + raise ClientError( + "Malformed Digest auth challenge: Missing 'realm' parameter" + ) + + if "nonce" not in challenge: + raise ClientError( + "Malformed Digest auth challenge: Missing 'nonce' parameter" + ) + + # Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name) + realm = challenge["realm"] + nonce = challenge["nonce"] + + # Empty nonce values are not allowed as they are security-critical for replay protection + if not nonce: + raise ClientError( + "Security issue: Digest auth challenge contains empty 'nonce' value" + ) + + qop_raw = challenge.get("qop", "") + algorithm = challenge.get("algorithm", "MD5").upper() + opaque = challenge.get("opaque", "") + + # Convert string values to bytes once + nonce_bytes = nonce.encode("utf-8") + realm_bytes = realm.encode("utf-8") + path = URL(url).path_qs + + # Process QoP + qop = "" + qop_bytes = b"" + if qop_raw: + valid_qops = {"auth", "auth-int"}.intersection( + {q.strip() for q in qop_raw.split(",") if q.strip()} + ) + if not valid_qops: + raise ClientError( + f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}" + ) + + qop = "auth-int" if "auth-int" in valid_qops else "auth" + qop_bytes = qop.encode("utf-8") + + if algorithm not in DigestFunctions: + raise ClientError( + f"Digest auth error: Unsupported hash algorithm: {algorithm}. " + f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}" + ) + hash_fn: Final = DigestFunctions[algorithm] + + def H(x: bytes) -> bytes: + """RFC 7616 Section 3: Hash function H(data) = hex(hash(data)).""" + return hash_fn(x).hexdigest().encode() + + def KD(s: bytes, d: bytes) -> bytes: + """RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data)).""" + return H(b":".join((s, d))) + + # Calculate A1 and A2 + A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes)) + A2 = f"{method.upper()}:{path}".encode() + if qop == "auth-int": + if isinstance(body, str): + entity_str = body.encode("utf-8", errors="replace") + else: + entity_str = body + entity_hash = H(entity_str) + A2 = b":".join((A2, entity_hash)) + + HA1 = H(A1) + HA2 = H(A2) + + # Nonce count handling + if nonce_bytes == self._last_nonce_bytes: + self._nonce_count += 1 + else: + self._nonce_count = 1 + + self._last_nonce_bytes = nonce_bytes + ncvalue = f"{self._nonce_count:08x}" + ncvalue_bytes = ncvalue.encode("utf-8") + + # Generate client nonce + cnonce = hashlib.sha1( + b"".join( + [ + str(self._nonce_count).encode("utf-8"), + nonce_bytes, + time.ctime().encode("utf-8"), + os.urandom(8), + ] + ) + ).hexdigest()[:16] + cnonce_bytes = cnonce.encode("utf-8") + + # Special handling for session-based algorithms + if algorithm.upper().endswith("-SESS"): + HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes))) + + # Calculate the response digest + if qop: + noncebit = b":".join( + (nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2) + ) + response_digest = KD(HA1, noncebit) + else: + response_digest = KD(HA1, b":".join((nonce_bytes, HA2))) + + # Define a dict mapping of header fields to their values + # Group fields into always-present, optional, and qop-dependent + header_fields = { + # Always present fields + "username": escape_quotes(self._login_str), + "realm": escape_quotes(realm), + "nonce": escape_quotes(nonce), + "uri": path, + "response": response_digest.decode(), + "algorithm": algorithm, + } + + # Optional fields + if opaque: + header_fields["opaque"] = escape_quotes(opaque) + + # QoP-dependent fields + if qop: + header_fields["qop"] = qop + header_fields["nc"] = ncvalue + header_fields["cnonce"] = cnonce + + # Build header using templates for each field type + pairs: List[str] = [] + for field, value in header_fields.items(): + if field in QUOTED_AUTH_FIELDS: + pairs.append(f'{field}="{value}"') + else: + pairs.append(f"{field}={value}") + + return f"Digest {', '.join(pairs)}" + + def _authenticate(self, response: ClientResponse) -> bool: + """ + Takes the given response and tries digest-auth, if needed. + + Returns true if the original request must be resent. + """ + if response.status != 401: + return False + + auth_header = response.headers.get("www-authenticate", "") + if not auth_header: + return False # No authentication header present + + method, sep, headers = auth_header.partition(" ") + if not sep: + # No space found in www-authenticate header + return False # Malformed auth header, missing scheme separator + + if method.lower() != "digest": + # Not a digest auth challenge (could be Basic, Bearer, etc.) + return False + + if not headers: + # We have a digest scheme but no parameters + return False # Malformed digest header, missing parameters + + # We have a digest auth header with content + if not (header_pairs := parse_header_pairs(headers)): + # Failed to parse any key-value pairs + return False # Malformed digest header, no valid parameters + + # Extract challenge parameters + self._challenge = {} + for field in CHALLENGE_FIELDS: + if value := header_pairs.get(field): + self._challenge[field] = value + + # Return True only if we found at least one challenge parameter + return bool(self._challenge) + + async def __call__( + self, request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + """Run the digest auth middleware.""" + response = None + for retry_count in range(2): + # Apply authorization header if we have a challenge (on second attempt) + if retry_count > 0: + request.headers[hdrs.AUTHORIZATION] = self._encode( + request.method, request.url, request.body + ) + + # Send the request + response = await handler(request) + + # Check if we need to authenticate + if not self._authenticate(response): + break + elif retry_count < 1: + response.release() # Release the response to enable connection reuse on retry + + # At this point, response is guaranteed to be defined + assert response is not None + return response diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 107141e69be..823ac45a8c7 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -67,6 +67,26 @@ argument. An instance of :class:`BasicAuth` can be passed in like this:: async with ClientSession(auth=auth) as session: ... +For HTTP digest authentication, use the :class:`DigestAuthMiddleware` client middleware:: + + from aiohttp import ClientSession, DigestAuthMiddleware + + # Create the middleware with your credentials + digest_auth = DigestAuthMiddleware(login="user", password="password") + + # Pass it to the ClientSession as a tuple + async with ClientSession(middlewares=(digest_auth,)) as session: + # The middleware will automatically handle auth challenges + async with session.get("https://example.com/protected") as resp: + print(await resp.text()) + +The :class:`DigestAuthMiddleware` implements HTTP Digest Authentication according to RFC 7616, +providing a more secure alternative to Basic Authentication. It supports all +standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session +variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options. +The middleware automatically handles the authentication flow by intercepting 401 responses +and retrying with proper credentials. + Note that if the request is redirected and the redirect URL contains credentials, those credentials will supersede any previously set credentials. In other words, if ``http://user@example.com`` redirects to diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 84e2f0c7014..030e07d9ef4 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1992,6 +1992,7 @@ Utilities .. versionadded:: 3.2 + .. class:: BasicAuth(login, password='', encoding='latin1') HTTP basic authentication helper. @@ -2032,6 +2033,34 @@ Utilities :return: encoded authentication data, :class:`str`. + +.. class:: DigestAuthMiddleware(login, password) + + HTTP digest authentication client middleware. + + :param str login: login + :param str password: password + + This middleware supports HTTP digest authentication with both `auth` and + `auth-int` quality of protection (qop) modes, and a variety of hashing algorithms. + + It automatically handles the digest authentication handshake by: + + - Parsing 401 Unauthorized responses with `WWW-Authenticate: Digest` headers + - Generating appropriate `Authorization: Digest` headers on retry + - Maintaining nonce counts and challenge data per request + + Usage:: + + digest_auth_middleware = DigestAuthMiddleware(login="user", password="pass") + async with ClientSession(middlewares=(digest_auth_middleware,)) as session: + async with session.get("http://protected.example.com") as resp: + # The middleware automatically handles the digest auth handshake + assert resp.status == 200 + + .. versionadded:: 3.12 + + .. class:: CookieJar(*, unsafe=False, quote_cookie=True, treat_as_secure_origin = []) The cookie jar instance is available as :attr:`ClientSession.cookie_jar`. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 16c8aa789e9..db6c500d5f7 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -252,6 +252,7 @@ pyflakes pyright pytest Pytest +qop Quickstart quote’s rc diff --git a/examples/digest_auth_qop_auth.py b/examples/digest_auth_qop_auth.py new file mode 100644 index 00000000000..508f444e9f9 --- /dev/null +++ b/examples/digest_auth_qop_auth.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +""" +Example of using digest authentication middleware with aiohttp client. + +This example shows how to use the DigestAuthMiddleware from +aiohttp.client_middleware_digest_auth to authenticate with a server +that requires digest authentication with different qop options. + +In this case, it connects to httpbin.org's digest auth endpoint. +""" + +import asyncio +from itertools import product + +from yarl import URL + +from aiohttp import ClientSession +from aiohttp.client_middleware_digest_auth import DigestAuthMiddleware + +# Define QOP options available +QOP_OPTIONS = ["auth", "auth-int"] + +# Algorithms supported by httpbin.org +ALGORITHMS = ["MD5", "SHA-256", "SHA-512"] + +# Username and password for testing +USERNAME = "my" +PASSWORD = "dog" + +# All combinations of QOP options and algorithms +TEST_COMBINATIONS = list(product(QOP_OPTIONS, ALGORITHMS)) + + +async def main() -> None: + # Create a DigestAuthMiddleware instance with appropriate credentials + digest_auth = DigestAuthMiddleware(login=USERNAME, password=PASSWORD) + + # Create a client session with the digest auth middleware + async with ClientSession(middlewares=(digest_auth,)) as session: + # Test each combination of QOP and algorithm + for qop, algorithm in TEST_COMBINATIONS: + print(f"\n\n=== Testing with qop={qop}, algorithm={algorithm} ===\n") + + url = URL( + f"https://httpbin.org/digest-auth/{qop}/{USERNAME}/{PASSWORD}/{algorithm}" + ) + + async with session.get(url) as resp: + print(f"Status: {resp.status}") + print(f"Headers: {resp.headers}") + + # Parse the JSON response + json_response = await resp.json() + print(f"Response: {json_response}") + + # Verify authentication was successful + if resp.status == 200: + print("\nAuthentication successful!") + print(f"Authenticated user: {json_response.get('user')}") + print( + f"Authentication method: {json_response.get('authenticated')}" + ) + else: + print("\nAuthentication failed.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_client_middleware_digest_auth.py b/tests/test_client_middleware_digest_auth.py new file mode 100644 index 00000000000..26118288913 --- /dev/null +++ b/tests/test_client_middleware_digest_auth.py @@ -0,0 +1,801 @@ +"""Test digest authentication middleware for aiohttp client.""" + +from hashlib import md5, sha1 +from typing import Generator, Union +from unittest import mock + +import pytest +from yarl import URL + +from aiohttp import ClientSession, hdrs +from aiohttp.client_exceptions import ClientError +from aiohttp.client_middleware_digest_auth import ( + DigestAuthChallenge, + DigestAuthMiddleware, + DigestFunctions, + escape_quotes, + parse_header_pairs, + unescape_quotes, +) +from aiohttp.client_reqrep import ClientResponse +from aiohttp.pytest_plugin import AiohttpServer +from aiohttp.web import Application, Request, Response + + +@pytest.fixture +def digest_auth_mw() -> DigestAuthMiddleware: + return DigestAuthMiddleware("user", "pass") + + +@pytest.fixture +def basic_challenge() -> DigestAuthChallenge: + """Return a basic digest auth challenge with required fields only.""" + return DigestAuthChallenge(realm="test", nonce="abc") + + +@pytest.fixture +def complete_challenge() -> DigestAuthChallenge: + """Return a complete digest auth challenge with all fields.""" + return DigestAuthChallenge( + realm="test", nonce="abc", qop="auth", algorithm="MD5", opaque="xyz" + ) + + +@pytest.fixture +def qop_challenge() -> DigestAuthChallenge: + """Return a digest auth challenge with qop field.""" + return DigestAuthChallenge(realm="test", nonce="abc", qop="auth") + + +@pytest.fixture +def no_qop_challenge() -> DigestAuthChallenge: + """Return a digest auth challenge without qop.""" + return DigestAuthChallenge(realm="test-realm", nonce="testnonce", algorithm="MD5") + + +@pytest.fixture +def auth_mw_with_challenge( + digest_auth_mw: DigestAuthMiddleware, complete_challenge: DigestAuthChallenge +) -> DigestAuthMiddleware: + """Return a digest auth middleware with pre-set challenge.""" + digest_auth_mw._challenge = complete_challenge + digest_auth_mw._last_nonce_bytes = complete_challenge["nonce"].encode("utf-8") + digest_auth_mw._nonce_count = 0 + return digest_auth_mw + + +@pytest.fixture +def mock_sha1_digest() -> Generator[mock.MagicMock, None, None]: + """Mock SHA1 to return a predictable value for testing.""" + mock_digest = mock.MagicMock(spec=sha1()) + mock_digest.hexdigest.return_value = "deadbeefcafebabe" + with mock.patch("hashlib.sha1", return_value=mock_digest) as patched: + yield patched + + +@pytest.fixture +def mock_md5_digest() -> Generator[mock.MagicMock, None, None]: + """Mock MD5 to return a predictable value for testing.""" + mock_digest = mock.MagicMock(spec=md5()) + mock_digest.hexdigest.return_value = "abcdef0123456789" + with mock.patch("hashlib.md5", return_value=mock_digest) as patched: + yield patched + + +@pytest.mark.parametrize( + ("response_status", "headers", "expected_result", "expected_challenge"), + [ + # Valid digest with all fields + ( + 401, + { + "www-authenticate": 'Digest realm="test", nonce="abc", ' + 'qop="auth", opaque="xyz", algorithm=MD5' + }, + True, + { + "realm": "test", + "nonce": "abc", + "qop": "auth", + "algorithm": "MD5", + "opaque": "xyz", + }, + ), + # Valid digest without opaque + ( + 401, + {"www-authenticate": 'Digest realm="test", nonce="abc", qop="auth"'}, + True, + {"realm": "test", "nonce": "abc", "qop": "auth"}, + ), + # Non-401 status + (200, {}, False, {}), # No challenge should be set + ], +) +async def test_authenticate_scenarios( + digest_auth_mw: DigestAuthMiddleware, + response_status: int, + headers: dict[str, str], + expected_result: bool, + expected_challenge: dict[str, str], +) -> None: + """Test different authentication scenarios.""" + response = mock.MagicMock(spec=ClientResponse) + response.status = response_status + response.headers = headers + + result = digest_auth_mw._authenticate(response) + assert result == expected_result + + if expected_result: + challenge_dict = dict(digest_auth_mw._challenge) + for key, value in expected_challenge.items(): + assert challenge_dict[key] == value + + +@pytest.mark.parametrize( + ("challenge", "expected_error"), + [ + ( + DigestAuthChallenge(), + "Malformed Digest auth challenge: Missing 'realm' parameter", + ), + ( + DigestAuthChallenge(nonce="abc"), + "Malformed Digest auth challenge: Missing 'realm' parameter", + ), + ( + DigestAuthChallenge(realm="test"), + "Malformed Digest auth challenge: Missing 'nonce' parameter", + ), + ( + DigestAuthChallenge(realm="test", nonce=""), + "Security issue: Digest auth challenge contains empty 'nonce' value", + ), + ], +) +def test_encode_validation_errors( + digest_auth_mw: DigestAuthMiddleware, + challenge: DigestAuthChallenge, + expected_error: str, +) -> None: + """Test validation errors when encoding digest auth headers.""" + digest_auth_mw._challenge = challenge + with pytest.raises(ClientError, match=expected_error): + digest_auth_mw._encode("GET", URL("http://example.com/resource"), "") + + +def test_encode_digest_with_md5(auth_mw_with_challenge: DigestAuthMiddleware) -> None: + header = auth_mw_with_challenge._encode( + "GET", URL("http://example.com/resource"), "" + ) + assert header.startswith("Digest ") + assert 'username="user"' in header + assert "algorithm=MD5" in header + + +@pytest.mark.parametrize( + "algorithm", ["MD5-SESS", "SHA-SESS", "SHA-256-SESS", "SHA-512-SESS"] +) +def test_encode_digest_with_sess_algorithms( + digest_auth_mw: DigestAuthMiddleware, + qop_challenge: DigestAuthChallenge, + algorithm: str, +) -> None: + """Test that all session-based digest algorithms work correctly.""" + # Create a modified challenge with the test algorithm + challenge = qop_challenge.copy() + challenge["algorithm"] = algorithm + digest_auth_mw._challenge = challenge + + header = digest_auth_mw._encode("GET", URL("http://example.com/resource"), "") + assert f"algorithm={algorithm}" in header + + +def test_encode_unsupported_algorithm( + digest_auth_mw: DigestAuthMiddleware, basic_challenge: DigestAuthChallenge +) -> None: + """Test that unsupported algorithm raises ClientError.""" + # Create a modified challenge with an unsupported algorithm + challenge = basic_challenge.copy() + challenge["algorithm"] = "UNSUPPORTED" + digest_auth_mw._challenge = challenge + + with pytest.raises(ClientError, match="Unsupported hash algorithm"): + digest_auth_mw._encode("GET", URL("http://example.com/resource"), "") + + +def test_invalid_qop_rejected( + digest_auth_mw: DigestAuthMiddleware, basic_challenge: DigestAuthChallenge +) -> None: + """Test that invalid Quality of Protection values are rejected.""" + # Use bad QoP value to trigger error + challenge = basic_challenge.copy() + challenge["qop"] = "badvalue" + challenge["algorithm"] = "MD5" + digest_auth_mw._challenge = challenge + + # This should raise an error about unsupported QoP + with pytest.raises(ClientError, match="Unsupported Quality of Protection"): + digest_auth_mw._encode("GET", URL("http://example.com"), "") + + +def compute_expected_digest( + algorithm: str, + username: str, + password: str, + realm: str, + nonce: str, + uri: str, + method: str, + qop: str, + nc: str, + cnonce: str, + body: str = "", +) -> str: + hash_fn = DigestFunctions[algorithm] + + def H(x: str) -> str: + return hash_fn(x.encode()).hexdigest() + + def KD(secret: str, data: str) -> str: + return H(f"{secret}:{data}") + + A1 = f"{username}:{realm}:{password}" + HA1 = H(A1) + + if algorithm.upper().endswith("-SESS"): + HA1 = H(f"{HA1}:{nonce}:{cnonce}") + + A2 = f"{method}:{uri}" + if "auth-int" in qop: + entity_hash = H(body) + A2 = f"{A2}:{entity_hash}" + HA2 = H(A2) + + if qop: + return KD(HA1, f"{nonce}:{nc}:{cnonce}:{qop}:{HA2}") + else: + return KD(HA1, f"{nonce}:{HA2}") + + +@pytest.mark.parametrize("qop", ["auth", "auth-int", "auth,auth-int", ""]) +@pytest.mark.parametrize("algorithm", sorted(DigestFunctions.keys())) +@pytest.mark.parametrize( + ("body", "body_str"), + [ + ("this is a body", "this is a body"), # String case + (b"this is a body", "this is a body"), # Bytes case + ], +) +def test_digest_response_exact_match( + qop: str, + algorithm: str, + body: Union[str, bytes], + body_str: str, + mock_sha1_digest: mock.MagicMock, +) -> None: + # Fixed input values + login = "user" + password = "pass" + realm = "example.com" + nonce = "abc123nonce" + cnonce = "deadbeefcafebabe" + nc = 1 + ncvalue = f"{nc+1:08x}" + method = "GET" + uri = "/secret" + qop = "auth-int" if "auth-int" in qop else "auth" + + # Create the auth object + auth = DigestAuthMiddleware(login, password) + auth._challenge = DigestAuthChallenge( + realm=realm, nonce=nonce, qop=qop, algorithm=algorithm + ) + auth._last_nonce_bytes = nonce.encode("utf-8") + auth._nonce_count = nc + + header = auth._encode(method, URL(f"http://host{uri}"), body) + + # Get expected digest + expected = compute_expected_digest( + algorithm=algorithm, + username=login, + password=password, + realm=realm, + nonce=nonce, + uri=uri, + method=method, + qop=qop, + nc=ncvalue, + cnonce=cnonce, + body=body_str, + ) + + # Check that the response digest is exactly correct + assert f'response="{expected}"' in header + + +@pytest.mark.parametrize( + ("header", "expected_result"), + [ + # Normal quoted values + ( + 'realm="example.com", nonce="12345", qop="auth"', + {"realm": "example.com", "nonce": "12345", "qop": "auth"}, + ), + # Unquoted values + ( + "realm=example.com, nonce=12345, qop=auth", + {"realm": "example.com", "nonce": "12345", "qop": "auth"}, + ), + # Mixed quoted/unquoted with commas in quoted values + ( + 'realm="ex,ample", nonce=12345, qop="auth", domain="/test"', + { + "realm": "ex,ample", + "nonce": "12345", + "qop": "auth", + "domain": "/test", + }, + ), + # Header with scheme + ( + 'Digest realm="example.com", nonce="12345", qop="auth"', + {"realm": "example.com", "nonce": "12345", "qop": "auth"}, + ), + # No spaces after commas + ( + 'realm="test",nonce="123",qop="auth"', + {"realm": "test", "nonce": "123", "qop": "auth"}, + ), + # Extra whitespace + ( + 'realm = "test" , nonce = "123"', + {"realm": "test", "nonce": "123"}, + ), + # Escaped quotes + ( + 'realm="test\\"realm", nonce="123"', + {"realm": 'test"realm', "nonce": "123"}, + ), + # Single quotes (treated as regular chars) + ( + "realm='test', nonce=123", + {"realm": "'test'", "nonce": "123"}, + ), + # Empty header + ("", {}), + ], + ids=[ + "fully_quoted_header", + "unquoted_header", + "mixed_quoted_unquoted_with_commas", + "header_with_scheme", + "no_spaces_after_commas", + "extra_whitespace", + "escaped_quotes", + "single_quotes_as_regular_chars", + "empty_header", + ], +) +def test_parse_header_pairs(header: str, expected_result: dict[str, str]) -> None: + """Test parsing HTTP header pairs with various formats.""" + result = parse_header_pairs(header) + assert result == expected_result + + +def test_digest_auth_middleware_callable(digest_auth_mw: DigestAuthMiddleware) -> None: + """Test that DigestAuthMiddleware is callable.""" + assert callable(digest_auth_mw) + + +def test_middleware_invalid_login() -> None: + """Test that invalid login values raise errors.""" + with pytest.raises(ValueError, match="None is not allowed as login value"): + DigestAuthMiddleware(None, "pass") # type: ignore[arg-type] + + with pytest.raises(ValueError, match="None is not allowed as password value"): + DigestAuthMiddleware("user", None) # type: ignore[arg-type] + + with pytest.raises(ValueError, match=r"A \":\" is not allowed in username"): + DigestAuthMiddleware("user:name", "pass") + + +def test_escaping_quotes_in_auth_header() -> None: + """Test that double quotes are properly escaped in auth header.""" + auth = DigestAuthMiddleware('user"with"quotes', "pass") + auth._challenge = DigestAuthChallenge( + realm='realm"with"quotes', + nonce='nonce"with"quotes', + qop="auth", + algorithm="MD5", + opaque='opaque"with"quotes', + ) + + header = auth._encode("GET", URL("http://example.com/path"), "") + + # Check that quotes are escaped in the header + assert 'username="user\\"with\\"quotes"' in header + assert 'realm="realm\\"with\\"quotes"' in header + assert 'nonce="nonce\\"with\\"quotes"' in header + assert 'opaque="opaque\\"with\\"quotes"' in header + + +def test_template_based_header_construction( + auth_mw_with_challenge: DigestAuthMiddleware, + mock_sha1_digest: mock.MagicMock, + mock_md5_digest: mock.MagicMock, +) -> None: + """Test that the template-based header construction works correctly.""" + header = auth_mw_with_challenge._encode("GET", URL("http://example.com/test"), "") + + # Split the header into scheme and parameters + scheme, params_str = header.split(" ", 1) + assert scheme == "Digest" + + # Parse the parameters into a dictionary + params = { + key: value[1:-1] if value.startswith('"') and value.endswith('"') else value + for key, value in (param.split("=", 1) for param in params_str.split(", ")) + } + + # Check all required fields are present + assert "username" in params + assert "realm" in params + assert "nonce" in params + assert "uri" in params + assert "response" in params + assert "algorithm" in params + assert "qop" in params + assert "nc" in params + assert "cnonce" in params + assert "opaque" in params + + # Check that fields are quoted correctly + quoted_fields = [ + "username", + "realm", + "nonce", + "uri", + "response", + "opaque", + "cnonce", + ] + unquoted_fields = ["algorithm", "qop", "nc"] + + # Re-check the original header for proper quoting + for field in quoted_fields: + assert f'{field}="{params[field]}"' in header + + for field in unquoted_fields: + assert f"{field}={params[field]}" in header + + # Check specific values + assert params["username"] == "user" + assert params["realm"] == "test" + assert params["algorithm"] == "MD5" + assert params["nc"] == "00000001" # nonce_count = 1 (incremented from 0) + assert params["uri"] == "/test" # path component of URL + + +@pytest.mark.parametrize( + ("test_string", "expected_escaped", "description"), + [ + ('value"with"quotes', 'value\\"with\\"quotes', "Basic string with quotes"), + ("", "", "Empty string"), + ("no quotes", "no quotes", "String without quotes"), + ('with"one"quote', 'with\\"one\\"quote', "String with one quoted segment"), + ( + 'many"quotes"in"string', + 'many\\"quotes\\"in\\"string', + "String with multiple quoted segments", + ), + ('""', '\\"\\"', "Just double quotes"), + ('"', '\\"', "Single double quote"), + ('already\\"escaped', 'already\\\\"escaped', "Already escaped quotes"), + ], +) +def test_quote_escaping_functions( + test_string: str, expected_escaped: str, description: str +) -> None: + """Test that escape_quotes and unescape_quotes work correctly.""" + # Test escaping + escaped = escape_quotes(test_string) + assert escaped == expected_escaped + + # Test unescaping (should return to original) + unescaped = unescape_quotes(escaped) + assert unescaped == test_string + + # Test that they're inverse operations + assert unescape_quotes(escape_quotes(test_string)) == test_string + + +async def test_middleware_retry_on_401( + aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware +) -> None: + """Test that the middleware retries on 401 errors.""" + request_count = 0 + + async def handler(request: Request) -> Response: + nonlocal request_count + request_count += 1 + + if request_count == 1: + # First request returns 401 with digest challenge + challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5' + return Response( + status=401, + headers={"WWW-Authenticate": challenge}, + text="Unauthorized", + ) + + # Second request should have Authorization header + auth_header = request.headers.get(hdrs.AUTHORIZATION) + if auth_header and auth_header.startswith("Digest "): + # Return success response + return Response(text="OK") + + # This branch should not be reached in the tests + assert False, "This branch should not be reached" + + app = Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(digest_auth_mw,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text_content = await resp.text() + assert text_content == "OK" + + assert request_count == 2 # Initial request + retry with auth + + +async def test_digest_auth_no_qop( + aiohttp_server: AiohttpServer, + digest_auth_mw: DigestAuthMiddleware, + no_qop_challenge: DigestAuthChallenge, + mock_sha1_digest: mock.MagicMock, +) -> None: + """Test digest auth with a server that doesn't provide a QoP parameter.""" + request_count = 0 + realm = no_qop_challenge["realm"] + nonce = no_qop_challenge["nonce"] + algorithm = no_qop_challenge["algorithm"] + username = "user" + password = "pass" + uri = "/" + + async def handler(request: Request) -> Response: + nonlocal request_count + request_count += 1 + + if request_count == 1: + # First request returns 401 with digest challenge without qop + challenge = ( + f'Digest realm="{realm}", nonce="{nonce}", algorithm={algorithm}' + ) + return Response( + status=401, + headers={"WWW-Authenticate": challenge}, + text="Unauthorized", + ) + + # Second request should have Authorization header + auth_header = request.headers.get(hdrs.AUTHORIZATION) + assert auth_header and auth_header.startswith("Digest ") + + # Successful auth should have no qop param + assert "qop=" not in auth_header + assert "nc=" not in auth_header + assert "cnonce=" not in auth_header + + expected_digest = compute_expected_digest( + algorithm=algorithm, + username=username, + password=password, + realm=realm, + nonce=nonce, + uri=uri, + method="GET", + qop="", # This is the key part - explicitly setting qop="" + nc="", # Not needed for non-qop digest + cnonce="", # Not needed for non-qop digest + ) + # We mock the cnonce, so we can check the expected digest + assert expected_digest in auth_header + + return Response(text="OK") + + app = Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(digest_auth_mw,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text_content = await resp.text() + assert text_content == "OK" + + assert request_count == 2 # Initial request + retry with auth + + +async def test_digest_auth_without_opaque( + aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware +) -> None: + """Test digest auth with a server that doesn't provide an opaque parameter.""" + request_count = 0 + + async def handler(request: Request) -> Response: + nonlocal request_count + request_count += 1 + + if request_count == 1: + # First request returns 401 with digest challenge without opaque + challenge = ( + 'Digest realm="test-realm", nonce="testnonce", ' + 'qop="auth", algorithm=MD5' + ) + return Response( + status=401, + headers={"WWW-Authenticate": challenge}, + text="Unauthorized", + ) + + # Second request should have Authorization header + auth_header = request.headers.get(hdrs.AUTHORIZATION) + assert auth_header and auth_header.startswith("Digest ") + # Successful auth should have no opaque param + assert "opaque=" not in auth_header + + return Response(text="OK") + + app = Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(digest_auth_mw,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text_content = await resp.text() + assert text_content == "OK" + + assert request_count == 2 # Initial request + retry with auth + + +@pytest.mark.parametrize( + "www_authenticate", + [ + None, + "DigestWithoutSpace", + 'Basic realm="test"', + "Digest ", + "Digest =invalid, format", + ], +) +async def test_auth_header_no_retry( + aiohttp_server: AiohttpServer, + www_authenticate: str, + digest_auth_mw: DigestAuthMiddleware, +) -> None: + """Test that middleware doesn't retry with invalid WWW-Authenticate headers.""" + request_count = 0 + + async def handler(request: Request) -> Response: + nonlocal request_count + request_count += 1 + + # First (and only) request returns 401 + headers = {} + if www_authenticate is not None: + headers["WWW-Authenticate"] = www_authenticate + + # Use a custom HTTPUnauthorized instead of the helper since + # we're specifically testing malformed headers + return Response(status=401, headers=headers, text="Unauthorized") + + app = Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(digest_auth_mw,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 401 + + # No retry should happen + assert request_count == 1 + + +async def test_direct_success_no_auth_needed( + aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware +) -> None: + """Test middleware with a direct 200 response with no auth challenge.""" + request_count = 0 + + async def handler(request: Request) -> Response: + nonlocal request_count + request_count += 1 + + # Return success without auth challenge + return Response(text="OK") + + app = Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(digest_auth_mw,)) as session: + async with session.get(server.make_url("/")) as resp: + text = await resp.text() + assert resp.status == 200 + assert text == "OK" + + # Verify only one request was made + assert request_count == 1 + + +async def test_no_retry_on_second_401( + aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware +) -> None: + """Test digest auth does not retry on second 401.""" + request_count = 0 + + async def handler(request: Request) -> Response: + nonlocal request_count + request_count += 1 + + # Always return 401 challenge + challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5' + return Response( + status=401, + headers={"WWW-Authenticate": challenge}, + text="Unauthorized", + ) + + app = Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Create a session that uses the digest auth middleware + async with ClientSession(middlewares=(digest_auth_mw,)) as session: + async with session.get(server.make_url("/")) as resp: + await resp.text() + assert resp.status == 401 + + # Verify we made exactly 2 requests (initial + 1 retry) + assert request_count == 2 + + +@pytest.mark.parametrize( + ("status", "headers", "expected"), + [ + (200, {}, False), + (401, {"www-authenticate": ""}, False), + (401, {"www-authenticate": "DigestWithoutSpace"}, False), + (401, {"www-authenticate": "Basic realm=test"}, False), + (401, {"www-authenticate": "Digest "}, False), + (401, {"www-authenticate": "Digest =invalid, format"}, False), + ], + ids=[ + "different_status_code", + "empty_www_authenticate_header", + "no_space_after_scheme", + "different_scheme", + "empty_parameters", + "malformed_parameters", + ], +) +def test_authenticate_with_malformed_headers( + digest_auth_mw: DigestAuthMiddleware, + status: int, + headers: dict[str, str], + expected: bool, +) -> None: + """Test _authenticate method with various edge cases.""" + response = mock.MagicMock(spec=ClientResponse) + response.status = status + response.headers = headers + + result = digest_auth_mw._authenticate(response) + assert result == expected