Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1289,34 +1289,51 @@ impl GenericInferenceChain {
}
};

let original_response = function_response.response.clone();
let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(&provider_interface);
let max_tokens_for_response = ((max_input_tokens as f64 * 0.9) as usize).max(1024); // Allow 90% of the context window, minimum 1024 tokens
let response_tokens = count_tokens_from_message_llama3(&function_response.response);
if response_tokens > max_tokens_for_response {
let response_tokens = count_tokens_from_message_llama3(&original_response);
let response_exceeded_limit = response_tokens > max_tokens_for_response;

if response_exceeded_limit {
// Tell the LLM why the tool response was skipped while keeping user visibility of the original output.
function_response.response = json!({
"max_tokens_for_response": max_tokens_for_response,
"max_input_tokens": max_input_tokens,
"response_tokens": response_tokens,
"response": "IMPORTANT: Function response exceeded model context window, try again with a smaller response or a more capable model.",
}).to_string();
})
.to_string();
}

let user_visible_response = if response_exceeded_limit {
json!({
"error": format!("This tool response exceeded the model context window ({} tokens > allowed {}).", response_tokens, max_tokens_for_response),
"new_response": function_response.response,
"original_response": original_response,
}).to_string()
} else {
original_response.clone()
};

let mut function_call_with_router_key = function_call.clone();
function_call_with_router_key.tool_router_key =
Some(shinkai_tool.tool_router_key().to_string_without_version());
function_call_with_router_key.response = Some(function_response.response.clone());
function_call_with_router_key.response = Some(user_visible_response.clone());
tool_calls_history.push(function_call_with_router_key);

// Trigger WS update after receiving function_response
// Trigger WS update after receiving function_response (show user the full tool output when available)
let mut user_function_response = function_response.clone();
user_function_response.response = user_visible_response.clone();
Self::trigger_ws_update(
&ws_manager_trait,
&Some(full_job.job_id.clone()),
&function_response,
&user_function_response,
shinkai_tool.tool_router_key().to_string_without_version(),
)
.await;

// Store all function responses to use in the next prompt
// Store all function responses to use in the next prompt (LLM sees the sanitized version if needed)
iteration_function_responses.push(function_response);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ mod tests {
ToolOutputArg::empty(),
None,
"local:::__official_shinkai:::concat_strings".to_string(),
"1.0.0".to_string(),
);
let shinkai_tool = ShinkaiTool::Rust(tool, true);

Expand Down Expand Up @@ -216,6 +217,7 @@ mod tests {
ToolOutputArg::empty(),
None,
"local:::__official_shinkai:::concat_strings".to_string(),
"1.0.0".to_string(),
);
let shinkai_tool = ShinkaiTool::Rust(tool, true);

