Skip to content

Commit 81fac10

Browse files
committed
✨ Chunker client
Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>
1 parent b2c160d commit 81fac10

File tree

7 files changed

+129
-28
lines changed

7 files changed

+129
-28
lines changed

build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
99
.include_file("mod.rs")
1010
.compile(
1111
&[
12+
"protos/caikit_runtime_Chunkers.proto",
1213
"protos/caikit_runtime_Nlp.proto",
1314
"protos/generation.proto",
1415
"protos/caikit_data_model_caikit_nlp.proto",

protos/caikit_runtime_Chunkers.proto

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
/*------------------------------------------------------------------------------
3+
* AUTO GENERATED
4+
*----------------------------------------------------------------------------*/
5+
6+
syntax = "proto3";
7+
package caikit.runtime.Chunkers;
8+
import "caikit_data_model_nlp.proto";
9+
10+
11+
/*-- MESSAGES ----------------------------------------------------------------*/
12+
13+
message BidiStreamingTokenizationTaskRequest {
14+
15+
/*-- fields --*/
16+
string text_stream = 1;
17+
}
18+
19+
message TokenizationTaskRequest {
20+
21+
/*-- fields --*/
22+
string text = 1;
23+
}
24+
25+
26+
/*-- SERVICES ----------------------------------------------------------------*/
27+
28+
service ChunkersService {
29+
rpc BidiStreamingTokenizationTaskPredict(stream caikit.runtime.Chunkers.BidiStreamingTokenizationTaskRequest) returns (stream caikit_data_model.nlp.TokenizationStreamResult);
30+
rpc TokenizationTaskPredict(caikit.runtime.Chunkers.TokenizationTaskRequest) returns (caikit_data_model.nlp.TokenizationResults);
31+
}

protos/caikit_runtime_Nlp.proto

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
/*-- SERVICES ----------------------------------------------------------------*/
8484

8585
service NlpService {
86-
rpc BidiStreamingTokenizationTaskPredict(stream caikit.runtime.Nlp.BidiStreamingTokenizationTaskRequest) returns (stream caikit_data_model.nlp.TokenizationStreamResult);
8786
rpc ServerStreamingTextGenerationTaskPredict(caikit.runtime.Nlp.ServerStreamingTextGenerationTaskRequest) returns (stream caikit_data_model.nlp.GeneratedTextStreamResult);
8887
rpc TextGenerationTaskPredict(caikit.runtime.Nlp.TextGenerationTaskRequest) returns (caikit_data_model.nlp.GeneratedTextResult);
8988
rpc TokenizationTaskPredict(caikit.runtime.Nlp.TokenizationTaskRequest) returns (caikit_data_model.nlp.TokenizationResults);

src/clients.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ use url::Url;
77

88
use crate::config::{ServiceConfig, Tls};
99

10+
pub mod chunker;
11+
pub use chunker::ChunkerClient;
12+
1013
pub mod detector;
1114
pub use detector::DetectorClient;
1215

@@ -18,6 +21,7 @@ pub use nlp::NlpClient;
1821

1922
pub const DEFAULT_TGIS_PORT: u16 = 8033;
2023
pub const DEFAULT_CAIKIT_NLP_PORT: u16 = 8085;
24+
pub const DEFAULT_CHUNKER_PORT: u16 = 8085;
2125
pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
2226
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
2327
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);

src/clients/chunker.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use std::{collections::HashMap, pin::Pin};
2+
3+
use futures::{Stream, StreamExt};
4+
use ginepro::LoadBalancedChannel;
5+
use tokio::sync::mpsc;
6+
use tokio_stream::wrappers::ReceiverStream;
7+
use tonic::Request;
8+
9+
use super::{create_grpc_clients, Error};
10+
use crate::{
11+
config::ServiceConfig,
12+
pb::{
13+
caikit::runtime::chunkers::{
14+
chunkers_service_client::ChunkersServiceClient, BidiStreamingTokenizationTaskRequest,
15+
TokenizationTaskRequest,
16+
},
17+
caikit_data_model::nlp::{
18+
TokenizationResults, TokenizationStreamResult,
19+
},
20+
},
21+
};
22+
23+
const MODEL_ID_HEADER_NAME: &str = "mm-model-id";
24+
25+
#[derive(Clone)]
26+
pub struct ChunkerClient {
27+
clients: HashMap<String, ChunkersServiceClient<LoadBalancedChannel>>,
28+
}
29+
30+
impl ChunkerClient {
31+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
32+
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await?;
33+
Ok(Self { clients })
34+
}
35+
36+
fn client(&self, model_id: &str) -> Result<ChunkersServiceClient<LoadBalancedChannel>, Error> {
37+
Ok(self
38+
.clients
39+
.get(model_id)
40+
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
41+
.clone())
42+
}
43+
44+
pub async fn tokenization_task_predict(
45+
&self,
46+
model_id: &str,
47+
request: TokenizationTaskRequest,
48+
) -> Result<TokenizationResults, Error> {
49+
let request = request_with_model_id(request, model_id);
50+
Ok(self
51+
.client(model_id)?
52+
.tokenization_task_predict(request)
53+
.await?
54+
.into_inner())
55+
}
56+
57+
pub async fn bidi_streaming_tokenization_task_predict(
58+
&self,
59+
model_id: &str,
60+
request: Pin<Box<dyn Stream<Item = BidiStreamingTokenizationTaskRequest> + Send + 'static>>,
61+
) -> Result<ReceiverStream<TokenizationStreamResult>, Error> {
62+
let request = request_with_model_id(request, model_id);
63+
let mut response_stream = self
64+
.client(model_id)?
65+
.bidi_streaming_tokenization_task_predict(request)
66+
.await?
67+
.into_inner();
68+
let (tx, rx) = mpsc::channel(128);
69+
tokio::spawn(async move {
70+
while let Some(Ok(message)) = response_stream.next().await {
71+
let _ = tx.send(message).await;
72+
}
73+
});
74+
Ok(ReceiverStream::new(rx))
75+
}
76+
77+
}
78+
79+
fn request_with_model_id<T>(request: T, model_id: &str) -> Request<T> {
80+
let mut request = Request::new(request);
81+
request
82+
.metadata_mut()
83+
.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap());
84+
request
85+
}

