Skip to content

Commit 3a3857f

Browse files
committed
Drop Batch associated type from DetectionBatcher and generics, drop detector_id from DetectionStream, integrate single detection stream optimization
Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com>
1 parent 2d3b5e4 commit 3a3857f

File tree

9 files changed

+134
-320
lines changed

9 files changed

+134
-320
lines changed

src/orchestrator/common/tasks.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ pub async fn text_contents_detections(
190190
ctx: Arc<Context>,
191191
headers: HeaderMap,
192192
detectors: HashMap<String, DetectorParams>,
193-
input_id: InputId,
193+
input_id: u32,
194194
inputs: Vec<(usize, String)>,
195-
) -> Result<(InputId, Detections), Error> {
195+
) -> Result<(u32, Detections), Error> {
196196
let chunkers = get_chunker_ids(&ctx, &detectors)?;
197197
let chunk_map = chunks(ctx.clone(), chunkers, inputs).await?;
198198
let inputs = detectors
@@ -249,7 +249,7 @@ pub async fn text_contents_detection_streams(
249249
ctx: Arc<Context>,
250250
headers: HeaderMap,
251251
detectors: HashMap<String, DetectorParams>,
252-
input_id: InputId,
252+
input_id: u32,
253253
input_rx: mpsc::Receiver<Result<(usize, String), Error>>, // (message_index, text)
254254
) -> Result<Vec<DetectionStream>, Error> {
255255
// Create chunk streams
@@ -294,14 +294,8 @@ pub async fn text_contents_detection_streams(
294294
.filter(|detection| detection.score >= threshold)
295295
.collect::<Detections>();
296296
// Send to detection channel
297-
let _ = detection_tx
298-
.send(Ok((
299-
input_id,
300-
detector_id.clone(),
301-
chunk,
302-
detections,
303-
)))
304-
.await;
297+
let _ =
298+
detection_tx.send(Ok((input_id, chunk, detections))).await;
305299
}
306300
Err(error) => {
307301
// Send error to detection channel
@@ -985,9 +979,7 @@ mod test {
985979

986980
let mut fake_detector_stream = detection_streams.swap_remove(0);
987981
let mut results = Vec::with_capacity(1);
988-
while let Some(Ok((_input_id, _detector_id, _chunk, detections))) =
989-
fake_detector_stream.next().await
990-
{
982+
while let Some(Ok((_input_id, _chunk, detections))) = fake_detector_stream.next().await {
991983
results.push(detections);
992984
}
993985
assert_eq!(results.len(), 1);

src/orchestrator/handlers/streaming_classification_with_gen.rs

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ use crate::{
4040
Context, Error, Orchestrator,
4141
common::{self, validate_detectors},
4242
types::{
43-
Chunk, DetectionBatchStream, DetectionStream, Detections, GenerationStream,
44-
MaxProcessedIndexBatcher,
43+
Chunk, DetectionBatchStream, Detections, GenerationStream, MaxProcessedIndexBatcher,
4544
},
4645
},
4746
};
@@ -254,12 +253,6 @@ async fn handle_output_detection(
254253
let generations = generations.clone();
255254
async move {
256255
match detection_streams {
257-
Ok(mut detection_streams) if detection_streams.len() == 1 => {
258-
// Process single detection stream, batching not applicable
259-
let detection_stream = detection_streams.swap_remove(0);
260-
process_detection_stream(trace_id, generations, detection_stream, response_tx)
261-
.await;
262-
}
263256
Ok(detection_streams) => {
264257
// Create detection batch stream
265258
let detection_batch_stream = DetectionBatchStream::new(
@@ -335,47 +328,17 @@ async fn forward_generation_stream(
335328
info!(%trace_id, "task completed: generation stream closed");
336329
}
337330

338-
/// Consumes a detection stream, builds responses, and sends them to a response channel.
339-
#[instrument(skip_all)]
340-
async fn process_detection_stream(
341-
trace_id: TraceId,
342-
generations: Arc<RwLock<Vec<ClassifiedGeneratedTextStreamResult>>>,
343-
mut detection_stream: DetectionStream,
344-
response_tx: mpsc::Sender<Result<ClassifiedGeneratedTextStreamResult, Error>>,
345-
) {
346-
while let Some(result) = detection_stream.next().await {
347-
match result {
348-
Ok((_, _detector_id, chunk, detections)) => {
349-
// Create response for this batch with output detections
350-
let response = output_detection_response(&generations, chunk, detections).unwrap();
351-
// Send message to response channel
352-
if response_tx.send(Ok(response)).await.is_err() {
353-
info!(%trace_id, "task completed: client disconnected");
354-
return;
355-
}
356-
}
357-
Err(error) => {
358-
error!(%trace_id, %error, "task failed: error received from detection stream");
359-
// Send error to response channel and terminate
360-
let _ = response_tx.send(Err(error)).await;
361-
return;
362-
}
363-
}
364-
}
365-
info!(%trace_id, "task completed: detection stream closed");
366-
}
367-
368331
/// Consumes a detection batch stream, builds responses, and sends them to a response channel.
369332
#[instrument(skip_all)]
370333
async fn process_detection_batch_stream(
371334
trace_id: TraceId,
372335
generations: Arc<RwLock<Vec<ClassifiedGeneratedTextStreamResult>>>,
373-
mut detection_batch_stream: DetectionBatchStream<MaxProcessedIndexBatcher>,
336+
mut detection_batch_stream: DetectionBatchStream,
374337
response_tx: mpsc::Sender<Result<ClassifiedGeneratedTextStreamResult, Error>>,
375338
) {
376339
while let Some(result) = detection_batch_stream.next().await {
377340
match result {
378-
Ok((chunk, detections)) => {
341+
Ok((_, chunk, detections)) => {
379342
// Create response for this batch with output detections
380343
let response = output_detection_response(&generations, chunk, detections).unwrap();
381344
// Send message to response channel

src/orchestrator/handlers/streaming_content_detection.rs

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use crate::{
3030
orchestrator::{
3131
Context, Error, Orchestrator,
3232
common::{self, validate_detectors},
33-
types::{BoxStream, DetectionBatchStream, DetectionStream, MaxProcessedIndexBatcher},
33+
types::{BoxStream, DetectionBatchStream, MaxProcessedIndexBatcher},
3434
},
3535
};
3636

@@ -133,11 +133,6 @@ async fn handle_detection(
133133
tokio::spawn(
134134
async move {
135135
match detection_streams {
136-
Ok(mut detection_streams) if detection_streams.len() == 1 => {
137-
// Process single detection stream, batching not applicable
138-
let detection_stream = detection_streams.swap_remove(0);
139-
process_detection_stream(trace_id, detection_stream, response_tx).await;
140-
}
141136
Ok(detection_streams) => {
142137
// Create detection batch stream
143138
let detection_batch_stream = DetectionBatchStream::new(
@@ -177,48 +172,16 @@ async fn handle_detection(
177172
);
178173
}
179174

180-
/// Consumes a detection stream, builds responses, and sends them to a response channel.
181-
#[instrument(skip_all)]
182-
async fn process_detection_stream(
183-
trace_id: TraceId,
184-
mut detection_stream: DetectionStream,
185-
response_tx: mpsc::Sender<Result<StreamingContentDetectionResponse, Error>>,
186-
) {
187-
while let Some(result) = detection_stream.next().await {
188-
match result {
189-
Ok((_, _detector_id, chunk, detections)) => {
190-
let response = StreamingContentDetectionResponse {
191-
start_index: chunk.start as u32,
192-
processed_index: chunk.end as u32,
193-
detections: detections.into(),
194-
};
195-
// Send message to response channel
196-
if response_tx.send(Ok(response)).await.is_err() {
197-
info!(%trace_id, "task completed: client disconnected");
198-
return;
199-
}
200-
}
201-
Err(error) => {
202-
error!(%trace_id, %error, "task failed: error received from detection stream");
203-
// Send error to response channel and terminate
204-
let _ = response_tx.send(Err(error)).await;
205-
return;
206-
}
207-
}
208-
}
209-
info!(%trace_id, "task completed: detection stream closed");
210-
}
211-
212175
/// Consumes a detection batch stream, builds responses, and sends them to a response channel.
213176
#[instrument(skip_all)]
214177
async fn process_detection_batch_stream(
215178
trace_id: TraceId,
216-
mut detection_batch_stream: DetectionBatchStream<MaxProcessedIndexBatcher>,
179+
mut detection_batch_stream: DetectionBatchStream,
217180
response_tx: mpsc::Sender<Result<StreamingContentDetectionResponse, Error>>,
218181
) {
219182
while let Some(result) = detection_batch_stream.next().await {
220183
match result {
221-
Ok((chunk, detections)) => {
184+
Ok((_, chunk, detections)) => {
222185
let response = StreamingContentDetectionResponse {
223186
start_index: chunk.start as u32,
224187
processed_index: chunk.end as u32,

src/orchestrator/types.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@ use crate::{
3737

3838
pub type ChunkerId = String;
3939
pub type DetectorId = String;
40-
pub type InputId = u32;
4140

4241
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
4342
pub type ChunkStream = BoxStream<Result<Chunk, Error>>;
4443
pub type InputStream = BoxStream<Result<(usize, String), Error>>;
45-
pub type DetectionStream = BoxStream<Result<(InputId, DetectorId, Chunk, Detections), Error>>;
44+
pub type DetectionStream = BoxStream<Result<(u32, Chunk, Detections), Error>>;
4645
pub type GenerationStream = BoxStream<(usize, Result<ClassifiedGeneratedTextStreamResult, Error>)>;
4746
pub type ChatCompletionStream = BoxStream<(usize, Result<Option<ChatCompletionChunk>, Error>)>;
4847
pub type CompletionStream = BoxStream<(usize, Result<Option<Completion>, Error>)>;

0 commit comments

Comments
 (0)