Skip to content

Commit 5d0ff08

Browse files
authored
Merge pull request #1265 from dcSpark/feature/update-ollama-provider
Add support for Ollama thinking flag, update gpt-oss model.
2 parents b050645 + 10cab65 commit 5d0ff08

File tree

7 files changed

+239
-10
lines changed

7 files changed

+239
-10
lines changed

shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ async fn process_stream(
243243
tools: Option<Vec<JsonValue>>,
244244
) -> Result<LLMInferenceResponse, LLMProviderError> {
245245
let mut response_text = String::new();
246+
let mut thinking_content = String::new();
247+
let mut thinking_started = false;
248+
let mut thinking_ended = false;
246249
let mut previous_json_chunk: String = String::new();
247250
let mut final_eval_count = None;
248251
let mut final_eval_duration = None;
@@ -315,7 +318,106 @@ async fn process_stream(
315318
match data_resp {
316319
Ok(data) => {
317320
previous_json_chunk = "".to_string();
318-
response_text.push_str(&data.message.content);
321+
322+
// Handle thinking tokens
323+
if let Some(thinking) = &data.message.thinking {
324+
if !thinking.is_empty() {
325+
if !thinking_started {
326+
thinking_started = true;
327+
// Send opening <think> tag immediately via WebSocket
328+
if let Some(ref manager) = ws_manager_trait {
329+
if let Some(ref inbox_name) = inbox_name {
330+
let m = manager.lock().await;
331+
let inbox_name_string = inbox_name.to_string();
332+
let metadata = WSMetadata {
333+
id: Some(session_id.clone()),
334+
is_done: false,
335+
done_reason: None,
336+
total_duration: None,
337+
eval_count: None,
338+
};
339+
let ws_message_type = WSMessageType::Metadata(metadata);
340+
let _ = m
341+
.queue_message(
342+
WSTopic::Inbox,
343+
inbox_name_string,
344+
"<think>".to_string(),
345+
ws_message_type,
346+
true,
347+
)
348+
.await;
349+
}
350+
}
351+
// Add to response text for final accumulation
352+
response_text.push_str("<think>");
353+
}
354+
355+
// Stream thinking content immediately via WebSocket
356+
if let Some(ref manager) = ws_manager_trait {
357+
if let Some(ref inbox_name) = inbox_name {
358+
let m = manager.lock().await;
359+
let inbox_name_string = inbox_name.to_string();
360+
let metadata = WSMetadata {
361+
id: Some(session_id.clone()),
362+
is_done: false,
363+
done_reason: None,
364+
total_duration: None,
365+
eval_count: None,
366+
};
367+
let ws_message_type = WSMessageType::Metadata(metadata);
368+
let _ = m
369+
.queue_message(
370+
WSTopic::Inbox,
371+
inbox_name_string,
372+
thinking.clone(),
373+
ws_message_type,
374+
true,
375+
)
376+
.await;
377+
}
378+
}
379+
380+
// Also accumulate for final response
381+
thinking_content.push_str(thinking);
382+
response_text.push_str(thinking);
383+
}
384+
}
385+
386+
// Handle regular content tokens
387+
if !data.message.content.is_empty() {
388+
// If we were processing thinking and now we have content,
389+
// close the thinking tags
390+
if thinking_started && !thinking_ended {
391+
thinking_ended = true;
392+
// Send closing </think> tag via WebSocket
393+
if let Some(ref manager) = ws_manager_trait {
394+
if let Some(ref inbox_name) = inbox_name {
395+
let m = manager.lock().await;
396+
let inbox_name_string = inbox_name.to_string();
397+
let metadata = WSMetadata {
398+
id: Some(session_id.clone()),
399+
is_done: false,
400+
done_reason: None,
401+
total_duration: None,
402+
eval_count: None,
403+
};
404+
let ws_message_type = WSMessageType::Metadata(metadata);
405+
let _ = m
406+
.queue_message(
407+
WSTopic::Inbox,
408+
inbox_name_string,
409+
"</think>".to_string(),
410+
ws_message_type,
411+
true,
412+
)
413+
.await;
414+
}
415+
}
416+
// Add to response text for final accumulation
417+
response_text.push_str("</think>");
418+
}
419+
response_text.push_str(&data.message.content);
420+
}
319421

320422
if let Some(tool_calls) = data.message.tool_calls {
321423
for tool_call in tool_calls {
@@ -461,6 +563,36 @@ async fn process_stream(
461563
}
462564
}
463565

566+
// If we ended with thinking content but no regular content, send closing tag
567+
if thinking_started && !thinking_ended && !thinking_content.is_empty() {
568+
// Send closing </think> tag via WebSocket
569+
if let Some(ref manager) = ws_manager_trait {
570+
if let Some(ref inbox_name) = inbox_name {
571+
let m = manager.lock().await;
572+
let inbox_name_string = inbox_name.to_string();
573+
let metadata = WSMetadata {
574+
id: Some(session_id.clone()),
575+
is_done: true,
576+
done_reason: None,
577+
total_duration: None,
578+
eval_count: None,
579+
};
580+
let ws_message_type = WSMessageType::Metadata(metadata);
581+
let _ = m
582+
.queue_message(
583+
WSTopic::Inbox,
584+
inbox_name_string,
585+
"</think>".to_string(),
586+
ws_message_type,
587+
true,
588+
)
589+
.await;
590+
}
591+
}
592+
// Add to response text for final accumulation
593+
response_text.push_str("</think>");
594+
}
595+
464596
let tps = if let (Some(eval_count), Some(eval_duration)) = (final_eval_count, final_eval_duration) {
465597
if eval_duration > 0 {
466598
Some(eval_count as f64 / eval_duration as f64 * 1e9)
@@ -530,6 +662,21 @@ async fn handle_non_streaming_response(
530662
if let Some(message) = response_json.get("message") {
531663
if let Some(content) = message.get("content") {
532664
if let Some(content_str) = content.as_str() {
665+
// Handle thinking content in non-streaming response
666+
let mut final_content = String::new();
667+
668+
// Check for thinking content and prepend it with tags
669+
if let Some(thinking) = message.get("thinking").and_then(|t| t.as_str()) {
670+
if !thinking.is_empty() {
671+
final_content.push_str("<think>");
672+
final_content.push_str(thinking);
673+
final_content.push_str("</think>");
674+
}
675+
}
676+
677+
// Add regular content
678+
final_content.push_str(content_str);
679+
533680
let mut function_calls = Vec::new();
534681

535682
if let Some(tool_calls) = message.get("tool_calls").and_then(|tc| tc.as_array()) {
@@ -606,6 +753,31 @@ async fn handle_non_streaming_response(
606753
format!("Function Calls: {:?}", function_calls).as_str(),
607754
);
608755

756+
// Send the final content (including thinking) via WebSocket in non-streaming mode
757+
if let Some(ref manager) = ws_manager_trait {
758+
if let Some(ref inbox_name) = inbox_name {
759+
let m = manager.lock().await;
760+
let inbox_name_string = inbox_name.to_string();
761+
let metadata = WSMetadata {
762+
id: None,
763+
is_done: true,
764+
done_reason: None,
765+
total_duration: None,
766+
eval_count: None,
767+
};
768+
let ws_message_type = WSMessageType::Metadata(metadata);
769+
let _ = m
770+
.queue_message(
771+
WSTopic::Inbox,
772+
inbox_name_string,
773+
final_content.clone(),
774+
ws_message_type,
775+
true,
776+
)
777+
.await;
778+
}
779+
}
780+
609781
let eval_count = response_json.get("eval_count").and_then(|v| v.as_u64()).unwrap_or(0);
610782
let eval_duration = response_json.get("eval_duration").and_then(|v| v.as_u64()).unwrap_or(1);
611783
let tps = if eval_duration > 0 {
@@ -615,7 +787,7 @@ async fn handle_non_streaming_response(
615787
};
616788

617789
break Ok(LLMInferenceResponse::new(
618-
content_str.to_string(),
790+
final_content,
619791
json!({}),
620792
function_calls,
621793
tps,
@@ -681,6 +853,15 @@ fn add_options_to_payload(
681853
let streaming = get_value("LLM_STREAMING", config.and_then(|c| c.stream.as_ref())).unwrap_or(true); // Default to true if not specified
682854
payload["stream"] = serde_json::json!(streaming);
683855

856+
// Handle thinking option (there are open issues with this feature)
857+
// https://github.com/ollama/ollama/issues/11712
858+
// https://github.com/ollama/ollama/issues/11751
859+
// https://github.com/ollama/ollama/issues/10976
860+
if ModelCapabilitiesManager::has_reasoning_capabilities(model) {
861+
let thinking = get_value("LLM_THINKING", config.and_then(|c| c.thinking.as_ref())).unwrap_or(true);
862+
payload["think"] = serde_json::json!(thinking);
863+
}
864+
684865
// Handle num_ctx setting
685866
let num_ctx_from_config = config
686867
.and_then(|c| c.other_model_params.as_ref())

shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ impl LLMService for OpenAI {
9797
let tools_json = result.functions.unwrap_or_else(Vec::new);
9898

9999
// Set up initial payload with appropriate token limit field based on model capabilities
100-
let mut payload = if ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
100+
let mut payload = if ModelCapabilitiesManager::has_reasoning_capabilities(&model)
101+
{
101102
json!({
102103
"model": self.model_type,
103104
"messages": messages_json,
@@ -118,8 +119,19 @@ impl LLMService for OpenAI {
118119
payload["tools"] = serde_json::Value::Array(tools_json.clone());
119120
}
120121

121-
// Only add options to payload for non-reasoning models
122-
if !ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
122+
// Only add options to payload for non-reasoning models, add reasoning_effort if thinking is enabled and the model has reasoning capabilities
123+
if ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
124+
let thinking_enabled = config.as_ref().and_then(|c| c.thinking).unwrap_or(false);
125+
if thinking_enabled {
126+
let effort = config
127+
.as_ref()
128+
.and_then(|c| c.reasoning_effort.clone())
129+
.unwrap_or("medium".to_string());
130+
payload["reasoning_effort"] = serde_json::json!(effort);
131+
} else if let Some(obj) = payload.as_object_mut() {
132+
obj.remove("reasoning_effort");
133+
}
134+
} else {
123135
add_options_to_payload(&mut payload, config.as_ref());
124136
}
125137

shinkai-bin/shinkai-node/src/llm_provider/providers/shared/ollama_api.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ pub struct OllamaMessage {
4646
pub images: Option<Vec<String>>,
4747
#[serde(skip_serializing_if = "Option::is_none")]
4848
pub tool_calls: Option<Vec<ToolCall>>,
49+
#[serde(skip_serializing_if = "Option::is_none")]
50+
pub thinking: Option<String>,
4951
}
5052

5153
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
@@ -131,6 +133,7 @@ fn from_chat_completion_messages(
131133
content,
132134
images,
133135
tool_calls: None,
136+
thinking: None,
134137
});
135138
}
136139
}
@@ -358,12 +361,14 @@ mod tests {
358361
content: "You are a very helpful assistant. You may be provided with documents or content to analyze and answer questions about them, in that case refer to the content provided in the user message for your responses.".to_string(),
359362
images: None,
360363
tool_calls: None,
364+
thinking: None,
361365
},
362366
OllamaMessage {
363367
role: "user".to_string(),
364368
content: "tell me what's the response when using shinkai echo tool with: say hello".to_string(),
365369
images: Some(vec![]),
366370
tool_calls: None,
371+
thinking: None,
367372
},
368373
];
369374

0 commit comments

Comments
 (0)