Skip to content

Commit f644c1c

Browse files
authored
Merge pull request #1288 from dcSpark/feature/add-previously-used-tools
Add previously used tools to tool selection.
2 parents 9679d0f + 54f36a9 commit f644c1c

File tree

8 files changed

+114
-677
lines changed

8 files changed

+114
-677
lines changed

shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::network::agent_payments_manager::external_agent_offerings_manager::Ex
1414
use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager;
1515
use shinkai_fs::shinkai_file_manager::ShinkaiFileManager;
1616
use shinkai_message_primitives::schemas::tool_router_key::ToolRouterKey;
17+
use shinkai_message_primitives::shinkai_message::shinkai_message::ShinkaiMessage;
1718

1819
use crate::utils::environment::{fetch_node_environment, NodeEnvironment};
1920
use async_trait::async_trait;
@@ -770,6 +771,44 @@ impl GenericInferenceChain {
770771
)));
771772
}
772773
}
774+
775+
// Append tools from conversation history to provide contextual continuity
776+
let historical_tool_keys = Self::extract_tools_from_conversation_history(&full_job.step_history, 5);
777+
for tool_key in historical_tool_keys {
778+
// Check if this tool is not already in our tools list to avoid duplicates
779+
let tool_already_exists = tools.iter().any(|existing_tool| {
780+
existing_tool.tool_router_key().to_string_without_version() == tool_key
781+
});
782+
783+
if !tool_already_exists {
784+
match tool_router.get_tool_by_name(&tool_key).await {
785+
Ok(Some(tool)) => {
786+
tools.push(tool);
787+
shinkai_log(
788+
ShinkaiLogOption::JobExecution,
789+
ShinkaiLogLevel::Debug,
790+
&format!("Added historical tool from conversation: {}", tool_key),
791+
);
792+
}
793+
Ok(None) => {
794+
// Tool not found - this is okay, it might have been removed
795+
shinkai_log(
796+
ShinkaiLogOption::JobExecution,
797+
ShinkaiLogLevel::Debug,
798+
&format!("Historical tool not found (may have been removed): {}", tool_key),
799+
);
800+
}
801+
Err(e) => {
802+
// Log error but don't fail the whole process
803+
shinkai_log(
804+
ShinkaiLogOption::JobExecution,
805+
ShinkaiLogLevel::Debug,
806+
&format!("Error retrieving historical tool {}: {:?}", tool_key, e),
807+
);
808+
}
809+
}
810+
}
811+
}
773812
}
774813
}
775814
}
@@ -1407,4 +1446,49 @@ impl GenericInferenceChain {
14071446

14081447
Ok(additional_files)
14091448
}
1449+
1450+
/// Extract tool router keys from conversation history
1451+
/// Returns a list of unique tool router keys that were used in the last N messages
1452+
fn extract_tools_from_conversation_history(
1453+
step_history: &[ShinkaiMessage],
1454+
max_messages: usize
1455+
) -> Vec<String> {
1456+
let mut tool_keys = Vec::new();
1457+
let mut seen_keys = std::collections::HashSet::new();
1458+
let mut message_count = 0;
1459+
1460+
for msg in step_history.iter().rev() {
1461+
if message_count >= max_messages {
1462+
break;
1463+
}
1464+
1465+
if let Ok(content) = msg.get_message_content() {
1466+
if content.trim().is_empty() {
1467+
message_count += 1;
1468+
continue;
1469+
}
1470+
1471+
// Parse the JSON content to extract tool calls from metadata
1472+
if let Ok(job_message) = serde_json::from_str::<serde_json::Value>(&content) {
1473+
if let Some(metadata) = job_message.get("metadata") {
1474+
if let Some(function_calls) = metadata.get("function_calls").and_then(|fc| fc.as_array()) {
1475+
for call in function_calls {
1476+
if let Some(tool_router_key) = call.get("tool_router_key").and_then(|k| k.as_str()) {
1477+
// Only add if we haven't seen this key before
1478+
if !seen_keys.contains(tool_router_key) {
1479+
seen_keys.insert(tool_router_key.to_string());
1480+
tool_keys.push(tool_router_key.to_string());
1481+
}
1482+
}
1483+
}
1484+
}
1485+
}
1486+
}
1487+
}
1488+
1489+
message_count += 1;
1490+
}
1491+
1492+
tool_keys
1493+
}
14101494
}

shinkai-bin/shinkai-node/src/llm_provider/providers/grok.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33

44
use super::super::error::LLMProviderError;
55
use super::shared::openai_api::openai_prepare_messages;
6-
use super::shared::openai_api_deprecated::{MessageContent, OpenAIResponse};
6+
use super::shared::openai_api::{MessageContent, OpenAIResponse};
77
use super::shared::shared_model_logic::{send_tool_ws_update, send_ws_update};
88
use super::LLMService;
99
use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, LLMInferenceResponse};
@@ -228,12 +228,12 @@ async fn handle_streaming_response(
228228
}
229229

