Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -821,32 +821,42 @@ impl GenericInferenceChain {
.and_then(|config| config.web_search_enabled)
.unwrap_or(false);
if tools_allowed && web_search_enabled {
// Check if web search tool is not already in the tools list
// Check if web search related tools are not already in the tools list and add them
let web_search_tool_key = "local:::__official_shinkai:::web_search";
let has_web_search = tools
.iter()
.any(|tool| tool.tool_router_key().to_string_without_version() == web_search_tool_key);

if !has_web_search {
// Add the web search tool
if let Some(tool_router) = &tool_router {
match tool_router.get_tool_by_name(web_search_tool_key).await {
Ok(Some(web_search_tool)) => {
tools.push(web_search_tool);
}
Ok(None) => {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("Web search tool not found: {}", web_search_tool_key),
);
}
Err(e) => {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("Error retrieving web search tool: {:?}", e),
);
let download_pages_tool_key = "local:::__official_shinkai:::download_pages";
let pdf_text_extractor_tool_key = "local:::__official_shinkai:::pdf_text_extractor";

let web_search_tools = vec![
(web_search_tool_key, "Web search tool"),
(download_pages_tool_key, "Download pages tool"),
(pdf_text_extractor_tool_key, "PDF text extractor tool"),
];

for (tool_key, tool_description) in web_search_tools {
let has_tool = tools
.iter()
.any(|tool| tool.tool_router_key().to_string_without_version() == tool_key);

if !has_tool {
if let Some(tool_router) = &tool_router {
match tool_router.get_tool_by_name(tool_key).await {
Ok(Some(tool)) => {
tools.push(tool);
}
Ok(None) => {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("{} not found: {}", tool_description, tool_key),
);
}
Err(e) => {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("Error retrieving {}: {:?}", tool_description, e),
);
}
}
}
}
Expand Down
72 changes: 53 additions & 19 deletions shinkai-libs/shinkai-mcp/src/mcp_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ pub mod tests_mcp_manager {
let params_map = params.as_object().unwrap().clone();

let result = run_tool_via_command(
"npx -y @modelcontextprotocol/server-everything".to_string(),
"npx -y @modelcontextprotocol/server-everything@2025.9.12".to_string(),
"add".to_string(),
HashMap::new(),
params_map,
Expand All @@ -295,7 +295,7 @@ pub mod tests_mcp_manager {
"npx".to_string(),
Some(vec![
"-y".to_string(),
"@modelcontextprotocol/server-everything".to_string(),
"@modelcontextprotocol/server-everything@2025.9.12".to_string(),
"sse".to_string(),
]) as Option<Vec<String>>,
Some(envs),
Expand Down Expand Up @@ -335,25 +335,31 @@ pub mod tests_mcp_manager {

#[tokio::test]
async fn test_list_tools_via_command() {
let result = list_tools_via_command("npx -y @modelcontextprotocol/server-everything", None).await;
let result = list_tools_via_command("npx -y @modelcontextprotocol/server-everything@2025.9.12", None).await;
assert!(result.is_ok());
let unwrapped = result.unwrap();
assert!(unwrapped.len() == 11);
let tools = [

// Debug output to see actual tools
println!("Actual number of tools: {}", unwrapped.len());
println!("Actual tools: {:?}", unwrapped.iter().map(|t| &t.name).collect::<Vec<_>>());

// The MCP server-everything package now returns 10 tools
assert_eq!(unwrapped.len(), 10, "Expected exactly 10 tools, got {}", unwrapped.len());

let expected_tools = [
"echo",
"add",
"add",
"longRunningOperation",
"printEnv",
"sampleLLM",
"getTinyImage",
"annotatedMessage",
"getResourceReference",
"startElicitation",
"getResourceLinks",
"structuredContent",
];
for tool in tools {
assert!(unwrapped.iter().any(|t| t.name == tool));
for tool in expected_tools {
assert!(unwrapped.iter().any(|t| t.name == tool), "Missing expected tool: {}", tool);
}
}

Expand All @@ -365,7 +371,7 @@ pub mod tests_mcp_manager {
"npx".to_string(),
Some(vec![
"-y".to_string(),
"@modelcontextprotocol/server-everything".to_string(),
"@modelcontextprotocol/server-everything@2025.9.12".to_string(),
"sse".to_string(),
]) as Option<Vec<String>>,
Some(envs),
Expand All @@ -390,22 +396,28 @@ pub mod tests_mcp_manager {
});
assert!(result.is_ok());
let unwrapped = result.unwrap();
assert!(unwrapped.len() == 11);
let tools = [

// Debug output to see actual tools
println!("SSE - Actual number of tools: {}", unwrapped.len());
println!("SSE - Actual tools: {:?}", unwrapped.iter().map(|t| &t.name).collect::<Vec<_>>());

// The MCP server-everything package now returns 10 tools
assert_eq!(unwrapped.len(), 10, "Expected exactly 10 tools, got {}", unwrapped.len());

let expected_tools = [
"echo",
"add",
"add",
"longRunningOperation",
"printEnv",
"sampleLLM",
"getTinyImage",
"annotatedMessage",
"getResourceReference",
"startElicitation",
"getResourceLinks",
"structuredContent",
];
for tool in tools {
assert!(unwrapped.iter().any(|t| t.name == tool));
for tool in expected_tools {
assert!(unwrapped.iter().any(|t| t.name == tool), "Missing expected tool: {}", tool);
}
}

Expand All @@ -417,7 +429,7 @@ pub mod tests_mcp_manager {
"npx".to_string(),
Some(vec![
"-y".to_string(),
"@modelcontextprotocol/server-everything".to_string(),
"@modelcontextprotocol/server-everything@2025.9.12".to_string(),
"streamableHttp".to_string(),
]) as Option<Vec<String>>,
Some(envs),
Expand All @@ -435,7 +447,29 @@ pub mod tests_mcp_manager {
let result = list_tools_via_http("http://localhost:8002/mcp", None).await;
assert!(result.is_ok());
let unwrapped = result.unwrap();
assert!(unwrapped.len() == 11);

// Debug output to see actual tools
println!("HTTP - Actual number of tools: {}", unwrapped.len());
println!("HTTP - Actual tools: {:?}", unwrapped.iter().map(|t| &t.name).collect::<Vec<_>>());

// The MCP server-everything package now returns 10 tools
assert_eq!(unwrapped.len(), 10, "Expected exactly 10 tools, got {}", unwrapped.len());

let expected_tools = [
"echo",
"add",
"longRunningOperation",
"printEnv",
"sampleLLM",
"getTinyImage",
"annotatedMessage",
"getResourceReference",
"getResourceLinks",
"structuredContent",
];
for tool in expected_tools {
assert!(unwrapped.iter().any(|t| t.name == tool), "Missing expected tool: {}", tool);
}
}

#[tokio::test]
Expand All @@ -446,7 +480,7 @@ pub mod tests_mcp_manager {
"npx".to_string(),
Some(vec![
"-y".to_string(),
"@modelcontextprotocol/server-everything".to_string(),
"@modelcontextprotocol/server-everything@2025.9.12".to_string(),
"streamableHttp".to_string(),
]) as Option<Vec<String>>,
Some(envs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ mod tests {

#[test]
fn test_same_command_same_hash() {
let server1 = create_test_server(Some("npx @modelcontextprotocol/server-everything".to_string()));
let server2 = create_test_server(Some("npx @modelcontextprotocol/server-everything".to_string()));
let server1 = create_test_server(Some("npx @modelcontextprotocol/server-everything@2025.9.12".to_string()));
let server2 = create_test_server(Some("npx @modelcontextprotocol/server-everything@2025.9.12".to_string()));

let hash1 = server1.get_command_hash();
let hash2 = server2.get_command_hash();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ mod tests {
name: "@modelcontextprotocol/server-everything".to_string(),
r#type: MCPServerType::Command,
url: None,
command: Some("npx -y @modelcontextprotocol/server-everything".to_string()),
command: Some("npx -y @modelcontextprotocol/server-everything@2025.9.12".to_string()),
is_enabled: true,
env: None,
},
Expand Down