Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1289,34 +1289,51 @@ impl GenericInferenceChain {
}
};

let original_response = function_response.response.clone();
let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(&provider_interface);
let max_tokens_for_response = ((max_input_tokens as f64 * 0.9) as usize).max(1024); // Allow 90% of the context window, minimum 1024 tokens
let response_tokens = count_tokens_from_message_llama3(&function_response.response);
if response_tokens > max_tokens_for_response {
let response_tokens = count_tokens_from_message_llama3(&original_response);
let response_exceeded_limit = response_tokens > max_tokens_for_response;

if response_exceeded_limit {
// Tell the LLM why the tool response was skipped while keeping user visibility of the original output.
function_response.response = json!({
"max_tokens_for_response": max_tokens_for_response,
"max_input_tokens": max_input_tokens,
"response_tokens": response_tokens,
"response": "IMPORTANT: Function response exceeded model context window, try again with a smaller response or a more capable model.",
}).to_string();
})
.to_string();
}

let user_visible_response = if response_exceeded_limit {
json!({
"error": format!("This tool response exceeded the model context window ({} tokens > allowed {}).", response_tokens, max_tokens_for_response),
"new_response": function_response.response,
"original_response": original_response,
}).to_string()
} else {
original_response.clone()
};

let mut function_call_with_router_key = function_call.clone();
function_call_with_router_key.tool_router_key =
Some(shinkai_tool.tool_router_key().to_string_without_version());
function_call_with_router_key.response = Some(function_response.response.clone());
function_call_with_router_key.response = Some(user_visible_response.clone());
tool_calls_history.push(function_call_with_router_key);

// Trigger WS update after receiving function_response
// Trigger WS update after receiving function_response (show user the full tool output when available)
let mut user_function_response = function_response.clone();
user_function_response.response = user_visible_response.clone();
Self::trigger_ws_update(
&ws_manager_trait,
&Some(full_job.job_id.clone()),
&function_response,
&user_function_response,
shinkai_tool.tool_router_key().to_string_without_version(),
)
.await;

// Store all function responses to use in the next prompt
// Store all function responses to use in the next prompt (LLM sees the sanitized version if needed)
iteration_function_responses.push(function_response);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ mod tests {
ToolOutputArg::empty(),
None,
"local:::__official_shinkai:::concat_strings".to_string(),
"1.0.0".to_string(),
);
let shinkai_tool = ShinkaiTool::Rust(tool, true);

Expand Down Expand Up @@ -216,6 +217,7 @@ mod tests {
ToolOutputArg::empty(),
None,
"local:::__official_shinkai:::concat_strings".to_string(),
"1.0.0".to_string(),
);
let shinkai_tool = ShinkaiTool::Rust(tool, true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,7 @@ use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider:
use shinkai_message_primitives::schemas::prompts::Prompt;
use std::collections::HashMap;
use uuid::Uuid;
use super::shared_model_logic::get_image_type;

fn sanitize_tool_name(name: &str) -> String {
let sanitized: String = name
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c.to_ascii_lowercase()
} else {
'_'
}
})
.collect();

// Ensure length is between 1 and 64 characters
if sanitized.is_empty() {
"tool".to_string()
} else {
sanitized.chars().take(64).collect()
}
}
use super::shared_model_logic::{get_image_type, sanitize_tool_name};

pub fn claude_prepare_messages(
model: &LLMProviderInterface,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::shared_model_logic::{get_image_type, get_video_type, get_audio_type};
use super::shared_model_logic::{get_image_type, get_video_type, get_audio_type, sanitize_tool_name};
use crate::llm_provider::error::LLMProviderError;
use crate::managers::model_capabilities_manager::ModelCapabilitiesManager;
use crate::managers::model_capabilities_manager::PromptResult;
Expand Down Expand Up @@ -301,9 +301,7 @@ pub fn gemini_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
}

serde_json::json!({
"name": function.name.chars()
.map(|c| if c.is_alphanumeric() || c == '_' || c == '-' { c.to_ascii_lowercase() } else { '_' })
.collect::<String>(),
"name": sanitize_tool_name(&function.name),
"description": function.description,
"parameters": function.parameters
})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use shinkai_message_primitives::schemas::{
llm_message::{LlmMessage, DetailedFunctionCall}, llm_providers::serialized_llm_provider::LLMProviderInterface, prompts::Prompt
llm_message::LlmMessage, llm_providers::serialized_llm_provider::LLMProviderInterface, prompts::Prompt
};
use super::shared_model_logic::sanitize_tool_name;
use serde::{Deserialize, Serialize};
use serde_json::Value;

