Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 44 additions & 26 deletions swarms-rs/src/agent/swarms_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,6 @@
//! - **State Persistence**: Optional automatic saving of agent state to disk
//! - **Task Hashing**: Efficient state management using content-based hashing

use std::{
ffi::OsStr,
hash::{Hash, Hasher},
ops::Deref,
path::Path,
sync::Arc,
};

use colored::*;
use dashmap::DashMap;
use futures::{StreamExt, future::BoxFuture, stream};
Expand All @@ -101,6 +93,15 @@ use rmcp::{
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::OnceLock;
use std::sync::RwLock;
use std::{
ffi::OsStr,
hash::{Hash, Hasher},
ops::Deref,
path::Path,
sync::Arc,
};
use swarms_macro::tool;
use thiserror::Error;
use tokio::{
Expand Down Expand Up @@ -169,7 +170,7 @@ where
/// List of tool definitions available to the agent
tools: Vec<ToolDefinition>,
/// Implementation instances of tools, keyed by tool name
tools_impl: DashMap<String, Arc<dyn ToolDyn>>,
tools_impl: OnceLock<DashMap<String, Arc<dyn ToolDyn>>>,
}

impl<M> SwarmsAgentBuilder<M>
Expand Down Expand Up @@ -204,7 +205,7 @@ where
config: AgentConfig::default(),
system_prompt: None,
tools: vec![],
tools_impl: DashMap::new(),
tools_impl: OnceLock::new(),
}
}

Expand Down Expand Up @@ -315,7 +316,7 @@ where
/// ```
pub fn add_tool<T: Tool + 'static>(mut self, tool: T) -> Self {
self.tools.push(tool.definition());
self.tools_impl
self.tools_impl()
.insert(tool.name().to_string(), Arc::new(tool) as Arc<dyn ToolDyn>);
self
}
Expand Down Expand Up @@ -461,7 +462,7 @@ where
);
}
self.tools.insert(0, ToolDyn::definition(&TaskEvaluator));
self.tools_impl.insert(
self.tools_impl().insert(
ToolDyn::name(&TaskEvaluator),
Arc::new(TaskEvaluator) as Arc<dyn ToolDyn>,
);
Expand All @@ -471,7 +472,7 @@ where
model: self.model,
config: self.config.clone(),
system_prompt: self.system_prompt,
short_memory: AgentShortMemory::new(),
short_memory: OnceLock::new(),
tools: self.tools.clone(),
tools_impl: self.tools_impl,
};
Expand Down Expand Up @@ -603,6 +604,11 @@ where
self.config.verbose = verbose;
self
}

#[inline]
pub fn tools_impl(&mut self) -> &DashMap<String, Arc<dyn ToolDyn>> {
self.tools_impl.get_or_init(|| DashMap::new())
}
}

/// The main Swarms Agent implementation providing autonomous task execution capabilities.
Expand Down Expand Up @@ -702,12 +708,13 @@ where
/// Optional system prompt that guides agent behavior
system_prompt: Option<String>,
/// Short-term memory for maintaining conversation history
short_memory: AgentShortMemory,
#[serde(skip)]
short_memory: OnceLock<AgentShortMemory>,
/// List of available tool definitions
tools: Vec<ToolDefinition>,
/// Tool implementation instances (not serialized)
#[serde(skip)]
tools_impl: DashMap<String, Arc<dyn ToolDyn>>,
tools_impl: OnceLock<DashMap<String, Arc<dyn ToolDyn>>>,
}

impl<M> SwarmsAgent<M>
Expand Down Expand Up @@ -742,9 +749,9 @@ where
model,
system_prompt: system_prompt.into(),
config: AgentConfig::default(),
short_memory: AgentShortMemory::new(),
short_memory: OnceLock::new(),
tools: vec![],
tools_impl: DashMap::new(),
tools_impl: OnceLock::new(),
}
}