230230
// If we got valid JSON but expected streaming, return the response anyway
231-
if let Ok(data) = serde_json::from_value::<super::shared::openai_api_deprecated::OpenAIResponse>(response_json.clone()) {
231+
if let Ok(data) = serde_json::from_value::<OpenAIResponse>(response_json.clone()) {
232232
let response_string: String = data
233233
.choices
234234
.iter()
235235
.filter_map(|choice| match &choice.message.content {
236-
Some(super::shared::openai_api_deprecated::MessageContent::Text(text)) => Some(text.clone()),
236+
Some(MessageContent::Text(text)) => Some(text.clone()),
237237
_ => None,
238238
})
239239
.collect::<Vec<String>>()

shinkai-bin/shinkai-node/src/llm_provider/providers/groq.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::error::Error;
22
use std::sync::Arc;
33

44
use super::super::error::LLMProviderError;
5-
use super::shared::openai_api_deprecated::{MessageContent, OpenAIResponse};
5+
use super::shared::openai_api::{MessageContent, OpenAIResponse};
66
use super::shared::shared_model_logic::{send_tool_ws_update, send_ws_update};
77
use super::LLMService;
88
use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, LLMInferenceResponse};

shinkai-bin/shinkai-node/src/llm_provider/providers/shared/deepseek_api.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
use crate::llm_provider::error::LLMProviderError;
55
use crate::llm_provider::providers::shared::openai_api;
6-
use crate::managers::model_capabilities_manager::{PromptResult, PromptResultEnum};
6+
use crate::managers::model_capabilities_manager::{ModelCapabilitiesManager, PromptResult, PromptResultEnum};
77
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::LLMProviderInterface;
88
use shinkai_message_primitives::schemas::prompts::Prompt;
9+
use shinkai_message_primitives::schemas::subprompts::{SubPrompt, SubPromptType};
910

1011
// DeepSeek is compatible with the OpenAI API, so we reuse its message
1112
// preparation and response handling logic.
@@ -17,6 +18,17 @@ pub fn deepseek_prepare_messages(
1718
prompt: Prompt,
1819
session_id: String,
1920
) -> Result<PromptResult, LLMProviderError> {
21+
let mut prompt = prompt.clone();
22+
23+
// If this is a reasoning model, filter out system prompts before any processing
24+
if ModelCapabilitiesManager::has_reasoning_capabilities(model) {
25+
prompt.sub_prompts.retain(|sp| match sp {
26+
SubPrompt::Content(SubPromptType::System, _, _) => false,
27+
SubPrompt::Omni(SubPromptType::System, _, _, _) => false,
28+
_ => true,
29+
});
30+
}
31+
2032
let result = openai_api::openai_prepare_messages(model, prompt)?;
2133
let tools_json = result.functions.unwrap_or_else(Vec::new);
2234
let messages_json = result.messages.clone();

shinkai-bin/shinkai-node/src/llm_provider/providers/shared/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ pub mod gemini_api;
44
pub mod groq_api;
55
pub mod ollama_api;
66
pub mod openai_api;
7-
pub mod openai_api_deprecated;
87
pub mod shared_model_logic;
98
pub mod togetherai;

shinkai-bin/shinkai-node/src/llm_provider/providers/shared/openai_api.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,6 @@ pub struct Usage {
110110
}
111111

112112
pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) -> Result<PromptResult, LLMProviderError> {
113-
let mut prompt = prompt.clone();
114-
115-
// If this is a reasoning model, filter out system prompts before any processing
116-
if ModelCapabilitiesManager::has_reasoning_capabilities(model) {
117-
prompt.sub_prompts.retain(|sp| match sp {
118-
SubPrompt::Content(SubPromptType::System, _, _) => false,
119-
SubPrompt::Omni(SubPromptType::System, _, _, _) => false,
120-
_ => true,
121-
});
122-
}
123-
124113
let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(model);
125114

126115
// Generate the messages and filter out images
@@ -645,7 +634,8 @@ mod tests {
645634
let model = SerializedLLMProvider::mock_provider_with_reasoning().model;
646635

647636
// Process the prompt
648-
let result = openai_prepare_messages(&model, prompt).expect("Failed to prepare messages");
637+
let session_id = uuid::Uuid::new_v4().to_string();
638+
let result = crate::llm_provider::providers::shared::deepseek_api::deepseek_prepare_messages(&model, prompt, session_id).expect("Failed to prepare messages");
649639

650640
// Extract the messages from the result
651641
let messages = match &result.messages {

0 commit comments

Comments
 (0)