Skip to content

Commit 023f205

Browse files
authored
Merge pull request #1267 from dcSpark/fix/capabilities
Update GPT-5 capabilities / OpenRouter.
2 parents 35508fc + c70142a commit 023f205

File tree

2 files changed

+83
-49
lines changed

2 files changed

+83
-49
lines changed

shinkai-bin/shinkai-node/src/llm_provider/providers/openrouter.rs

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,19 @@ impl LLMService for OpenRouter {
9090
"max_tokens": result.remaining_output_tokens,
9191
});
9292

93-
// Conditionally add functions to the payload if tools_json is not empty
93+
// Conditionally add tools to the payload if tools_json is not empty
9494
if !tools_json.is_empty() {
95-
payload["functions"] = serde_json::Value::Array(tools_json.clone());
95+
// Remove tool_router_key from each tool before sending to OpenRouter
96+
let tools_payload = tools_json
97+
.clone()
98+
.into_iter()
99+
.map(|mut tool| {
100+
tool.as_object_mut().unwrap().remove("tool_router_key");
101+
tool
102+
})
103+
.collect::<Vec<JsonValue>>();
104+
105+
payload["tools"] = serde_json::Value::Array(tools_payload);
96106
}
97107

98108
// Add options to payload
@@ -252,38 +262,44 @@ async fn handle_streaming_response(
252262
if let Some(content) = message.get("content") {
253263
response_text.push_str(content.as_str().unwrap_or(""));
254264
}
255-
if let Some(fc) = message.get("function_call") {
256-
if let Some(name) = fc.get("name") {
257-
let fc_arguments = fc
258-
.get("arguments")
259-
.and_then(|args| args.as_str())
260-
.and_then(|args_str| serde_json::from_str(args_str).ok())
261-
.and_then(|args_value: serde_json::Value| {
262-
args_value.as_object().cloned()
263-
})
264-
.unwrap_or_else(|| serde_json::Map::new());
265-
266-
// Extract tool_router_key
267-
let tool_router_key = tools.as_ref().and_then(|tools_array| {
268-
tools_array.iter().find_map(|tool| {
269-
if tool.get("name")?.as_str()? == name.as_str().unwrap_or("") {
270-
tool.get("tool_router_key")
271-
.and_then(|key| key.as_str().map(|s| s.to_string()))
272-
} else {
273-
None
265+
if let Some(tool_calls) = message.get("tool_calls") {
266+
if let Some(tool_calls_array) = tool_calls.as_array() {
267+
for tool_call in tool_calls_array {
268+
if let Some(function) = tool_call.get("function") {
269+
if let Some(name) = function.get("name") {
270+
let fc_arguments = function
271+
.get("arguments")
272+
.and_then(|args| args.as_str())
273+
.and_then(|args_str| serde_json::from_str(args_str).ok())
274+
.and_then(|args_value: serde_json::Value| {
275+
args_value.as_object().cloned()
276+
})
277+
.unwrap_or_else(|| serde_json::Map::new());
278+
279+
// Extract tool_router_key
280+
let tool_router_key = tools.as_ref().and_then(|tools_array| {
281+
tools_array.iter().find_map(|tool| {
282+
if tool.get("name")?.as_str()? == name.as_str().unwrap_or("") {
283+
tool.get("tool_router_key")
284+
.and_then(|key| key.as_str().map(|s| s.to_string()))
285+
} else {
286+
None
287+
}
288+
})
289+
});
290+
291+
function_calls.push(FunctionCall {
292+
name: name.as_str().unwrap_or("").to_string(),
293+
arguments: fc_arguments.clone(),
294+
tool_router_key,
295+
response: None,
296+
index: function_calls.len() as u64,
297+
id: tool_call.get("id").and_then(|id| id.as_str()).map(|s| s.to_string()),
298+
call_type: tool_call.get("type").and_then(|t| t.as_str()).map(|s| s.to_string()).or(Some("function".to_string())),
299+
});
274300
}
275-
})
276-
});
277-
278-
function_calls.push(FunctionCall {
279-
name: name.as_str().unwrap_or("").to_string(),
280-
arguments: fc_arguments.clone(),
281-
tool_router_key,
282-
response: None,
283-
index: function_calls.len() as u64,
284-
id: None,
285-
call_type: Some("function".to_string()),
286-
});
301+
}
302+
}
287303
}
288304
}
289305
}
@@ -462,17 +478,17 @@ async fn handle_non_streaming_response(
462478
.collect::<Vec<String>>()
463479
.join(" ");
464480

465-
let function_call: Option<FunctionCall> = data.choices.iter().find_map(|choice| {
466-
choice.message.function_call.clone().map(|fc| {
467-
let arguments = serde_json::from_str::<serde_json::Value>(&fc.arguments)
481+
let function_calls: Vec<FunctionCall> = data.choices.iter().flat_map(|choice| {
482+
choice.message.tool_calls.as_ref().unwrap_or(&vec![]).iter().map(|tool_call| {
483+
let arguments = serde_json::from_str::<serde_json::Value>(&tool_call.function.arguments)
468484
.ok()
469485
.and_then(|args_value: serde_json::Value| args_value.as_object().cloned())
470486
.unwrap_or_else(|| serde_json::Map::new());
471487

472488
// Extract tool_router_key
473489
let tool_router_key = tools.as_ref().and_then(|tools_array| {
474490
tools_array.iter().find_map(|tool| {
475-
if tool.get("name")?.as_str()? == fc.name {
491+
if tool.get("name")?.as_str()? == tool_call.function.name {
476492
tool.get("tool_router_key").and_then(|key| key.as_str().map(|s| s.to_string()))
477493
} else {
478494
None
@@ -481,22 +497,22 @@ async fn handle_non_streaming_response(
481497
});
482498

483499
FunctionCall {
484-
name: fc.name,
500+
name: tool_call.function.name.clone(),
485501
arguments,
486502
tool_router_key,
487503
response: None,
488504
index: 0,
489-
id: None,
490-
call_type: Some("function".to_string()),
505+
id: Some(tool_call.id.clone()),
506+
call_type: Some(tool_call.call_type.clone()),
491507
}
492-
})
493-
});
494-
eprintln!("Function Call: {:?}", function_call);
508+
}).collect::<Vec<_>>()
509+
}).collect();
510+
eprintln!("Function Calls: {:?}", function_calls);
495511
eprintln!("Response String: {:?}", response_string);
496512
return Ok(LLMInferenceResponse::new(
497513
response_string,
498514
json!({}),
499-
function_call.map_or_else(Vec::new, |fc| vec![fc]),
515+
function_calls,
500516
None,
501517
));
502518
}

shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ impl ModelCapabilitiesManager {
134134
match model {
135135
LLMProviderInterface::OpenAI(openai) => match openai.model_type.as_str() {
136136
"gpt-5" => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
137+
"gpt-5-mini" => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
138+
"gpt-5-nano" => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
139+
"gpt-5-chat-latest" => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
137140
"gpt-4o" => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
138141
"gpt-4o-mini" => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
139142
"gpt-4.1-nano" => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
@@ -387,7 +390,10 @@ impl ModelCapabilitiesManager {
387390
pub fn get_llm_provider_cost(model: &LLMProviderInterface) -> ModelCost {
388391
match model {
389392
LLMProviderInterface::OpenAI(openai) => match openai.model_type.as_str() {
390-
"gpt-5" => ModelCost::Expensive,
393+
"gpt-5" => ModelCost::GoodValue,
394+
"gpt-5-mini" => ModelCost::Cheap,
395+
"gpt-5-nano" => ModelCost::VeryCheap,
396+
"gpt-5-chat-latest" => ModelCost::GoodValue,
391397
"gpt-4o" => ModelCost::GoodValue,
392398
"gpt-3.5-turbo-1106" => ModelCost::VeryCheap,
393399
"gpt-4o-mini" => ModelCost::VeryCheap,
@@ -617,7 +623,7 @@ impl ModelCapabilitiesManager {
617623
} else if openai.model_type.starts_with("gpt-5") {
618624
400_000
619625
} else if openai.model_type.starts_with("gpt-3.5") {
620-
16384
626+
16_384
621627
} else {
622628
200_000 // New default for OpenAI models
623629
}
@@ -791,6 +797,12 @@ impl ModelCapabilitiesManager {
791797
65_536
792798
} else if openai.model_type.starts_with("o3") || openai.model_type.starts_with("o4-mini") {
793799
100_000
800+
} else if openai.model_type.starts_with("gpt-5-chat-latest") {
801+
16_384
802+
} else if openai.model_type.starts_with("gpt-5-mini") {
803+
128_000
804+
} else if openai.model_type.starts_with("gpt-5-nano") {
805+
128_000
794806
} else if openai.model_type.starts_with("gpt-5") {
795807
128_000
796808
} else if openai.model_type.starts_with("gpt-3.5") {
@@ -982,8 +994,14 @@ impl ModelCapabilitiesManager {
982994
eprintln!("has tool capabilities model: {:?}", model);
983995
match model {
984996
LLMProviderInterface::OpenAI(openai) => {
985-
// o1-mini specifically does not support function calling
986-
!openai.model_type.starts_with("o1-mini")
997+
// o1-mini and gpt-5-chat specifically does not support function calling
998+
if openai.model_type.starts_with("o1-mini") {
999+
false
1000+
} else if openai.model_type.starts_with("gpt-5-chat-latest") {
1001+
false
1002+
} else {
1003+
true
1004+
}
9871005
}
9881006
LLMProviderInterface::Ollama(model) => {
9891007
// For Ollama, check model type and respect the passed stream parameter
@@ -1073,7 +1091,7 @@ impl ModelCapabilitiesManager {
10731091
|| openai.model_type.starts_with("o3")
10741092
|| openai.model_type.starts_with("o4")
10751093
|| openai.model_type.starts_with("o5")
1076-
|| openai.model_type.starts_with("gpt-5")
1094+
|| (openai.model_type.starts_with("gpt-5") && openai.model_type != "gpt-5-chat-latest")
10771095
}
10781096
LLMProviderInterface::Ollama(ollama) => {
10791097
ollama.model_type.starts_with("deepseek-r1")

0 commit comments

Comments
 (0)