Expand Down Expand Up @@ -841,7 +848,7 @@ where
let results = Arc::clone(&results);
async move {
let tool = Arc::clone(
match self.tools_impl.get(&tool_call.name) {
match self.tools_impl().get(&tool_call.name) {
Some(tool) => tool,
None => {
tracing::error!("Tool not found: {}", tool_call.name);
Expand Down Expand Up @@ -885,7 +892,7 @@ where
} else {
for tool_call in all_tool_calls {
let tool = Arc::clone(
self.tools_impl
self.tools_impl()
.get(&tool_call.name)
.ok_or(AgentError::ToolNotFound(tool_call.name.clone()))?
.deref(),
Expand Down Expand Up @@ -970,7 +977,8 @@ where
let toolname = tool.name();
let definition = tool.definition();
self.tools.push(definition);
self.tools_impl.insert(toolname, Arc::new(tool));
self.tools_impl()
.insert(toolname, Arc::new(tool));
self
}

Expand Down Expand Up @@ -999,6 +1007,16 @@ where
});
}
}

#[inline]
pub fn memory(&self) -> &AgentShortMemory {
self.short_memory.get_or_init(AgentShortMemory::new)
}

#[inline]
pub fn tools_impl(&self) -> &DashMap<String, Arc<dyn ToolDyn>> {
self.tools_impl.get_or_init(|| DashMap::new())
}
}

impl<M> Agent for SwarmsAgent<M>
Expand All @@ -1020,7 +1038,7 @@ where
);
}

self.short_memory.add(
self.memory().add(
&task,
&self.config.name,
Role::User(self.config.user_name.clone()),
Expand Down Expand Up @@ -1156,7 +1174,7 @@ where
// }

// Generate response using LLM
let history = self.short_memory.0.get(&task).unwrap(); // Safety: task is in short_memory
let history = self.memory().0.get(&task).unwrap(); // Safety: task is in short_memory
let current_chat_response =
match self.chat(&current_prompt, history.deref()).await {
Ok(response) => response,
Expand Down Expand Up @@ -1261,7 +1279,7 @@ where
// Update the flag for the *next* iteration based on *this* iteration's call
was_prev_call_task_evaluator = is_task_evaluator_called && !task_complete;

self.short_memory.add(
self.memory().add(
&task,
&self.config.name,
Role::Assistant(self.config.name.to_owned()),
Expand Down Expand Up @@ -1331,7 +1349,7 @@ where

// TODO: More flexible output types, e.g. JSON, CSV, etc.
Ok(self
.short_memory
.memory()
.0
.get(&task)
.expect("Task should exist in short memory")
Expand Down Expand Up @@ -1383,7 +1401,7 @@ where
let plan = self.prompt(planning_prompt).await?;
tracing::debug!("Plan: {}", plan);
// Add plan to memory
self.short_memory.add(
self.memory().add(
task,
self.config.name.clone(),
Role::Assistant(self.config.name.clone()),
Expand Down Expand Up @@ -1416,7 +1434,7 @@ where
.join(format!("{}_{}", self.name(), task_hash))
.with_extension("json");

let json = serde_json::to_string_pretty(&self.short_memory.0.get(&task).unwrap())?; // TODO: Safety?
let json = serde_json::to_string_pretty(&self.memory().0.get(&task).unwrap())?; // TODO: Safety?
persistence::save_to_file(&json, path).await?;
}
Ok(())
Expand Down
2 changes: 2 additions & 0 deletions swarms-rs/src/structs/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use tokio::sync::broadcast;

#[derive(Debug, Error)]
pub enum AgentError {
#[error("Tool {0} not found")]
ConfigError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Serde json error: {0}")]
Expand Down
31 changes: 31 additions & 0 deletions swarms-rs/tests/example_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use anyhow::Result;
use swarms_rs::{llm::provider::openai::OpenAI, structs::agent::Agent};
use swarms_rs::agent::SwarmsAgentBuilder;
use swarms_rs::llm::request::ToolDefinition;
use swarms_rs::agent::SwarmsAgent;


#[tokio::test]
async fn test_basic_agent_functionality() -> Result<()> {
Expand Down Expand Up @@ -66,3 +70,30 @@ async fn test_agent_creation() -> Result<()> {
assert_eq!(agent.name(), "MockAgent");
Ok(())
}

#[tokio::test]
async fn test_lazy_initialized_for_SwarmsAgentBuilder() {
let mut builder = SwarmsAgentBuilder::new_with_model(OpenAI::new("mock-api-key".to_string()));

// Initially, tools_impl should not be initialized
assert!(
builder.tools_impl().is_empty(),
"tools_impl should be empty before first access"
);

{
// First borrow
let tools_map = builder.tools_impl();
assert!(
tools_map.is_empty(),
"tools_impl should be empty after first access"
);
} // tools_map goes out of scope here

// Now safe to re-borrow
let tools_map2 = builder.tools_impl();
assert!(
tools_map2.is_empty(),
"tools_impl should still be empty when no tools are added"
);
}
Loading