Skip to content
Open
2 changes: 1 addition & 1 deletion crates/agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ r2d2_sqlite.workspace = true
rand.workspace = true
regex.workspace = true
reqwest.workspace = true
rmcp = { version = "0.8.0", features = ["client", "transport-async-rw", "transport-child-process", "transport-io"] }
rmcp.workspace = true
rusqlite.workspace = true
rustls.workspace = true
rustls-native-certs.workspace = true
Expand Down
173 changes: 169 additions & 4 deletions crates/agent/src/agent/agent_config/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use serde::{
use super::types::ResourcePath;
use crate::agent::consts::DEFAULT_AGENT_NAME;
use crate::agent::tools::BuiltInToolName;
use crate::mcp::oauth_util::OAuthConfig;

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
Expand Down Expand Up @@ -214,14 +215,61 @@ pub struct McpServers {
pub mcp_servers: HashMap<String, McpServerConfig>,
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
#[derive(Debug, Clone, Serialize, JsonSchema)]
#[serde(tag = "type")]
pub enum McpServerConfig {
#[serde(rename = "stdio")]
Local(LocalMcpServerConfig),
StreamableHTTP(StreamableHTTPMcpServerConfig),
#[serde(rename = "http")]
Remote(RemoteMcpServerConfig),
}

impl<'de> Deserialize<'de> for McpServerConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;

// Helper enum with derived Deserialize to avoid infinite recursion
#[derive(Deserialize)]
#[serde(tag = "type")]
enum McpServerConfigHelper {
#[serde(rename = "stdio")]
Local(LocalMcpServerConfig),
#[serde(rename = "http")]
Remote(RemoteMcpServerConfig),
}

let value = serde_json::Value::deserialize(deserializer)?;

// Check if "type" field exists
if let Some(obj) = value.as_object() {
if !obj.contains_key("type") {
// If "type" is missing, default to "stdio" by adding it
let mut obj = obj.clone();
obj.insert("type".to_string(), serde_json::Value::String("stdio".to_string()));
let value_with_type = serde_json::Value::Object(obj);
let helper: McpServerConfigHelper =
serde_json::from_value(value_with_type).map_err(D::Error::custom)?;
return Ok(match helper {
McpServerConfigHelper::Local(config) => McpServerConfig::Local(config),
McpServerConfigHelper::Remote(config) => McpServerConfig::Remote(config),
});
}
}

// Normal deserialization with type field present
let helper: McpServerConfigHelper = serde_json::from_value(value).map_err(D::Error::custom)?;
Ok(match helper {
McpServerConfigHelper::Local(config) => McpServerConfig::Local(config),
McpServerConfigHelper::Remote(config) => McpServerConfig::Remote(config),
})
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct LocalMcpServerConfig {
/// The command string used to initialize the mcp server
pub command: String,
Expand All @@ -241,7 +289,8 @@ pub struct LocalMcpServerConfig {
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct StreamableHTTPMcpServerConfig {
#[serde(rename_all = "camelCase")]
pub struct RemoteMcpServerConfig {
/// The URL endpoint for HTTP-based MCP servers
pub url: String,
/// HTTP headers to include when communicating with HTTP-based MCP servers
Expand All @@ -251,6 +300,15 @@ pub struct StreamableHTTPMcpServerConfig {
#[serde(alias = "timeout")]
#[serde(default = "default_timeout")]
pub timeout_ms: u64,
/// OAuth scopes required for authentication with the remote MCP server
#[serde(default)]
pub oauth_scopes: Vec<String>,
/// OAuth configuration for this server
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth: Option<OAuthConfig>,
/// A boolean flag to denote whether or not to load this mcp server
#[serde(default)]
pub disabled: bool,
}

pub fn default_timeout() -> u64 {
Expand Down Expand Up @@ -392,4 +450,111 @@ mod tests {

let _: AgentConfig = serde_json::from_value(agent).unwrap();
}

#[test]
fn test_mcp_server_config_http_deser() {
// Test HTTP server without oauth scopes
let config = serde_json::json!({
"type": "http",
"url": "https://mcp.api.coingecko.com/sse"
});
let result: McpServerConfig = serde_json::from_value(config).unwrap();
match result {
McpServerConfig::Remote(remote) => {
assert_eq!(remote.url, "https://mcp.api.coingecko.com/sse");
assert!(remote.oauth_scopes.is_empty());
},
McpServerConfig::Local(_) => panic!("Expected Remote variant"),
}

// Test HTTP server with oauth scopes
let config = serde_json::json!({
"type": "http",
"url": "https://mcp.datadoghq.com/api/unstable/mcp-server/mcp",
"oauthScopes": ["mcp", "profile", "email"]
});
let result: McpServerConfig = serde_json::from_value(config).unwrap();
match result {
McpServerConfig::Remote(remote) => {
assert_eq!(remote.url, "https://mcp.datadoghq.com/api/unstable/mcp-server/mcp");
assert_eq!(remote.oauth_scopes, vec!["mcp", "profile", "email"]);
},
McpServerConfig::Local(_) => panic!("Expected Remote variant"),
}

// Test HTTP server with empty oauth scopes
let config = serde_json::json!({
"type": "http",
"url": "https://example-server.modelcontextprotocol.io/mcp",
"oauthScopes": []
});
let result: McpServerConfig = serde_json::from_value(config).unwrap();
match result {
McpServerConfig::Remote(remote) => {
assert_eq!(remote.url, "https://example-server.modelcontextprotocol.io/mcp");
assert!(remote.oauth_scopes.is_empty());
},
McpServerConfig::Local(_) => panic!("Expected Remote variant"),
}
}

#[test]
fn test_mcp_server_config_stdio_deser() {
let config = serde_json::json!({
"type": "stdio",
"command": "node",
"args": ["server.js"]
});
let result: McpServerConfig = serde_json::from_value(config).unwrap();
match result {
McpServerConfig::Local(local) => {
assert_eq!(local.command, "node");
assert_eq!(local.args, vec!["server.js"]);
},
McpServerConfig::Remote(_) => panic!("Expected Local variant"),
}
}

#[test]
fn test_mcp_server_config_defaults_to_stdio() {
// Test that when "type" field is missing, it defaults to "stdio" (Local variant)
let config = serde_json::json!({
"command": "node",
"args": ["server.js"]
});
let result: McpServerConfig = serde_json::from_value(config).unwrap();
match result {
McpServerConfig::Local(local) => {
assert_eq!(local.command, "node");
assert_eq!(local.args, vec!["server.js"]);
},
McpServerConfig::Remote(_) => panic!("Expected Local variant when type field is missing"),
}
}

#[test]
fn test_mcp_servers_map_deser() {
let servers = serde_json::json!({
"coin-gecko": {
"type": "http",
"url": "https://mcp.api.coingecko.com/sse"
},
"datadog": {
"type": "http",
"url": "https://mcp.datadoghq.com/api/unstable/mcp-server/mcp",
"oauthScopes": ["mcp", "profile", "email"]
},
"local-server": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
}
});

let result: HashMap<String, McpServerConfig> = serde_json::from_value(servers).unwrap();
assert_eq!(result.len(), 3);
assert!(result.contains_key("coin-gecko"));
assert!(result.contains_key("datadog"));
assert!(result.contains_key("local-server"));
}
}
2 changes: 2 additions & 0 deletions crates/agent/src/agent/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub const MAX_TOOL_NAME_LEN: usize = 64;

pub const MAX_TOOL_SPEC_DESCRIPTION_LEN: usize = 10_004;

pub const DEFAULT_MCP_CREDENTIAL_PATH: &str = "~/.aws/sso/cache";

/// 10 MB
pub const MAX_IMAGE_SIZE_BYTES: u64 = 10 * 1024 * 1024;

Expand Down
39 changes: 23 additions & 16 deletions crates/agent/src/agent/mcp/actor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -50,14 +51,9 @@ pub enum McpMessage {
pub struct McpServerActorHandle {
_server_name: String,
sender: RequestSender<McpServerActorRequest, McpServerActorResponse, McpServerActorError>,
event_rx: mpsc::Receiver<McpServerActorEvent>,
}

impl McpServerActorHandle {
pub async fn recv(&mut self) -> Option<McpServerActorEvent> {
self.event_rx.recv().await
}

pub async fn get_tool_specs(&self) -> Result<Vec<ToolSpec>, McpServerActorError> {
match self
.sender
Expand Down Expand Up @@ -153,6 +149,7 @@ impl From<ServiceError> for McpServerActorError {
pub enum McpServerActorEvent {
/// The MCP server has launched successfully
Initialized {
server_name: String,
/// Time taken to launch the server
serve_duration: Duration,
/// Time taken to list all tools.
Expand All @@ -165,7 +162,9 @@ pub enum McpServerActorEvent {
list_prompts_duration: Option<Duration>,
},
/// The MCP server failed to initialize successfully
InitializeError(String),
InitializeError { server_name: String, error: String },
/// An OAuth authentication request from the MCP server
OauthRequest { server_name: String, oauth_url: String },
}

#[derive(Debug)]
Expand Down Expand Up @@ -195,34 +194,38 @@ pub struct McpServerActor {

impl McpServerActor {
/// Spawns an actor to manage the MCP server, returning a [McpServerActorHandle].
pub fn spawn(server_name: String, config: McpServerConfig) -> McpServerActorHandle {
let (event_tx, event_rx) = mpsc::channel(32);
pub fn spawn(
server_name: String,
config: McpServerConfig,
cred_path: PathBuf,
event_tx: mpsc::Sender<McpServerActorEvent>,
) -> McpServerActorHandle {
let (req_tx, req_rx) = new_request_channel();

let server_name_clone = server_name.clone();
tokio::spawn(async move { Self::launch(server_name_clone, config, req_rx, event_tx).await });
tokio::spawn(async move { Self::launch(server_name_clone, config, cred_path, req_rx, event_tx).await });

McpServerActorHandle {
_server_name: server_name,
sender: req_tx,
event_rx,
}
}

async fn launch(
server_name: String,
config: McpServerConfig,
cred_path: PathBuf,
req_rx: RequestReceiver<McpServerActorRequest, McpServerActorResponse, McpServerActorError>,
event_tx: mpsc::Sender<McpServerActorEvent>,
) {
let (message_tx, message_rx) = mpsc::channel(32);
match McpService::new(server_name.clone(), config.clone(), message_tx.clone())
.launch()
match McpService::new(server_name.clone(), config.clone(), cred_path, message_tx.clone())
.launch(&event_tx)
.await
{
Ok((service_handle, launch_md)) => {
let s = Self {
server_name,
server_name: server_name.clone(),
_config: config,
tools: launch_md.tools.unwrap_or_default(),
prompts: launch_md.prompts.unwrap_or_default(),
Expand All @@ -237,6 +240,7 @@ impl McpServerActor {
let _ = s
.event_tx
.send(McpServerActorEvent::Initialized {
server_name,
serve_duration: launch_md.serve_time_taken,
list_tools_duration: launch_md.list_tools_duration,
list_prompts_duration: launch_md.list_prompts_duration,
Expand All @@ -246,7 +250,10 @@ impl McpServerActor {
},
Err(err) => {
let _ = event_tx
.send(McpServerActorEvent::InitializeError(err.to_string()))
.send(McpServerActorEvent::InitializeError {
server_name,
error: err.to_string(),
})
.await;
},
}
Expand Down Expand Up @@ -340,7 +347,7 @@ impl McpServerActor {
let service_handle = self.service_handle.clone();
let tx = self.message_tx.clone();
tokio::spawn(async move {
let res = service_handle.list_tools().await;
let res = service_handle.list_all_tools().await;
let _ = tx.send(McpMessage::Tools(res)).await;
});
}
Expand All @@ -351,7 +358,7 @@ impl McpServerActor {
let service_handle = self.service_handle.clone();
let tx = self.message_tx.clone();
tokio::spawn(async move {
let res = service_handle.list_prompts().await;
let res = service_handle.list_all_prompts().await;
let _ = tx.send(McpMessage::Prompts(res)).await;
});
}
Expand Down
Loading