-
Notifications
You must be signed in to change notification settings - Fork 271
feat(python): eliminate SSE dependency 'httpx_sse' by hard-forking into core_utilities #9784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
848a8b3
574461e
fe130b9
5ea9466
c29f4ae
8e6ad98
130ae61
9a8c4e8
6d4128e
dd4c6f4
f6491bb
7951ea7
56d26d3
7188132
ec6d9cc
2089bce
2ba8ea5
77ef32c
37b1c9b
1b8e8c8
6f55245
c837357
86accef
1485998
385ee44
6388b30
abf5eea
0ad05ab
f25543f
f384986
e0266e0
d2da4a7
706989d
00aa224
c271c1e
2c876f2
a4bb7d5
b5c5f7f
2050678
a699f38
35b45dc
49cbea8
c7f85c5
940fbca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| from ._api import EventSource, aconnect_sse, connect_sse | ||
| from ._exceptions import SSEError | ||
| from ._models import ServerSentEvent | ||
|
|
||
| __version__ = "0.4.1" | ||
|
|
||
| __all__ = [ | ||
| "__version__", | ||
| "EventSource", | ||
| "connect_sse", | ||
| "aconnect_sse", | ||
| "ServerSentEvent", | ||
| "SSEError", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| import re | ||
| from contextlib import asynccontextmanager, contextmanager | ||
| from typing import Any, AsyncGenerator, AsyncIterator, Iterator, cast | ||
|
|
||
| import httpx | ||
| from ._decoders import SSEDecoder | ||
| from ._exceptions import SSEError | ||
| from ._models import ServerSentEvent | ||
|
|
||
|
|
||
| class EventSource: | ||
| def __init__(self, response: httpx.Response) -> None: | ||
| self._response = response | ||
|
|
||
| def _check_content_type(self) -> None: | ||
| content_type = self._response.headers.get("content-type", "").partition(";")[0] | ||
| if "text/event-stream" not in content_type: | ||
| raise SSEError( | ||
| f"Expected response header Content-Type to contain 'text/event-stream', got {content_type!r}" | ||
| ) | ||
|
|
||
| def _get_charset(self) -> str: | ||
| """Extract charset from Content-Type header, fallback to UTF-8.""" | ||
| content_type = self._response.headers.get("content-type", "") | ||
|
|
||
| # Parse charset parameter using regex | ||
| charset_match = re.search(r"charset=([^;\s]+)", content_type, re.IGNORECASE) | ||
| if charset_match: | ||
| charset = charset_match.group(1).strip("\"'") | ||
| # Validate that it's a known encoding | ||
| try: | ||
| # Test if the charset is valid by trying to encode/decode | ||
| "test".encode(charset).decode(charset) | ||
| return charset | ||
| except (LookupError, UnicodeError): | ||
| # If charset is invalid, fall back to UTF-8 | ||
| pass | ||
|
|
||
| # Default to UTF-8 if no charset specified or invalid charset | ||
| return "utf-8" | ||
|
|
||
| @property | ||
| def response(self) -> httpx.Response: | ||
| return self._response | ||
|
|
||
| def iter_sse(self) -> Iterator[ServerSentEvent]: | ||
| self._check_content_type() | ||
| decoder = SSEDecoder() | ||
| charset = self._get_charset() | ||
|
|
||
| buffer = "" | ||
| for chunk in self._response.iter_bytes(): | ||
| # Decode chunk using detected charset | ||
| text_chunk = chunk.decode(charset, errors="replace") | ||
| buffer += text_chunk | ||
|
|
||
| # Process complete lines | ||
| while "\n" in buffer: | ||
| line, buffer = buffer.split("\n", 1) | ||
| line = line.rstrip("\r") | ||
| sse = decoder.decode(line) | ||
| # when we reach a "\n\n" => line = '' | ||
| # => decoder will attempt to return an SSE Event | ||
| if sse is not None: | ||
| yield sse | ||
|
|
||
| # Process any remaining data in buffer | ||
| if buffer.strip(): | ||
| line = buffer.rstrip("\r") | ||
| sse = decoder.decode(line) | ||
| if sse is not None: | ||
| yield sse | ||
|
|
||
| async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]: | ||
| self._check_content_type() | ||
| decoder = SSEDecoder() | ||
| lines = cast(AsyncGenerator[str, None], self._response.aiter_lines()) | ||
| try: | ||
| async for line in lines: | ||
| line = line.rstrip("\n") | ||
| sse = decoder.decode(line) | ||
| if sse is not None: | ||
| yield sse | ||
| finally: | ||
| await lines.aclose() | ||
|
|
||
|
|
||
| @contextmanager | ||
| def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any) -> Iterator[EventSource]: | ||
| headers = kwargs.pop("headers", {}) | ||
| headers["Accept"] = "text/event-stream" | ||
| headers["Cache-Control"] = "no-store" | ||
|
|
||
| with client.stream(method, url, headers=headers, **kwargs) as response: | ||
| yield EventSource(response) | ||
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def aconnect_sse( | ||
| client: httpx.AsyncClient, | ||
| method: str, | ||
| url: str, | ||
| **kwargs: Any, | ||
| ) -> AsyncIterator[EventSource]: | ||
| headers = kwargs.pop("headers", {}) | ||
| headers["Accept"] = "text/event-stream" | ||
| headers["Cache-Control"] = "no-store" | ||
|
|
||
| async with client.stream(method, url, headers=headers, **kwargs) as response: | ||
| yield EventSource(response) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from typing import List, Optional | ||
|
|
||
| from ._models import ServerSentEvent | ||
|
|
||
|
|
||
| class SSEDecoder: | ||
| def __init__(self) -> None: | ||
| self._event = "" | ||
| self._data: List[str] = [] | ||
| self._last_event_id = "" | ||
| self._retry: Optional[int] = None | ||
|
|
||
| def decode(self, line: str) -> Optional[ServerSentEvent]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The decode method should strip any trailing '\r' characters from the input line before processing it, similar to how it's done in the Spotted by Diamond |
||
| # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 | ||
|
|
||
| if not line: | ||
| if not self._event and not self._data and not self._last_event_id and self._retry is None: | ||
| return None | ||
|
|
||
| sse = ServerSentEvent( | ||
| event=self._event, | ||
| data="\n".join(self._data), | ||
| id=self._last_event_id, | ||
| retry=self._retry, | ||
| ) | ||
|
|
||
| # NOTE: as per the SSE spec, do not reset last_event_id. | ||
| self._event = "" | ||
| self._data = [] | ||
| self._retry = None | ||
|
|
||
| return sse | ||
|
|
||
| if line.startswith(":"): | ||
| return None | ||
|
|
||
| fieldname, _, value = line.partition(":") | ||
|
|
||
| if value.startswith(" "): | ||
| value = value[1:] | ||
|
|
||
| if fieldname == "event": | ||
| self._event = value | ||
| elif fieldname == "data": | ||
| self._data.append(value) | ||
| elif fieldname == "id": | ||
| if "\0" in value: | ||
| pass | ||
| else: | ||
| self._last_event_id = value | ||
| elif fieldname == "retry": | ||
| try: | ||
| self._retry = int(value) | ||
| except (TypeError, ValueError): | ||
| pass | ||
| else: | ||
| pass # Field is ignored. | ||
|
|
||
| return None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| import httpx | ||
|
|
||
|
|
||
| class SSEError(httpx.TransportError): | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| import json | ||
| from dataclasses import dataclass | ||
| from typing import Any, Optional | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ServerSentEvent: | ||
| event: str = "message" | ||
| data: str = "" | ||
| id: str = "" | ||
| retry: Optional[int] = None | ||
|
|
||
| def json(self) -> Any: | ||
| """Parse the data field as JSON.""" | ||
| return json.loads(self.data) | ||
tjb9dc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you test this with a multiline event (like a JSON blob containing a content field with extra lines)? This doesn't appear to handle that to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it out with Cohere's bug. The extra lines in the JSON blob must be escaped ("\n\n") and this works fine. If they are regular double new lines ("\n\n"), we will interpret it as a new event.