src/clients/nlp.rs

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ use crate::{
1111
config::ServiceConfig,
1212
pb::{
1313
caikit::runtime::nlp::{
14-
nlp_service_client::NlpServiceClient, BidiStreamingTokenizationTaskRequest,
14+
nlp_service_client::NlpServiceClient,
1515
ServerStreamingTextGenerationTaskRequest, TextGenerationTaskRequest,
1616
TokenClassificationTaskRequest, TokenizationTaskRequest,
1717
},
1818
caikit_data_model::nlp::{
1919
GeneratedTextResult, GeneratedTextStreamResult, TokenClassificationResults,
20-
TokenizationResults, TokenizationStreamResult,
20+
TokenizationResults,
2121
},
2222
},
2323
};
@@ -56,26 +56,6 @@ impl NlpClient {
5656
.into_inner())
5757
}
5858

59-
pub async fn bidi_streaming_tokenization_task_predict(
60-
&self,
61-
model_id: &str,
62-
request: Pin<Box<dyn Stream<Item = BidiStreamingTokenizationTaskRequest> + Send + 'static>>,
63-
) -> Result<ReceiverStream<TokenizationStreamResult>, Error> {
64-
let request = request_with_model_id(request, model_id);
65-
let mut response_stream = self
66-
.client(model_id)?
67-
.bidi_streaming_tokenization_task_predict(request)
68-
.await?
69-
.into_inner();
70-
let (tx, rx) = mpsc::channel(128);
71-
tokio::spawn(async move {
72-
while let Some(Ok(message)) = response_stream.next().await {
73-
let _ = tx.send(message).await;
74-
}
75-
});
76-
Ok(ReceiverStream::new(rx))
77-
}
78-
7959
pub async fn token_classification_task_predict(
8060
&self,
8161
model_id: &str,

src/orchestrator.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use uuid::Uuid;
1010

1111
use crate::{
1212
clients::{
13-
self, detector::DetectorRequest, DetectorClient, GenerationClient, NlpClient, TgisClient,
13+
self, detector::DetectorRequest, ChunkerClient, DetectorClient, GenerationClient, NlpClient, TgisClient,
1414
},
1515
config::{GenerationProvider, OrchestratorConfig},
1616
models::{
@@ -19,6 +19,7 @@ use crate::{
1919
InputWarningReason, TextGenTokenClassificationResults, TokenClassificationResult,
2020
},
2121
pb::{
22+
caikit::runtime::chunkers::TokenizationTaskRequest as ChunkersTokenizationTaskRequest,
2223
caikit::runtime::nlp::{TextGenerationTaskRequest, TokenizationTaskRequest},
2324
fmaas::{
2425
BatchedGenerationRequest, BatchedTokenizeRequest, GenerationRequest, TokenizeRequest,
@@ -34,7 +35,7 @@ const UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. \
3435
struct Context {
3536
config: OrchestratorConfig,
3637
generation_client: GenerationClient,
37-
chunker_client: NlpClient,
38+
chunker_client: ChunkerClient,
3839
detector_client: DetectorClient,
3940
}
4041

@@ -248,7 +249,7 @@ async fn handle_chunk_task(
248249
let ctx = ctx.clone();
249250
let chunker_id = chunker_id.clone();
250251
async move {
251-
let request = TokenizationTaskRequest { text };
252+
let request = ChunkersTokenizationTaskRequest { text };
252253
debug!(
253254
%chunker_id,
254255
?request,
@@ -495,7 +496,7 @@ fn apply_masks(text: &str, masks: &[(usize, usize)]) -> Vec<(usize, String)> {
495496

496497
async fn create_clients(
497498
config: &OrchestratorConfig,
498-
) -> Result<(GenerationClient, NlpClient, DetectorClient), Error> {
499+
) -> Result<(GenerationClient, ChunkerClient, DetectorClient), Error> {
499500
// TODO: create better solution for routers
500501
let generation_client = match config.generation.provider {
501502
GenerationProvider::Tgis => {
@@ -521,7 +522,7 @@ async fn create_clients(
521522
.iter()
522523
.map(|(chunker_id, config)| (chunker_id.clone(), config.service.clone()))
523524
.collect::<Vec<_>>();
524-
let chunker_client = NlpClient::new(clients::DEFAULT_CAIKIT_NLP_PORT, &chunker_config).await?;
525+
let chunker_client = ChunkerClient::new(clients::DEFAULT_CHUNKER_PORT, &chunker_config).await?;
525526

526527
let detector_config = config
527528
.detectors

0 commit comments

Comments
 (0)