Skip to content

Commit b454788

Browse files
committed
🥅 Use error handling objects
Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>
1 parent 45b12ef commit b454788

File tree

2 files changed

+5
-40
lines changed

2 files changed

+5
-40
lines changed

src/models.rs

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
use garde::Validate;
44

5-
use crate::pb;
5+
use crate::{pb, server};
66
use std::collections::HashMap;
77

88
pub type DetectorParams = HashMap<String, serde_json::Value>;
@@ -36,11 +36,10 @@ pub struct GuardrailsHttpRequest {
3636

3737
impl GuardrailsHttpRequest {
3838
/// Upfront validation of user request
39-
// TODO: Change to validation error when present
40-
pub fn upfront_validate(&self) -> Result<(), crate::Error> {
39+
pub fn upfront_validate(&self) -> Result<(), server::Error> {
4140
// Invoke garde validation for various fields
4241
if let Err(e) = self.validate(&()) {
43-
return Err(crate::Error::ValidationError(e.to_string())); // TODO: update on presence of validation error
42+
return Err(server::Error::Validation(e.to_string()));
4443
};
4544
// Validate masks
4645
let input_range = 0..self.inputs.len();
@@ -52,8 +51,7 @@ impl GuardrailsHttpRequest {
5251
if !input_masks.iter().all(|(start, end)| {
5352
input_range.contains(start) && input_range.contains(end) && start < end
5453
}) {
55-
return Err(crate::Error::ValidationError("invalid masks".into()));
56-
// TODO: update on presence of validation error
54+
return Err(server::Error::Validation("invalid masks".into()));
5755
}
5856
}
5957
Ok(())
@@ -594,39 +592,6 @@ pub struct GeneratedTextStreamResult {
594592
pub input_tokens: Option<Vec<GeneratedToken>>,
595593
}
596594

597-
// TODO: The below errors follow FastAPI concepts esp. for loc
598-
// It may be worth revisiting if the orchestrator without FastAPI
599-
// should be using these error types
600-
601-
/// HTTP validation error
602-
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
603-
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
604-
pub struct HttpValidationError {
605-
#[serde(rename = "detail")]
606-
#[serde(skip_serializing_if = "Option::is_none")]
607-
pub detail: Option<Vec<ValidationError>>,
608-
}
609-
610-
/// Validation error
611-
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
612-
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
613-
pub struct ValidationError {
614-
#[serde(rename = "loc")]
615-
pub loc: Vec<LocationInner>,
616-
617-
/// Error message
618-
#[serde(rename = "msg")]
619-
pub msg: String,
620-
621-
/// Error type
622-
#[serde(rename = "type")]
623-
pub r#type: String,
624-
}
625-
626-
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
627-
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
628-
pub struct LocationInner {}
629-
630595
impl From<ExponentialDecayLengthPenalty> for pb::fmaas::decoding_parameters::LengthPenalty {
631596
fn from(value: ExponentialDecayLengthPenalty) -> Self {
632597
Self {

src/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async fn classification_with_gen(
8484
let request_id = Uuid::new_v4();
8585
// Upfront request validation
8686
if let Err(e) = request.upfront_validate() {
87-
return Err((StatusCode::BAD_REQUEST, Json(e.to_string())));
87+
return Err(e.into());
8888
};
8989
let task = ClassificationWithGenTask::new(request_id, request);
9090
match state

0 commit comments

Comments
 (0)