From fbd73ddb5ce39b456e18a5f4d05e13d7d2726194 Mon Sep 17 00:00:00 2001 From: Harish Patidar Date: Thu, 4 Sep 2025 14:40:29 +0530 Subject: [PATCH] Implement lazy initialization for agent components --- swarms-rs/src/agent/swarms_agent.rs | 70 ++++++++++++++++++----------- swarms-rs/src/structs/agent.rs | 2 + swarms-rs/tests/example_test.rs | 31 +++++++++++++ 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/swarms-rs/src/agent/swarms_agent.rs b/swarms-rs/src/agent/swarms_agent.rs index a345c3b..787b749 100644 --- a/swarms-rs/src/agent/swarms_agent.rs +++ b/swarms-rs/src/agent/swarms_agent.rs @@ -82,14 +82,6 @@ //! - **State Persistence**: Optional automatic saving of agent state to disk //! - **Task Hashing**: Efficient state management using content-based hashing -use std::{ - ffi::OsStr, - hash::{Hash, Hasher}, - ops::Deref, - path::Path, - sync::Arc, -}; - use colored::*; use dashmap::DashMap; use futures::{StreamExt, future::BoxFuture, stream}; @@ -101,6 +93,15 @@ use rmcp::{ }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::sync::OnceLock; +use std::sync::RwLock; +use std::{ + ffi::OsStr, + hash::{Hash, Hasher}, + ops::Deref, + path::Path, + sync::Arc, +}; use swarms_macro::tool; use thiserror::Error; use tokio::{ @@ -169,7 +170,7 @@ where /// List of tool definitions available to the agent tools: Vec, /// Implementation instances of tools, keyed by tool name - tools_impl: DashMap>, + tools_impl: OnceLock>>, } impl SwarmsAgentBuilder @@ -204,7 +205,7 @@ where config: AgentConfig::default(), system_prompt: None, tools: vec![], - tools_impl: DashMap::new(), + tools_impl: OnceLock::new(), } } @@ -315,7 +316,7 @@ where /// ``` pub fn add_tool(mut self, tool: T) -> Self { self.tools.push(tool.definition()); - self.tools_impl + self.tools_impl() .insert(tool.name().to_string(), Arc::new(tool) as Arc); self } @@ -461,7 +462,7 @@ where ); } self.tools.insert(0, ToolDyn::definition(&TaskEvaluator)); - self.tools_impl.insert( + self.tools_impl().insert( ToolDyn::name(&TaskEvaluator), Arc::new(TaskEvaluator) as Arc, ); @@ -471,7 +472,7 @@ where model: self.model, config: self.config.clone(), system_prompt: self.system_prompt, - short_memory: AgentShortMemory::new(), + short_memory: OnceLock::new(), tools: self.tools.clone(), tools_impl: self.tools_impl, }; @@ -603,6 +604,11 @@ where self.config.verbose = verbose; self } + + #[inline] + pub fn tools_impl(&mut self) -> &DashMap> { + self.tools_impl.get_or_init(|| DashMap::new()) + } } /// The main Swarms Agent implementation providing autonomous task execution capabilities. @@ -702,12 +708,13 @@ where /// Optional system prompt that guides agent behavior system_prompt: Option, /// Short-term memory for maintaining conversation history - short_memory: AgentShortMemory, + #[serde(skip)] + short_memory: OnceLock, /// List of available tool definitions tools: Vec, /// Tool implementation instances (not serialized) #[serde(skip)] - tools_impl: DashMap>, + tools_impl: OnceLock>>, } impl SwarmsAgent @@ -742,9 +749,9 @@ where model, system_prompt: system_prompt.into(), config: AgentConfig::default(), - short_memory: AgentShortMemory::new(), + short_memory: OnceLock::new(), tools: vec![], - tools_impl: DashMap::new(), + tools_impl: OnceLock::new(), } } @@ -841,7 +848,7 @@ where let results = Arc::clone(&results); async move { let tool = Arc::clone( - match self.tools_impl.get(&tool_call.name) { + match self.tools_impl().get(&tool_call.name) { Some(tool) => tool, None => { tracing::error!("Tool not found: {}", tool_call.name); @@ -885,7 +892,7 @@ where } else { for tool_call in all_tool_calls { let tool = Arc::clone( - self.tools_impl + self.tools_impl() .get(&tool_call.name) .ok_or(AgentError::ToolNotFound(tool_call.name.clone()))? .deref(), @@ -970,7 +977,8 @@ where let toolname = tool.name(); let definition = tool.definition(); self.tools.push(definition); - self.tools_impl.insert(toolname, Arc::new(tool)); + self.tools_impl() + .insert(toolname, Arc::new(tool)); self } @@ -999,6 +1007,16 @@ where }); } } + + #[inline] + pub fn memory(&self) -> &AgentShortMemory { + self.short_memory.get_or_init(AgentShortMemory::new) + } + + #[inline] + pub fn tools_impl(&self) -> &DashMap> { + self.tools_impl.get_or_init(|| DashMap::new()) + } } impl Agent for SwarmsAgent @@ -1020,7 +1038,7 @@ where ); } - self.short_memory.add( + self.memory().add( &task, &self.config.name, Role::User(self.config.user_name.clone()), @@ -1156,7 +1174,7 @@ where // } // Generate response using LLM - let history = self.short_memory.0.get(&task).unwrap(); // Safety: task is in short_memory + let history = self.memory().0.get(&task).unwrap(); // Safety: task is in short_memory let current_chat_response = match self.chat(¤t_prompt, history.deref()).await { Ok(response) => response, @@ -1261,7 +1279,7 @@ where // Update the flag for the *next* iteration based on *this* iteration's call was_prev_call_task_evaluator = is_task_evaluator_called && !task_complete; - self.short_memory.add( + self.memory().add( &task, &self.config.name, Role::Assistant(self.config.name.to_owned()), @@ -1331,7 +1349,7 @@ where // TODO: More flexible output types, e.g. JSON, CSV, etc. Ok(self - .short_memory + .memory() .0 .get(&task) .expect("Task should exist in short memory") @@ -1383,7 +1401,7 @@ where let plan = self.prompt(planning_prompt).await?; tracing::debug!("Plan: {}", plan); // Add plan to memory - self.short_memory.add( + self.memory().add( task, self.config.name.clone(), Role::Assistant(self.config.name.clone()), @@ -1416,7 +1434,7 @@ where .join(format!("{}_{}", self.name(), task_hash)) .with_extension("json"); - let json = serde_json::to_string_pretty(&self.short_memory.0.get(&task).unwrap())?; // TODO: Safety? + let json = serde_json::to_string_pretty(&self.memory().0.get(&task).unwrap())?; // TODO: Safety? persistence::save_to_file(&json, path).await?; } Ok(()) diff --git a/swarms-rs/src/structs/agent.rs b/swarms-rs/src/structs/agent.rs index 93aece8..499bfe6 100644 --- a/swarms-rs/src/structs/agent.rs +++ b/swarms-rs/src/structs/agent.rs @@ -13,6 +13,8 @@ use tokio::sync::broadcast; #[derive(Debug, Error)] pub enum AgentError { + #[error("Tool {0} not found")] + ConfigError(String), #[error("IO error: {0}")] IoError(#[from] std::io::Error), #[error("Serde json error: {0}")] diff --git a/swarms-rs/tests/example_test.rs b/swarms-rs/tests/example_test.rs index f616065..e8ee107 100644 --- a/swarms-rs/tests/example_test.rs +++ b/swarms-rs/tests/example_test.rs @@ -1,5 +1,9 @@ use anyhow::Result; use swarms_rs::{llm::provider::openai::OpenAI, structs::agent::Agent}; +use swarms_rs::agent::SwarmsAgentBuilder; +use swarms_rs::llm::request::ToolDefinition; +use swarms_rs::agent::SwarmsAgent; + #[tokio::test] async fn test_basic_agent_functionality() -> Result<()> { @@ -66,3 +70,30 @@ async fn test_agent_creation() -> Result<()> { assert_eq!(agent.name(), "MockAgent"); Ok(()) } + +#[tokio::test] +async fn test_lazy_initialized_for_SwarmsAgentBuilder() { + let mut builder = SwarmsAgentBuilder::new_with_model(OpenAI::new("mock-api-key".to_string())); + + // Initially, tools_impl should not be initialized + assert!( + builder.tools_impl().is_empty(), + "tools_impl should be empty before first access" + ); + + { + // First borrow + let tools_map = builder.tools_impl(); + assert!( + tools_map.is_empty(), + "tools_impl should be empty after first access" + ); + } // tools_map goes out of scope here + + // Now safe to re-borrow + let tools_map2 = builder.tools_impl(); + assert!( + tools_map2.is_empty(), + "tools_impl should still be empty when no tools are added" + ); +} \ No newline at end of file