Expand Down Expand Up @@ -174,19 +175,7 @@ pub fn ollama_conversation_prepare_messages_with_tooling(
.map(|mut tool| {
if let Some(functions) = tool.functions.as_mut() {
for function in functions {
// Replace any characters that aren't alphanumeric, underscore, or hyphen
function.name = function
.name
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c
} else {
'_'
}
})
.collect::<String>()
.to_lowercase();
function.name = sanitize_tool_name(&function.name);
}
}
tool
Expand Down Expand Up @@ -245,7 +234,7 @@ pub fn ollama_conversation_prepare_messages_with_tooling(
mod tests {
use serde_json::json;
use shinkai_message_primitives::schemas::{
llm_providers::serialized_llm_provider::SerializedLLMProvider, subprompts::{SubPrompt, SubPromptAssetType, SubPromptType}
llm_message::DetailedFunctionCall, llm_providers::serialized_llm_provider::SerializedLLMProvider, subprompts::{SubPrompt, SubPromptAssetType, SubPromptType}
};

use super::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ use serde::{Deserialize, Serialize};
use serde_json::{self};
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::LLMProviderInterface;
use shinkai_message_primitives::schemas::prompts::Prompt;
use shinkai_message_primitives::schemas::subprompts::{SubPrompt, SubPromptType};

use super::shared_model_logic;
use super::shared_model_logic::{self, sanitize_tool_name};

#[derive(Debug, Deserialize)]
pub struct OpenAIResponse {
Expand Down Expand Up @@ -140,30 +139,7 @@ pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
.map(|mut tool| {
if let Some(functions) = tool.functions.as_mut() {
for function in functions {
// Replace any characters that aren't alphanumeric, underscore, or hyphen
let mut sanitized_name = function
.name
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c
} else {
'_'
}
})
.collect::<String>()
.to_lowercase();

// Truncate function name to OpenAI's 64-character limit
// If name is too long, keep the end part (similar to agent_id truncation)
let max_len = 64;
if sanitized_name.len() > max_len {
let chars: Vec<char> = sanitized_name.chars().collect();
let start_index = chars.len() - max_len;
sanitized_name = chars[start_index..].iter().collect();
}

function.name = sanitized_name;
function.name = sanitize_tool_name(&function.name);
}
}
tool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,32 @@ pub async fn send_tool_ws_update_with_status(
}
Ok(())
}

pub fn sanitize_tool_name(name: &str) -> String {
let sanitized: String = name
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c.to_ascii_lowercase()
} else {
'_'
}
})
.collect();

let mut result = if sanitized.is_empty() {
"tool".to_string()
} else {
sanitized.chars().take(64).collect()
};

// Ensure the name starts with a letter or underscore
if let Some(first_char) = result.chars().next() {
if !first_char.is_alphabetic() && first_char != '_' {
result = format!("t_{}", result);
}
}

