Skip to content

Commit e82dfbf

Browse files
authored
Merge pull request #68 from gkumbhat/add_mock_lib
Add mock lib
2 parents 2354acd + 627cfab commit e82dfbf

File tree

6 files changed

+143
-9
lines changed

6 files changed

+143
-9
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
3838
[build-dependencies]
3939
tonic-build = "0.11.0"
4040

41+
[dev-dependencies]
42+
faux = "0.1.10"
43+
4144
[profile.release]
4245
debug = false
4346
incremental = true

src/clients/chunker.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::{
2020

2121
const MODEL_ID_HEADER_NAME: &str = "mm-model-id";
2222

23+
#[cfg_attr(test, derive(Default))]
2324
#[derive(Clone)]
2425
pub struct ChunkerClient {
2526
clients: HashMap<String, ChunkersServiceClient<LoadBalancedChannel>>,

src/clients/detector.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::config::ServiceConfig;
77

88
const DETECTOR_ID_HEADER_NAME: &str = "detector-id";
99

10+
#[cfg_attr(test, derive(Default))]
1011
#[derive(Clone)]
1112
pub struct DetectorClient {
1213
clients: HashMap<String, HttpClient>,

src/clients/tgis.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ use crate::{
1515
},
1616
};
1717

18+
#[cfg_attr(test, faux::create)]
1819
#[derive(Clone)]
1920
pub struct TgisClient {
2021
clients: HashMap<String, GenerationServiceClient<LoadBalancedChannel>>,
2122
}
2223

24+
#[cfg_attr(test, faux::methods)]
2325
impl TgisClient {
2426
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
2527
let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await;
@@ -85,3 +87,36 @@ impl TgisClient {
8587
.into_inner())
8688
}
8789
}
90+
91+
#[cfg(test)]
92+
mod tests {
93+
use super::*;
94+
use crate::pb::fmaas::model_info_response;
95+
96+
#[tokio::test]
97+
async fn test_model_info() {
98+
// Initialize a mock object from `TgisClient`
99+
let mut mock_client = TgisClient::faux();
100+
101+
let request = ModelInfoRequest {
102+
model_id: "test-model-1".to_string(),
103+
};
104+
105+
let expected_response = ModelInfoResponse {
106+
max_sequence_length: 2,
107+
max_new_tokens: 20,
108+
max_beam_width: 3,
109+
model_kind: model_info_response::ModelKind::DecoderOnly.into(),
110+
max_beam_sequence_lengths: [].to_vec(),
111+
};
112+
// Construct a behavior for the mock object
113+
faux::when!(mock_client.model_info(request.clone()))
114+
.once()
115+
.then_return(Ok(expected_response.clone()));
116+
// Test the mock object's behaviour
117+
assert_eq!(
118+
mock_client.model_info(request).await.unwrap(),
119+
expected_response
120+
);
121+
}
122+
}

src/config.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,42 @@ use tracing::debug;
88

99
/// Configuration for service needed for
1010
/// orchestrator to communicate with it
11-
#[derive(Debug, Clone, Deserialize)]
11+
#[derive(Clone, Debug, Default, Deserialize)]
1212
pub struct ServiceConfig {
1313
pub hostname: String,
1414
pub port: Option<u16>,
1515
pub tls: Option<Tls>,
1616
}
1717

1818
/// TLS provider
19-
#[derive(Debug, Clone, Deserialize)]
19+
#[derive(Clone, Debug, Deserialize)]
2020
#[serde(untagged)]
2121
pub enum Tls {
2222
Name(String),
2323
Config(TlsConfig),
2424
}
2525

2626
/// Client TLS configuration
27-
#[derive(Debug, Clone, Deserialize)]
27+
#[derive(Clone, Debug, Default, Deserialize)]
2828
pub struct TlsConfig {
2929
pub cert_path: Option<PathBuf>,
3030
pub key_path: Option<PathBuf>,
3131
pub client_ca_cert_path: Option<PathBuf>,
3232
}
3333

3434
/// Generation service provider
35-
#[derive(Debug, Clone, Copy, Deserialize)]
35+
#[cfg_attr(test, derive(Default))]
36+
#[derive(Clone, Copy, Debug, Deserialize)]
3637
#[serde(rename_all = "lowercase")]
3738
pub enum GenerationProvider {
39+
#[cfg_attr(test, default)]
3840
Tgis,
3941
Nlp,
4042
}
4143

4244
/// Generate service configuration
43-
#[derive(Debug, Clone, Deserialize)]
45+
#[cfg_attr(test, derive(Default))]
46+
#[derive(Clone, Debug, Deserialize)]
4447
pub struct GenerationConfig {
4548
/// Generation service provider
4649
pub provider: GenerationProvider,
@@ -49,16 +52,19 @@ pub struct GenerationConfig {
4952
}
5053

5154
/// Chunker parser type
52-
#[derive(Debug, Clone, Copy, Deserialize)]
55+
#[cfg_attr(test, derive(Default))]
56+
#[derive(Clone, Copy, Debug, Deserialize)]
5357
#[serde(rename_all = "lowercase")]
5458
pub enum ChunkerType {
59+
#[cfg_attr(test, default)]
5560
Sentence,
5661
All,
5762
}
5863

5964
/// Configuration for each chunker
65+
#[cfg_attr(test, derive(Default))]
6066
#[allow(dead_code)]
61-
#[derive(Debug, Clone, Deserialize)]
67+
#[derive(Clone, Debug, Deserialize)]
6268
pub struct ChunkerConfig {
6369
/// Chunker type
6470
pub r#type: ChunkerType,
@@ -67,7 +73,7 @@ pub struct ChunkerConfig {
6773
}
6874

6975
/// Configuration for each detector
70-
#[derive(Debug, Clone, Deserialize)]
76+
#[derive(Clone, Debug, Default, Deserialize)]
7177
pub struct DetectorConfig {
7278
/// Detector service connection information
7379
pub service: ServiceConfig,
@@ -78,7 +84,8 @@ pub struct DetectorConfig {
7884
}
7985

8086
/// Overall orchestrator server configuration
81-
#[derive(Debug, Clone, Deserialize)]
87+
#[cfg_attr(test, derive(Default))]
88+
#[derive(Clone, Debug, Deserialize)]
8289
pub struct OrchestratorConfig {
8390
/// Generation service and associated configuration
8491
pub generation: GenerationConfig,
@@ -150,6 +157,13 @@ fn service_tls_name_to_config(
150157
service
151158
}
152159

160+
#[cfg(test)]
161+
impl Default for Tls {
162+
fn default() -> Self {
163+
Tls::Name("dummy_tls".to_string())
164+
}
165+
}
166+
153167
#[cfg(test)]
154168
mod tests {
155169
use anyhow::Error;

src/orchestrator.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,28 @@ impl StreamingClassificationWithGenTask {
689689
#[cfg(test)]
690690
mod tests {
691691
use super::*;
692+
use crate::{
693+
models::FinishReason,
694+
pb::fmaas::{
695+
BatchedGenerationRequest, BatchedGenerationResponse, GenerationResponse, StopReason,
696+
},
697+
};
698+
699+
async fn get_test_context(
700+
gen_client: GenerationClient,
701+
chunker_client: Option<ChunkerClient>,
702+
detector_client: Option<DetectorClient>,
703+
) -> Context {
704+
let chunker_client = chunker_client.unwrap_or_default();
705+
let detector_client = detector_client.unwrap_or_default();
706+
707+
Context {
708+
generation_client: gen_client,
709+
chunker_client,
710+
detector_client,
711+
config: OrchestratorConfig::default(),
712+
}
713+
}
692714

693715
#[test]
694716
fn test_apply_masks() {
@@ -709,4 +731,62 @@ mod tests {
709731
let s = "哈囉世界";
710732
assert_eq!(slice_codepoints(s, 3, 4), "界");
711733
}
734+
735+
// Test for TGIS generation with default parameter
736+
#[tokio::test]
737+
async fn test_tgis_generate_with_default_params() {
738+
// Initialize a mock object from `TgisClient`
739+
let mut mock_client = TgisClient::faux();
740+
741+
let sample_text = String::from("sample text");
742+
let text_gen_model_id = String::from("test-llm-id-1");
743+
744+
let generation_response = GenerationResponse {
745+
text: String::from("sample response worked"),
746+
stop_reason: StopReason::EosToken.into(),
747+
stop_sequence: String::from("\n"),
748+
generated_token_count: 3,
749+
seed: 7,
750+
..Default::default()
751+
};
752+
753+
let client_generation_response = BatchedGenerationResponse {
754+
responses: [generation_response].to_vec(),
755+
};
756+
757+
let expected_generate_req_args = BatchedGenerationRequest {
758+
model_id: text_gen_model_id.clone(),
759+
prefix_id: None,
760+
requests: [GenerationRequest {
761+
text: sample_text.clone(),
762+
}]
763+
.to_vec(),
764+
params: None,
765+
};
766+
767+
let expected_generate_response = ClassifiedGeneratedTextResult {
768+
generated_text: Some(client_generation_response.responses[0].text.clone()),
769+
finish_reason: Some(FinishReason::EosToken),
770+
generated_token_count: Some(3),
771+
seed: Some(7),
772+
..Default::default()
773+
};
774+
775+
// Construct a behavior for the mock object
776+
faux::when!(mock_client.generate(expected_generate_req_args))
777+
.once() // TODO: Add with_args
778+
.then_return(Ok(client_generation_response));
779+
780+
let mock_generation_client = GenerationClient::Tgis(mock_client.clone());
781+
782+
let ctx: Context = get_test_context(mock_generation_client, None, None).await;
783+
784+
// Test request formulation and response processing is as expected
785+
assert_eq!(
786+
generate(ctx.into(), text_gen_model_id, sample_text, None)
787+
.await
788+
.unwrap(),
789+
expected_generate_response
790+
);
791+
}
712792
}

0 commit comments

Comments
 (0)