Skip to content

Commit 0081403

Browse files
author
Rutik Thakre
committed
Refactor event source handling: Introduce NdJsonStream and Stream trait for improved stream management
1 parent 1825e58 commit 0081403

File tree

5 files changed

+278
-82
lines changed

5 files changed

+278
-82
lines changed

llm/src/event_source/error.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use crate::event_source::event_stream::EventStreamError;
21
use core::fmt;
3-
use golem_rust::bindings::wasi::io::streams::StreamError;
2+
use golem_rust::bindings::wasi::io::streams::StreamError as WasiStreamError;
43
use nom::error::Error as NomError;
54
use reqwest::header::HeaderValue;
65
use reqwest::Error as ReqwestError;
@@ -9,6 +8,8 @@ use reqwest::StatusCode;
98
use std::string::FromUtf8Error;
109
use thiserror::Error;
1110

11+
use super::stream::StreamError;
12+
1213
/// Error raised when a [`RequestBuilder`] cannot be cloned. See [`RequestBuilder::try_clone`] for
1314
/// more information
1415
#[derive(Debug, Clone, Copy)]
@@ -51,24 +52,24 @@ pub enum Error {
5152
StreamEnded,
5253
}
5354

54-
impl From<EventStreamError<ReqwestError>> for Error {
55-
fn from(err: EventStreamError<ReqwestError>) -> Self {
55+
impl From<StreamError<ReqwestError>> for Error {
56+
fn from(err: StreamError<ReqwestError>) -> Self {
5657
match err {
57-
EventStreamError::Utf8(err) => Self::Utf8(err),
58-
EventStreamError::Parser(err) => Self::Parser(err),
59-
EventStreamError::Transport(err) => Self::Transport(err),
58+
StreamError::Utf8(err) => Self::Utf8(err),
59+
StreamError::Parser(err) => Self::Parser(err),
60+
StreamError::Transport(err) => Self::Transport(err),
6061
}
6162
}
6263
}
6364

64-
impl From<EventStreamError<StreamError>> for Error {
65-
fn from(err: EventStreamError<StreamError>) -> Self {
65+
impl From<StreamError<WasiStreamError>> for Error {
66+
fn from(err: StreamError<WasiStreamError>) -> Self {
6667
match err {
67-
EventStreamError::Utf8(err) => Self::Utf8(err),
68-
EventStreamError::Parser(err) => Self::Parser(err),
69-
EventStreamError::Transport(err) => match err {
70-
StreamError::Closed => Self::StreamEnded,
71-
StreamError::LastOperationFailed(err) => {
68+
StreamError::Utf8(err) => Self::Utf8(err),
69+
StreamError::Parser(err) => Self::Parser(err),
70+
StreamError::Transport(err) => match err {
71+
WasiStreamError::Closed => Self::StreamEnded,
72+
WasiStreamError::LastOperationFailed(err) => {
7273
Self::TransportStream(err.to_debug_string())
7374
}
7475
},

llm/src/event_source/event_stream.rs

Lines changed: 8 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
use crate::event_source::parser::{is_bom, is_lf, line, RawEventLine};
22
use crate::event_source::utf8_stream::{Utf8Stream, Utf8StreamError};
33
use crate::event_source::MessageEvent;
4-
use core::fmt;
54
use core::time::Duration;
65
use golem_rust::bindings::wasi::io::streams::{InputStream, StreamError};
76
use golem_rust::wasm_rpc::Pollable;
87
use log::trace;
9-
use nom::error::Error as NomError;
10-
use std::string::FromUtf8Error;
118
use std::task::Poll;
129

10+
use super::stream::{Stream, StreamError as EventStreamError};
11+
1312
#[derive(Default, Debug)]
1413
struct EventBuilder {
1514
event: MessageEvent,
@@ -133,9 +132,9 @@ pub struct EventStream {
133132
last_event_id: String,
134133
}
135134

136-
impl EventStream {
135+
impl Stream for EventStream {
137136
/// Initialize the EventStream with a Stream
138-
pub fn new(stream: InputStream) -> Self {
137+
fn new(stream: InputStream) -> Self {
139138
println!("EventStream::new");
140139
Self {
141140
stream: Utf8Stream::new(stream),
@@ -148,23 +147,23 @@ impl EventStream {
148147

149148
/// Set the last event ID of the stream. Useful for initializing the stream with a previous
150149
/// last event ID
151-
pub fn set_last_event_id(&mut self, id: impl Into<String>) {
150+
fn set_last_event_id(&mut self, id: impl Into<String>) {
152151
println!("EventStream::set_last_event_id");
153152
self.last_event_id = id.into();
154153
}
155154

156155
/// Get the last event ID of the stream
157-
pub fn last_event_id(&self) -> &str {
156+
fn last_event_id(&self) -> &str {
158157
println!("EventStream::last_event_id");
159158
&self.last_event_id
160159
}
161160

162-
pub fn subscribe(&self) -> Pollable {
161+
fn subscribe(&self) -> Pollable {
163162
println!("EventStream::subscribe");
164163
self.stream.subscribe()
165164
}
166165

167-
pub fn poll_next(
166+
fn poll_next(
168167
&mut self,
169168
) -> Poll<Option<Result<MessageEvent, EventStreamError<StreamError>>>> {
170169
println!("EventStream::poll_next buffer {} ", self.buffer.as_str());
@@ -222,50 +221,6 @@ impl EventStream {
222221
}
223222
}
224223

225-
/// Error thrown while parsing an event line
226-
#[derive(Debug, PartialEq)]
227-
pub enum EventStreamError<E> {
228-
/// Source stream is not valid UTF8
229-
Utf8(FromUtf8Error),
230-
/// Source stream is not a valid EventStream
231-
Parser(NomError<String>),
232-
/// Underlying source stream error
233-
Transport(E),
234-
}
235-
236-
impl<E> From<Utf8StreamError<E>> for EventStreamError<E> {
237-
fn from(err: Utf8StreamError<E>) -> Self {
238-
println!("EventStreamError::from");
239-
match err {
240-
Utf8StreamError::Utf8(err) => Self::Utf8(err),
241-
Utf8StreamError::Transport(err) => Self::Transport(err),
242-
}
243-
}
244-
}
245-
246-
impl<E> From<NomError<&str>> for EventStreamError<E> {
247-
fn from(err: NomError<&str>) -> Self {
248-
println!("EventStreamError::from");
249-
EventStreamError::Parser(NomError::new(err.input.to_string(), err.code))
250-
}
251-
}
252-
253-
impl<E> fmt::Display for EventStreamError<E>
254-
where
255-
E: fmt::Display,
256-
{
257-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258-
println!("EventStreamError::fmt");
259-
match self {
260-
Self::Utf8(err) => f.write_fmt(format_args!("UTF8 error: {}", err)),
261-
Self::Parser(err) => f.write_fmt(format_args!("Parse error: {}", err)),
262-
Self::Transport(err) => f.write_fmt(format_args!("Transport error: {}", err)),
263-
}
264-
}
265-
}
266-
267-
impl<E> std::error::Error for EventStreamError<E> where E: fmt::Display + fmt::Debug + Send + Sync {}
268-
269224
fn parse_event<E>(
270225
buffer: &mut String,
271226
builder: &mut EventBuilder,

llm/src/event_source/mod.rs

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44
pub mod error;
55
mod event_stream;
66
mod message_event;
7+
mod ndjson_stream;
78
mod parser;
9+
mod stream;
810
mod utf8_stream;
911

1012
use crate::event_source::error::Error;
1113
use crate::event_source::event_stream::EventStream;
1214
use golem_rust::wasm_rpc::Pollable;
1315
pub use message_event::MessageEvent;
16+
use ndjson_stream::NdJsonStream;
1417
use reqwest::header::HeaderValue;
1518
use reqwest::{Response, StatusCode};
1619
use std::task::Poll;
20+
use stream::{Stream, StreamType};
1721

1822
/// The ready state of an [`EventSource`]
1923
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
@@ -28,7 +32,8 @@ pub enum ReadyState {
2832
}
2933

3034
pub struct EventSource {
31-
stream: EventStream,
35+
/// stream is the type which implements Stream trait
36+
stream: StreamType,
3237
response: Response,
3338
is_closed: bool,
3439
}
@@ -45,7 +50,19 @@ impl EventSource {
4550
golem_rust::bindings::wasi::io::streams::InputStream,
4651
>(response.get_raw_input_stream())
4752
};
48-
let stream = EventStream::new(handle);
53+
54+
let stream = if response
55+
.headers()
56+
.get(&reqwest::header::CONTENT_TYPE)
57+
.unwrap()
58+
.to_str()
59+
.unwrap()
60+
.contains("ndjson")
61+
{
62+
StreamType::NdJsonStream(NdJsonStream::new(handle))
63+
} else {
64+
StreamType::EventStream(EventStream::new(handle))
65+
};
4966
Ok(Self {
5067
response,
5168
stream,
@@ -74,7 +91,10 @@ impl EventSource {
7491

7592
pub fn subscribe(&self) -> Pollable {
7693
println!("EventSource::subscribe");
77-
self.stream.subscribe()
94+
match &self.stream {
95+
StreamType::EventStream(stream) => stream.subscribe(),
96+
StreamType::NdJsonStream(stream) => stream.subscribe(),
97+
}
7898
}
7999

80100
pub fn poll_next(&mut self) -> Poll<Option<Result<Event, Error>>> {
@@ -83,19 +103,23 @@ impl EventSource {
83103
return Poll::Ready(None);
84104
}
85105

86-
match self.stream.poll_next() {
87-
Poll::Ready(Some(Err(err))) => {
88-
let err = err.into();
89-
self.is_closed = true;
90-
Poll::Ready(Some(Err(err)))
106+
match &mut self.stream {
107+
StreamType::EventStream( stream) => {
108+
match stream.poll_next() {
109+
Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(Event::Message(event)))),
110+
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
111+
Poll::Ready(None) => Poll::Ready(None),
112+
Poll::Pending => Poll::Pending,
113+
}
91114
}
92-
Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(event.into()))),
93-
Poll::Ready(None) => {
94-
let err = Error::StreamEnded;
95-
self.is_closed = true;
96-
Poll::Ready(Some(Err(err)))
115+
StreamType::NdJsonStream(stream) => {
116+
match stream.poll_next() {
117+
Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(Event::Message(event)))),
118+
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
119+
Poll::Ready(None) => Poll::Ready(None),
120+
Poll::Pending => Poll::Pending,
121+
}
97122
}
98-
Poll::Pending => Poll::Pending,
99123
}
100124
}
101125
}
@@ -126,7 +150,7 @@ fn check_response(response: Response) -> Result<Response, Error> {
126150
matches!(
127151
(mime_type.type_(), mime_type.subtype()),
128152
(mime::TEXT, mime::EVENT_STREAM)
129-
) || mime_type.subtype().as_str().contains("ndjson")
153+
) || mime_type.subtype().as_str().contains("ndjson")
130154
})
131155
.unwrap_or(false)
132156
{

0 commit comments

Comments
 (0)