@@ -156,6 +156,7 @@ impl LLMService for OpenAI {
156
156
llm_stopper,
157
157
session_id,
158
158
Some ( tools_json) ,
159
+ None ,
159
160
)
160
161
. await
161
162
} else {
@@ -168,6 +169,7 @@ impl LLMService for OpenAI {
168
169
llm_stopper,
169
170
ws_manager_trait,
170
171
Some ( tools_json) ,
172
+ None ,
171
173
)
172
174
. await
173
175
}
@@ -520,15 +522,42 @@ pub async fn handle_streaming_response(
520
522
llm_stopper : Arc < LLMStopper > ,
521
523
session_id : String ,
522
524
tools : Option < Vec < JsonValue > > ,
525
+ headers : Option < JsonValue > ,
523
526
) -> Result < LLMInferenceResponse , LLMProviderError > {
524
527
let res = client
525
528
. post ( url)
526
529
. bearer_auth ( api_key)
527
530
. header ( "Content-Type" , "application/json" )
531
+ . header ( "X-Shinkai-Version" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-version" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
532
+ . header ( "X-Shinkai-Identity" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-identity" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
533
+ . header ( "X-Shinkai-Signature" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-signature" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
534
+ . header ( "X-Shinkai-Metadata" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-metadata" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
528
535
. json ( & payload)
529
536
. send ( )
530
537
. await ?;
531
538
539
+ // Check for 429 status code
540
+ if res. status ( ) == 429 {
541
+ let error_text = res. text ( ) . await ?;
542
+ if let Ok ( error_json) = serde_json:: from_str :: < JsonValue > ( & error_text) {
543
+ if let Some ( code) = error_json. get ( "code" ) . and_then ( |c| c. as_str ( ) ) {
544
+ if code == "QUOTA_EXCEEDED" &&
545
+ payload. get ( "model" ) . and_then ( |m| m. as_str ( ) ) . map_or ( false , |model| {
546
+ model == "FREE_TEXT_INFERENCE" ||
547
+ model == "STANDARD_TEXT_INFERENCE" ||
548
+ model == "PREMIUM_TEXT_INFERENCE"
549
+ } ) {
550
+ let error_msg = error_json. get ( "error" )
551
+ . and_then ( |e| e. as_str ( ) )
552
+ . unwrap_or ( "Daily quota exceeded" )
553
+ . to_string ( ) ;
554
+ return Err ( LLMProviderError :: LLMServiceInferenceLimitReached ( error_msg) ) ;
555
+ }
556
+ }
557
+ }
558
+ return Err ( LLMProviderError :: LLMServiceUnexpectedError ( "Rate limit exceeded" . to_string ( ) ) ) ;
559
+ }
560
+
532
561
let mut stream = res. bytes_stream ( ) ;
533
562
let mut response_text = String :: new ( ) ;
534
563
let mut buffer = String :: new ( ) ;
@@ -660,12 +689,17 @@ pub async fn handle_non_streaming_response(
660
689
llm_stopper : Arc < LLMStopper > ,
661
690
ws_manager_trait : Option < Arc < Mutex < dyn WSUpdateHandler + Send > > > ,
662
691
tools : Option < Vec < JsonValue > > ,
692
+ headers : Option < JsonValue > ,
663
693
) -> Result < LLMInferenceResponse , LLMProviderError > {
664
694
let mut interval = tokio:: time:: interval ( tokio:: time:: Duration :: from_millis ( 500 ) ) ;
665
695
let response_fut = client
666
696
. post ( url)
667
697
. bearer_auth ( api_key)
668
698
. header ( "Content-Type" , "application/json" )
699
+ . header ( "X-Shinkai-Version" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-version" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
700
+ . header ( "X-Shinkai-Identity" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-identity" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
701
+ . header ( "X-Shinkai-Signature" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-signature" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
702
+ . header ( "X-Shinkai-Metadata" , headers. as_ref ( ) . and_then ( |h| h. get ( "x-shinkai-metadata" ) ) . and_then ( |v| v. as_str ( ) ) . unwrap_or ( "" ) )
669
703
. json ( & payload)
670
704
. send ( ) ;
671
705
let mut response_fut = Box :: pin ( response_fut) ;
@@ -688,6 +722,29 @@ pub async fn handle_non_streaming_response(
688
722
} ,
689
723
response = & mut response_fut => {
690
724
let res = response?;
725
+
726
+ // Check for 429 status code
727
+ if res. status( ) == 429 {
728
+ let error_text = res. text( ) . await ?;
729
+ if let Ok ( error_json) = serde_json:: from_str:: <JsonValue >( & error_text) {
730
+ if let Some ( code) = error_json. get( "code" ) . and_then( |c| c. as_str( ) ) {
731
+ if code == "QUOTA_EXCEEDED" &&
732
+ payload. get( "model" ) . and_then( |m| m. as_str( ) ) . map_or( false , |model| {
733
+ model == "FREE_TEXT_INFERENCE" ||
734
+ model == "STANDARD_TEXT_INFERENCE" ||
735
+ model == "PREMIUM_TEXT_INFERENCE"
736
+ } ) {
737
+ let error_msg = error_json. get( "error" )
738
+ . and_then( |e| e. as_str( ) )
739
+ . unwrap_or( "Daily quota exceeded" )
740
+ . to_string( ) ;
741
+ return Err ( LLMProviderError :: LLMServiceInferenceLimitReached ( error_msg) ) ;
742
+ }
743
+ }
744
+ }
745
+ return Err ( LLMProviderError :: LLMServiceUnexpectedError ( "Rate limit exceeded" . to_string( ) ) ) ;
746
+ }
747
+
691
748
let response_text = res. text( ) . await ?;
692
749
eprintln!( "Raw server response: {}" , response_text) ;
693
750
let data_resp: Result <JsonValue , _> = serde_json:: from_str( & response_text) ;
0 commit comments