Skip to content

Commit a70814d

Browse files
authored
Merge pull request #879 from dcSpark/feature/shinkai-llm-provider
2 parents 2eca44f + fcbe9e5 commit a70814d

File tree

4 files changed

+165
-170
lines changed

4 files changed

+165
-170
lines changed

shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ impl LLMService for OpenAI {
156156
llm_stopper,
157157
session_id,
158158
Some(tools_json),
159+
None,
159160
)
160161
.await
161162
} else {
@@ -168,6 +169,7 @@ impl LLMService for OpenAI {
168169
llm_stopper,
169170
ws_manager_trait,
170171
Some(tools_json),
172+
None,
171173
)
172174
.await
173175
}
@@ -520,15 +522,42 @@ pub async fn handle_streaming_response(
520522
llm_stopper: Arc<LLMStopper>,
521523
session_id: String,
522524
tools: Option<Vec<JsonValue>>,
525+
headers: Option<JsonValue>,
523526
) -> Result<LLMInferenceResponse, LLMProviderError> {
524527
let res = client
525528
.post(url)
526529
.bearer_auth(api_key)
527530
.header("Content-Type", "application/json")
531+
.header("X-Shinkai-Version", headers.as_ref().and_then(|h| h.get("x-shinkai-version")).and_then(|v| v.as_str()).unwrap_or(""))
532+
.header("X-Shinkai-Identity", headers.as_ref().and_then(|h| h.get("x-shinkai-identity")).and_then(|v| v.as_str()).unwrap_or(""))
533+
.header("X-Shinkai-Signature", headers.as_ref().and_then(|h| h.get("x-shinkai-signature")).and_then(|v| v.as_str()).unwrap_or(""))
534+
.header("X-Shinkai-Metadata", headers.as_ref().and_then(|h| h.get("x-shinkai-metadata")).and_then(|v| v.as_str()).unwrap_or(""))
528535
.json(&payload)
529536
.send()
530537
.await?;
531538

539+
// Check for 429 status code
540+
if res.status() == 429 {
541+
let error_text = res.text().await?;
542+
if let Ok(error_json) = serde_json::from_str::<JsonValue>(&error_text) {
543+
if let Some(code) = error_json.get("code").and_then(|c| c.as_str()) {
544+
if code == "QUOTA_EXCEEDED" &&
545+
payload.get("model").and_then(|m| m.as_str()).map_or(false, |model| {
546+
model == "FREE_TEXT_INFERENCE" ||
547+
model == "STANDARD_TEXT_INFERENCE" ||
548+
model == "PREMIUM_TEXT_INFERENCE"
549+
}) {
550+
let error_msg = error_json.get("error")
551+
.and_then(|e| e.as_str())
552+
.unwrap_or("Daily quota exceeded")
553+
.to_string();
554+
return Err(LLMProviderError::LLMServiceInferenceLimitReached(error_msg));
555+
}
556+
}
557+
}
558+
return Err(LLMProviderError::LLMServiceUnexpectedError("Rate limit exceeded".to_string()));
559+
}
560+
532561
let mut stream = res.bytes_stream();
533562
let mut response_text = String::new();
534563
let mut buffer = String::new();
@@ -660,12 +689,17 @@ pub async fn handle_non_streaming_response(
660689
llm_stopper: Arc<LLMStopper>,
661690
ws_manager_trait: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
662691
tools: Option<Vec<JsonValue>>,
692+
headers: Option<JsonValue>,
663693
) -> Result<LLMInferenceResponse, LLMProviderError> {
664694
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
665695
let response_fut = client
666696
.post(url)
667697
.bearer_auth(api_key)
668698
.header("Content-Type", "application/json")
699+
.header("X-Shinkai-Version", headers.as_ref().and_then(|h| h.get("x-shinkai-version")).and_then(|v| v.as_str()).unwrap_or(""))
700+
.header("X-Shinkai-Identity", headers.as_ref().and_then(|h| h.get("x-shinkai-identity")).and_then(|v| v.as_str()).unwrap_or(""))
701+
.header("X-Shinkai-Signature", headers.as_ref().and_then(|h| h.get("x-shinkai-signature")).and_then(|v| v.as_str()).unwrap_or(""))
702+
.header("X-Shinkai-Metadata", headers.as_ref().and_then(|h| h.get("x-shinkai-metadata")).and_then(|v| v.as_str()).unwrap_or(""))
669703
.json(&payload)
670704
.send();
671705
let mut response_fut = Box::pin(response_fut);
@@ -688,6 +722,29 @@ pub async fn handle_non_streaming_response(
688722
},
689723
response = &mut response_fut => {
690724
let res = response?;
725+
726+
// Check for 429 status code
727+
if res.status() == 429 {
728+
let error_text = res.text().await?;
729+
if let Ok(error_json) = serde_json::from_str::<JsonValue>(&error_text) {
730+
if let Some(code) = error_json.get("code").and_then(|c| c.as_str()) {
731+
if code == "QUOTA_EXCEEDED" &&
732+
payload.get("model").and_then(|m| m.as_str()).map_or(false, |model| {
733+
model == "FREE_TEXT_INFERENCE" ||
734+
model == "STANDARD_TEXT_INFERENCE" ||
735+
model == "PREMIUM_TEXT_INFERENCE"
736+
}) {
737+
let error_msg = error_json.get("error")
738+
.and_then(|e| e.as_str())
739+
.unwrap_or("Daily quota exceeded")
740+
.to_string();
741+
return Err(LLMProviderError::LLMServiceInferenceLimitReached(error_msg));
742+
}
743+
}
744+
}
745+
return Err(LLMProviderError::LLMServiceUnexpectedError("Rate limit exceeded".to_string()));
746+
}
747+
691748
let response_text = res.text().await?;
692749
eprintln!("Raw server response: {}", response_text);
693750
let data_resp: Result<JsonValue, _> = serde_json::from_str(&response_text);

0 commit comments

Comments
 (0)