// Ensure length is still within 64 characters after potential prefix
result.chars().take(64).collect()
}
59 changes: 50 additions & 9 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ impl ToolRouter {
println!("Adding {} Rust tools", rust_tools.len());
let mut added_count = 0;
let mut skipped_count = 0;
let mut upgraded_count = 0;

for tool in rust_tools {
let rust_tool = RustTool::new(
Expand All @@ -605,26 +606,66 @@ impl ToolRouter {
tool.output_arg,
None,
tool.tool_router_key,
tool.version,
);

let _ = match self.sqlite_manager.get_tool_by_key(&rust_tool.tool_router_key) {
let router_key = rust_tool.tool_router_key.clone();
let new_version = IndexableVersion::from_string(&rust_tool.version).map_err(|e| {
ToolError::ParseError(format!("Invalid Rust tool version '{}': {}", rust_tool.version, e))
})?;

match self.sqlite_manager.get_tool_header_by_key(&router_key) {
Err(SqliteManagerError::ToolNotFound(_)) => {
added_count += 1;
self.sqlite_manager
.add_tool(ShinkaiTool::Rust(rust_tool, true))
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
}
Err(e) => Err(ToolError::DatabaseError(e.to_string())),
Ok(_db_tool) => {
skipped_count += 1;
continue;
Err(e) => return Err(ToolError::DatabaseError(e.to_string())),
Ok(header) => {
let current_version = IndexableVersion::from_string(&header.version).map_err(|e| {
ToolError::ParseError(format!(
"Invalid installed Rust tool version '{}': {}",
header.version, e
))
})?;

if new_version > current_version {
match self.sqlite_manager.get_tool_by_key(&router_key) {
Ok(ShinkaiTool::Rust(existing_rust_tool, is_enabled)) => {
let mut upgraded_tool = rust_tool.clone();
if upgraded_tool.mcp_enabled.is_none() {
upgraded_tool.mcp_enabled = existing_rust_tool.mcp_enabled;
Comment on lines +636 to +639

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Preserve MCP flag when upgrading Rust tools

The upgrade branch rebuilds the new RustTool and only copies mcp_enabled from the existing tool when the new definition leaves the flag unset. Because RustTool::new always initialises mcp_enabled to Some(false), this condition never fires and an upgrade will always overwrite any user‑chosen MCP value with false. Users who enabled MCP for a tool will find it disabled after a version bump. The upgrade path should carry over the previous mcp_enabled value regardless, unless a new default is explicitly intended.

Useful? React with 👍 / 👎.

}

upgraded_count += 1;
self.sqlite_manager
.upgrade_tool(ShinkaiTool::Rust(upgraded_tool, is_enabled))
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
}
Ok(other_variant) => {
skipped_count += 1;
eprintln!(
"Expected Rust tool for key '{}' but found {:?}, skipping",
router_key,
other_variant.tool_type()
);
}
Err(err) => {
return Err(ToolError::DatabaseError(err.to_string()));
}
}
} else {
skipped_count += 1;
}
}
}?;
}
}
println!(
"Rust tools installation complete - Added: {}, Skipped: {}",
added_count, skipped_count
"Rust tools installation complete - Added: {}, Upgraded: {}, Skipped: {}",
added_count, upgraded_count, skipped_count
);
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde_json::{json, Value};
use shinkai_http_api::node_api_router::{APIError, SendResponseBody, SendResponseBodyData};
use shinkai_message_primitives::{
schemas::{
identity::Identity, inbox_name::InboxName, job::{ForkedJob, JobLike}, job_config::JobConfig, llm_providers::{common_agent_llm_provider::ProviderOrAgent, serialized_llm_provider::SerializedLLMProvider}, shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, smart_inbox::{LLMProviderSubset, ProviderType, SmartInbox, V2SmartInbox}
identity::Identity, inbox_name::InboxName, inbox_permission::InboxPermission, job::{ForkedJob, JobLike}, job_config::JobConfig, llm_providers::{common_agent_llm_provider::ProviderOrAgent, serialized_llm_provider::SerializedLLMProvider}, shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, smart_inbox::{LLMProviderSubset, ProviderType, SmartInbox, V2SmartInbox}
}, shinkai_message::{
shinkai_message::{MessageBody, MessageData}, shinkai_message_schemas::{
APIChangeJobAgentRequest, ExportInboxMessagesFormat, JobCreationInfo, JobMessage, MessageSchemaType, V2ChatMessage
Expand Down Expand Up @@ -1371,7 +1371,7 @@ impl Node {
pub async fn fork_job(
db: Arc<SqliteManager>,
_node_name: ShinkaiName,
_identity_manager: Arc<Mutex<IdentityManager>>,
identity_manager: Arc<Mutex<IdentityManager>>,
job_id: String,
message_id: Option<String>,
node_encryption_sk: EncryptionStaticKey,
Expand All @@ -1385,6 +1385,28 @@ impl Node {
message: format!("Failed to retrieve job: {}", err),
})?;

// Get the requesting identity to keep permissions aligned with the new forked job
let requesting_identity = {
let identity_manager = identity_manager.lock().await;
match identity_manager.get_main_identity() {
Some(Identity::Standard(identity)) => identity.clone(),
Some(_) => {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: "Main identity is not a standard identity".to_string(),
});
}
None => {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: "Failed to get main identity".to_string(),
});
}
}
};

// Retrieve the message from the inbox
let message_id = match message_id {
Some(message_id) => message_id,
Expand Down Expand Up @@ -1456,6 +1478,24 @@ impl Node {
message: format!("Failed to create new job: {}", err),
})?;

let forked_inbox_name =
InboxName::get_job_inbox_name_from_params(forked_job_id.clone()).map_err(|err| APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to build forked inbox name: {}", err),
})?;

db.add_permission(
&forked_inbox_name.to_string(),
&requesting_identity,
InboxPermission::Admin,
)
.map_err(|err| APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to add permissions for forked job: {}", err),
})?;

// Fork the messages
let mut forked_message_map: HashMap<String, String> = HashMap::new();

Expand Down
Loading