Expand Down
59 changes: 50 additions & 9 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ impl ToolRouter {
println!("Adding {} Rust tools", rust_tools.len());
let mut added_count = 0;
let mut skipped_count = 0;
let mut upgraded_count = 0;

for tool in rust_tools {
let rust_tool = RustTool::new(
Expand All @@ -605,26 +606,66 @@ impl ToolRouter {
tool.output_arg,
None,
tool.tool_router_key,
tool.version,
);

let _ = match self.sqlite_manager.get_tool_by_key(&rust_tool.tool_router_key) {
let router_key = rust_tool.tool_router_key.clone();
let new_version = IndexableVersion::from_string(&rust_tool.version).map_err(|e| {
ToolError::ParseError(format!("Invalid Rust tool version '{}': {}", rust_tool.version, e))
})?;

match self.sqlite_manager.get_tool_header_by_key(&router_key) {
Err(SqliteManagerError::ToolNotFound(_)) => {
added_count += 1;
self.sqlite_manager
.add_tool(ShinkaiTool::Rust(rust_tool, true))
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
}
Err(e) => Err(ToolError::DatabaseError(e.to_string())),
Ok(_db_tool) => {
skipped_count += 1;
continue;
Err(e) => return Err(ToolError::DatabaseError(e.to_string())),
Ok(header) => {
let current_version = IndexableVersion::from_string(&header.version).map_err(|e| {
ToolError::ParseError(format!(
"Invalid installed Rust tool version '{}': {}",
header.version, e
))
})?;

if new_version > current_version {
match self.sqlite_manager.get_tool_by_key(&router_key) {
Ok(ShinkaiTool::Rust(existing_rust_tool, is_enabled)) => {
let mut upgraded_tool = rust_tool.clone();
if upgraded_tool.mcp_enabled.is_none() {
upgraded_tool.mcp_enabled = existing_rust_tool.mcp_enabled;
Comment on lines +636 to +639

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Preserve MCP flag when upgrading Rust tools

The upgrade branch rebuilds the new RustTool and only copies mcp_enabled from the existing tool when the new definition leaves the flag unset. Because RustTool::new always initialises mcp_enabled to Some(false), this condition never fires and an upgrade will always overwrite any user‑chosen MCP value with false. Users who enabled MCP for a tool will find it disabled after a version bump. The upgrade path should carry over the previous mcp_enabled value regardless, unless a new default is explicitly intended.

Useful? React with 👍 / 👎.

}

upgraded_count += 1;
self.sqlite_manager
.upgrade_tool(ShinkaiTool::Rust(upgraded_tool, is_enabled))
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
}
Ok(other_variant) => {
skipped_count += 1;
eprintln!(
"Expected Rust tool for key '{}' but found {:?}, skipping",
router_key,
other_variant.tool_type()
);
}
Err(err) => {
return Err(ToolError::DatabaseError(err.to_string()));
}
}
} else {
skipped_count += 1;
}
}
}?;
}
}
println!(
"Rust tools installation complete - Added: {}, Skipped: {}",
added_count, skipped_count
"Rust tools installation complete - Added: {}, Upgraded: {}, Skipped: {}",
added_count, upgraded_count, skipped_count
);
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde_json::{json, Value};
use shinkai_http_api::node_api_router::{APIError, SendResponseBody, SendResponseBodyData};
use shinkai_message_primitives::{
schemas::{
identity::Identity, inbox_name::InboxName, job::{ForkedJob, JobLike}, job_config::JobConfig, llm_providers::{common_agent_llm_provider::ProviderOrAgent, serialized_llm_provider::SerializedLLMProvider}, shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, smart_inbox::{LLMProviderSubset, ProviderType, SmartInbox, V2SmartInbox}
identity::Identity, inbox_name::InboxName, inbox_permission::InboxPermission, job::{ForkedJob, JobLike}, job_config::JobConfig, llm_providers::{common_agent_llm_provider::ProviderOrAgent, serialized_llm_provider::SerializedLLMProvider}, shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, smart_inbox::{LLMProviderSubset, ProviderType, SmartInbox, V2SmartInbox}
}, shinkai_message::{
shinkai_message::{MessageBody, MessageData}, shinkai_message_schemas::{
APIChangeJobAgentRequest, ExportInboxMessagesFormat, JobCreationInfo, JobMessage, MessageSchemaType, V2ChatMessage
Expand Down Expand Up @@ -1371,7 +1371,7 @@ impl Node {
pub async fn fork_job(
db: Arc<SqliteManager>,
_node_name: ShinkaiName,
_identity_manager: Arc<Mutex<IdentityManager>>,
identity_manager: Arc<Mutex<IdentityManager>>,
job_id: String,
message_id: Option<String>,
node_encryption_sk: EncryptionStaticKey,
Expand All @@ -1385,6 +1385,28 @@ impl Node {
message: format!("Failed to retrieve job: {}", err),
})?;

// Get the requesting identity to keep permissions aligned with the new forked job
let requesting_identity = {
let identity_manager = identity_manager.lock().await;
match identity_manager.get_main_identity() {
Some(Identity::Standard(identity)) => identity.clone(),
Some(_) => {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: "Main identity is not a standard identity".to_string(),
});
}
None => {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: "Failed to get main identity".to_string(),
});
}
}
};

// Retrieve the message from the inbox
let message_id = match message_id {
Some(message_id) => message_id,
Expand Down Expand Up @@ -1456,6 +1478,24 @@ impl Node {
message: format!("Failed to create new job: {}", err),
})?;

let forked_inbox_name =
InboxName::get_job_inbox_name_from_params(forked_job_id.clone()).map_err(|err| APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to build forked inbox name: {}", err),
})?;

db.add_permission(
&forked_inbox_name.to_string(),
&requesting_identity,
InboxPermission::Admin,
)
.map_err(|err| APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to add permissions for forked job: {}", err),
})?;

// Fork the messages
let mut forked_message_map: HashMap<String, String> = HashMap::new();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ INSERT INTO table_name (field_1, field_3, field_4)

-- Example read:
SELECT * FROM table_name WHERE field_2 > datetime('now', '-1 day');
SELECT field_1, field_3 FROM table_name WHERE field_3 > 100 ORDER BY field_2 DESC LIMIT 10;"#
SELECT field_1, field_3 FROM table_name WHERE field_3 > 100 ORDER BY field_2 DESC LIMIT 10;

-- Changelog:
- 1.0.1: Fixed parameters to be compliant with JSON schema."#
.to_string(),
tool_router_key: "local:::__official_shinkai:::shinkai_sqlite_query_executor".to_string(),
tool_type: "Rust".to_string(),
formatted_tool_summary_for_ui: "Execute SQLite queries".to_string(),
author: "@@official.shinkai".to_string(),
version: "1.0".to_string(),
version: "1.0.1".to_string(),
enabled: true,
mcp_enabled: Some(false),
input_args: {
Expand Down Expand Up @@ -311,6 +314,7 @@ mod tests {
mcp_enabled: sql_processor_tool.tool.mcp_enabled.clone(),
input_args: sql_processor_tool.tool.input_args.clone(),
output_arg: sql_processor_tool.tool.output_arg.clone(),
version: sql_processor_tool.tool.version.clone(),
tool_embedding: sql_processor_tool._tool_embedding.clone(),
tool_router_key: sql_processor_tool.tool.tool_router_key.clone(),
};
Expand Down
49 changes: 46 additions & 3 deletions shinkai-libs/shinkai-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,8 @@ impl SqliteManager {
)?;
}

// Step 4: Regenerate embeddings for file chunks
// Step 4: Regenerate embeddings for file chunks with smart truncation
// Keep existing chunks but truncate only if new model has smaller context window
shinkai_log(
ShinkaiLogOption::Database,
ShinkaiLogLevel::Info,
Expand All @@ -1445,8 +1446,20 @@ impl SqliteManager {
chunk_iter.collect::<Result<Vec<_>, _>>()?
};

let new_model_max_tokens = new_model_type.max_input_token_count();
let mut truncated_count = 0;
let total_chunks = chunks.len();

for (chunk_id, parsed_file_id, text) in chunks {
let embedding = embedding_generator.generate_embedding_default(&text).await
// Only truncate if the chunk exceeds new model's token limit
let processed_text = if text.chars().count() > new_model_max_tokens {
truncated_count += 1;
text.chars().take(new_model_max_tokens).collect()
} else {
text
};

let embedding = embedding_generator.generate_embedding_default(&processed_text).await
.map_err(|e| SqliteManagerError::SerializationError(format!("Chunk embedding generation failed: {}", e)))?;

let conn = self.get_connection()?;
Expand All @@ -1456,7 +1469,37 @@ impl SqliteManager {
)?;
}

// Step 5: Update the database with the new model type
if truncated_count > 0 {
shinkai_log(
ShinkaiLogOption::Database,
ShinkaiLogLevel::Info,
&format!("Truncated {} out of {} chunks to fit new model's {} token limit",
truncated_count, total_chunks, new_model_max_tokens),
);
} else {
shinkai_log(
ShinkaiLogOption::Database,
ShinkaiLogLevel::Info,
"No chunks needed truncation - all fit within new model's token limit",
);
}

// Step 5: Update parsed_files.embedding_model_used column for consistency
shinkai_log(
ShinkaiLogOption::Database,
ShinkaiLogLevel::Info,
"Updating parsed_files.embedding_model_used column",
);

{
let conn = self.get_connection()?;
conn.execute(
"UPDATE parsed_files SET embedding_model_used = ?1",
rusqlite::params![new_model_type.to_string()],
)?;
}

// Step 6: Update the database with the new model type
self.update_default_embedding_model(new_model_type.clone())?;

shinkai_log(
Expand Down
3 changes: 3 additions & 0 deletions shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ impl SqliteManager {
new_python.config = merged_config;
(old_config, ShinkaiTool::Python(new_python, is_enabled))
}
(ShinkaiTool::Rust(_old_rust, _), ShinkaiTool::Rust(new_rust, is_enabled)) => {
(Vec::new(), ShinkaiTool::Rust(new_rust, is_enabled))
}
_ => return Err(SqliteManagerError::ToolTypeMismatch),
};

Expand Down
11 changes: 10 additions & 1 deletion shinkai-libs/shinkai-tools-primitives/src/tools/rust_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@ impl fmt::Display for RustToolError {

impl std::error::Error for RustToolError {}

fn default_rust_tool_version() -> String {
"1.0.0".to_string()
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct RustTool {
pub name: String,
pub description: String,
pub input_args: Parameters,
pub output_arg: ToolOutputArg,
#[serde(default = "default_rust_tool_version")]
pub version: String,
pub tool_embedding: Option<Vec<f32>>,
pub tool_router_key: String,
pub mcp_enabled: Option<bool>,
Expand All @@ -46,12 +52,14 @@ impl RustTool {
output_arg: ToolOutputArg,
tool_embedding: Option<Vec<f32>>,
tool_router_key: String,
version: String,
) -> Self {
Self {
name: utils::clean_string(&name),
description,
input_args,
output_arg,
version,
tool_embedding,
tool_router_key,
mcp_enabled: Some(false),
Expand Down Expand Up @@ -79,6 +87,7 @@ impl RustTool {
description: header.description.clone(),
input_args: header.input_args.clone(),
output_arg: header.output_arg.clone(),
version: header.version.clone(),
tool_embedding: None, // Assuming no embedding is provided in the header
tool_router_key: header.tool_router_key.clone(),
mcp_enabled: header.mcp_enabled,
Expand Down Expand Up @@ -106,7 +115,7 @@ impl RustTool {

ToolPlaygroundMetadata {
name: self.name.clone(),
version: "1.0.0".to_string(),
version: self.version.clone(),
homepage: None,
description: self.description.clone(),
author: self.author(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl ShinkaiTool {
/// Returns the version of the tool
pub fn version(&self) -> String {
match self {
ShinkaiTool::Rust(_r, _) => "1.0.0".to_string(),
ShinkaiTool::Rust(r, _) => r.version.clone(),
ShinkaiTool::Network(n, _) => n.version.clone(),
ShinkaiTool::Deno(d, _) => d.version.clone(),
ShinkaiTool::Python(p, _) => p.version.clone(),
Expand Down