diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 08f219024e..b2de0af897 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -51,7 +51,7 @@ r2d2_sqlite.workspace = true rand.workspace = true regex.workspace = true reqwest.workspace = true -rmcp = { version = "0.8.0", features = ["client", "transport-async-rw", "transport-child-process", "transport-io"] } +rmcp.workspace = true rusqlite.workspace = true rustls.workspace = true rustls-native-certs.workspace = true diff --git a/crates/agent/src/agent/agent_config/definitions.rs b/crates/agent/src/agent/agent_config/definitions.rs index 9532e16c75..ee58877490 100644 --- a/crates/agent/src/agent/agent_config/definitions.rs +++ b/crates/agent/src/agent/agent_config/definitions.rs @@ -12,6 +12,7 @@ use serde::{ use super::types::ResourcePath; use crate::agent::consts::DEFAULT_AGENT_NAME; use crate::agent::tools::BuiltInToolName; +use crate::mcp::oauth_util::OAuthConfig; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(untagged)] @@ -214,14 +215,61 @@ pub struct McpServers { pub mcp_servers: HashMap, } -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] -#[serde(untagged)] +#[derive(Debug, Clone, Serialize, JsonSchema)] +#[serde(tag = "type")] pub enum McpServerConfig { + #[serde(rename = "stdio")] Local(LocalMcpServerConfig), - StreamableHTTP(StreamableHTTPMcpServerConfig), + #[serde(rename = "http")] + Remote(RemoteMcpServerConfig), +} + +impl<'de> Deserialize<'de> for McpServerConfig { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + // Helper enum with derived Deserialize to avoid infinite recursion + #[derive(Deserialize)] + #[serde(tag = "type")] + enum McpServerConfigHelper { + #[serde(rename = "stdio")] + Local(LocalMcpServerConfig), + #[serde(rename = "http")] + Remote(RemoteMcpServerConfig), + } + + let value = serde_json::Value::deserialize(deserializer)?; + + // Check if "type" field exists + if let Some(obj) = value.as_object() { + if !obj.contains_key("type") { + // If "type" is missing, default to "stdio" by adding it + let mut obj = obj.clone(); + obj.insert("type".to_string(), serde_json::Value::String("stdio".to_string())); + let value_with_type = serde_json::Value::Object(obj); + let helper: McpServerConfigHelper = + serde_json::from_value(value_with_type).map_err(D::Error::custom)?; + return Ok(match helper { + McpServerConfigHelper::Local(config) => McpServerConfig::Local(config), + McpServerConfigHelper::Remote(config) => McpServerConfig::Remote(config), + }); + } + } + + // Normal deserialization with type field present + let helper: McpServerConfigHelper = serde_json::from_value(value).map_err(D::Error::custom)?; + Ok(match helper { + McpServerConfigHelper::Local(config) => McpServerConfig::Local(config), + McpServerConfigHelper::Remote(config) => McpServerConfig::Remote(config), + }) + } } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct LocalMcpServerConfig { /// The command string used to initialize the mcp server pub command: String, @@ -241,7 +289,8 @@ pub struct LocalMcpServerConfig { } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] -pub struct StreamableHTTPMcpServerConfig { +#[serde(rename_all = "camelCase")] +pub struct RemoteMcpServerConfig { /// The URL endpoint for HTTP-based MCP servers pub url: String, /// HTTP headers to include when communicating with HTTP-based MCP servers @@ -251,6 +300,15 @@ pub struct StreamableHTTPMcpServerConfig { #[serde(alias = "timeout")] #[serde(default = "default_timeout")] pub timeout_ms: u64, + /// OAuth scopes required for authentication with the remote MCP server + #[serde(default)] + pub oauth_scopes: Vec, + /// OAuth configuration for this server + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth: Option, + /// A boolean flag to denote whether or not to load this mcp server + #[serde(default)] + pub disabled: bool, } pub fn default_timeout() -> u64 { @@ -392,4 +450,111 @@ mod tests { let _: AgentConfig = serde_json::from_value(agent).unwrap(); } + + #[test] + fn test_mcp_server_config_http_deser() { + // Test HTTP server without oauth scopes + let config = serde_json::json!({ + "type": "http", + "url": "https://mcp.api.coingecko.com/sse" + }); + let result: McpServerConfig = serde_json::from_value(config).unwrap(); + match result { + McpServerConfig::Remote(remote) => { + assert_eq!(remote.url, "https://mcp.api.coingecko.com/sse"); + assert!(remote.oauth_scopes.is_empty()); + }, + McpServerConfig::Local(_) => panic!("Expected Remote variant"), + } + + // Test HTTP server with oauth scopes + let config = serde_json::json!({ + "type": "http", + "url": "https://mcp.datadoghq.com/api/unstable/mcp-server/mcp", + "oauthScopes": ["mcp", "profile", "email"] + }); + let result: McpServerConfig = serde_json::from_value(config).unwrap(); + match result { + McpServerConfig::Remote(remote) => { + assert_eq!(remote.url, "https://mcp.datadoghq.com/api/unstable/mcp-server/mcp"); + assert_eq!(remote.oauth_scopes, vec!["mcp", "profile", "email"]); + }, + McpServerConfig::Local(_) => panic!("Expected Remote variant"), + } + + // Test HTTP server with empty oauth scopes + let config = serde_json::json!({ + "type": "http", + "url": "https://example-server.modelcontextprotocol.io/mcp", + "oauthScopes": [] + }); + let result: McpServerConfig = serde_json::from_value(config).unwrap(); + match result { + McpServerConfig::Remote(remote) => { + assert_eq!(remote.url, "https://example-server.modelcontextprotocol.io/mcp"); + assert!(remote.oauth_scopes.is_empty()); + }, + McpServerConfig::Local(_) => panic!("Expected Remote variant"), + } + } + + #[test] + fn test_mcp_server_config_stdio_deser() { + let config = serde_json::json!({ + "type": "stdio", + "command": "node", + "args": ["server.js"] + }); + let result: McpServerConfig = serde_json::from_value(config).unwrap(); + match result { + McpServerConfig::Local(local) => { + assert_eq!(local.command, "node"); + assert_eq!(local.args, vec!["server.js"]); + }, + McpServerConfig::Remote(_) => panic!("Expected Local variant"), + } + } + + #[test] + fn test_mcp_server_config_defaults_to_stdio() { + // Test that when "type" field is missing, it defaults to "stdio" (Local variant) + let config = serde_json::json!({ + "command": "node", + "args": ["server.js"] + }); + let result: McpServerConfig = serde_json::from_value(config).unwrap(); + match result { + McpServerConfig::Local(local) => { + assert_eq!(local.command, "node"); + assert_eq!(local.args, vec!["server.js"]); + }, + McpServerConfig::Remote(_) => panic!("Expected Local variant when type field is missing"), + } + } + + #[test] + fn test_mcp_servers_map_deser() { + let servers = serde_json::json!({ + "coin-gecko": { + "type": "http", + "url": "https://mcp.api.coingecko.com/sse" + }, + "datadog": { + "type": "http", + "url": "https://mcp.datadoghq.com/api/unstable/mcp-server/mcp", + "oauthScopes": ["mcp", "profile", "email"] + }, + "local-server": { + "type": "stdio", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + }); + + let result: HashMap = serde_json::from_value(servers).unwrap(); + assert_eq!(result.len(), 3); + assert!(result.contains_key("coin-gecko")); + assert!(result.contains_key("datadog")); + assert!(result.contains_key("local-server")); + } } diff --git a/crates/agent/src/agent/agent_config/mod.rs b/crates/agent/src/agent/agent_config/mod.rs index 5f6d5efc01..5fb755dd5d 100644 --- a/crates/agent/src/agent/agent_config/mod.rs +++ b/crates/agent/src/agent/agent_config/mod.rs @@ -277,6 +277,13 @@ impl LoadedMcpServerConfig { source, } } + + pub fn is_enabled(&self) -> bool { + match &self.config { + McpServerConfig::Local(local_mcp_server_config) => !local_mcp_server_config.disabled, + McpServerConfig::Remote(remote_mcp_server_config) => !remote_mcp_server_config.disabled, + } + } } #[derive(Debug, Clone)] diff --git a/crates/agent/src/agent/consts.rs b/crates/agent/src/agent/consts.rs index d5bc44fcbb..258965a868 100644 --- a/crates/agent/src/agent/consts.rs +++ b/crates/agent/src/agent/consts.rs @@ -13,6 +13,8 @@ pub const MAX_TOOL_NAME_LEN: usize = 64; pub const MAX_TOOL_SPEC_DESCRIPTION_LEN: usize = 10_004; +pub const DEFAULT_MCP_CREDENTIAL_PATH: &str = "~/.aws/sso/cache"; + /// 10 MB pub const MAX_IMAGE_SIZE_BYTES: u64 = 10 * 1024 * 1024; diff --git a/crates/agent/src/agent/mcp/actor.rs b/crates/agent/src/agent/mcp/actor.rs index be986a2550..d924b2019b 100644 --- a/crates/agent/src/agent/mcp/actor.rs +++ b/crates/agent/src/agent/mcp/actor.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -50,14 +51,9 @@ pub enum McpMessage { pub struct McpServerActorHandle { _server_name: String, sender: RequestSender, - event_rx: mpsc::Receiver, } impl McpServerActorHandle { - pub async fn recv(&mut self) -> Option { - self.event_rx.recv().await - } - pub async fn get_tool_specs(&self) -> Result, McpServerActorError> { match self .sender @@ -153,6 +149,7 @@ impl From for McpServerActorError { pub enum McpServerActorEvent { /// The MCP server has launched successfully Initialized { + server_name: String, /// Time taken to launch the server serve_duration: Duration, /// Time taken to list all tools. @@ -165,7 +162,9 @@ pub enum McpServerActorEvent { list_prompts_duration: Option, }, /// The MCP server failed to initialize successfully - InitializeError(String), + InitializeError { server_name: String, error: String }, + /// An OAuth authentication request from the MCP server + OauthRequest { server_name: String, oauth_url: String }, } #[derive(Debug)] @@ -195,34 +194,38 @@ pub struct McpServerActor { impl McpServerActor { /// Spawns an actor to manage the MCP server, returning a [McpServerActorHandle]. - pub fn spawn(server_name: String, config: McpServerConfig) -> McpServerActorHandle { - let (event_tx, event_rx) = mpsc::channel(32); + pub fn spawn( + server_name: String, + config: McpServerConfig, + cred_path: PathBuf, + event_tx: mpsc::Sender, + ) -> McpServerActorHandle { let (req_tx, req_rx) = new_request_channel(); let server_name_clone = server_name.clone(); - tokio::spawn(async move { Self::launch(server_name_clone, config, req_rx, event_tx).await }); + tokio::spawn(async move { Self::launch(server_name_clone, config, cred_path, req_rx, event_tx).await }); McpServerActorHandle { _server_name: server_name, sender: req_tx, - event_rx, } } async fn launch( server_name: String, config: McpServerConfig, + cred_path: PathBuf, req_rx: RequestReceiver, event_tx: mpsc::Sender, ) { let (message_tx, message_rx) = mpsc::channel(32); - match McpService::new(server_name.clone(), config.clone(), message_tx.clone()) - .launch() + match McpService::new(server_name.clone(), config.clone(), cred_path, message_tx.clone()) + .launch(&event_tx) .await { Ok((service_handle, launch_md)) => { let s = Self { - server_name, + server_name: server_name.clone(), _config: config, tools: launch_md.tools.unwrap_or_default(), prompts: launch_md.prompts.unwrap_or_default(), @@ -237,6 +240,7 @@ impl McpServerActor { let _ = s .event_tx .send(McpServerActorEvent::Initialized { + server_name, serve_duration: launch_md.serve_time_taken, list_tools_duration: launch_md.list_tools_duration, list_prompts_duration: launch_md.list_prompts_duration, @@ -246,7 +250,10 @@ impl McpServerActor { }, Err(err) => { let _ = event_tx - .send(McpServerActorEvent::InitializeError(err.to_string())) + .send(McpServerActorEvent::InitializeError { + server_name, + error: err.to_string(), + }) .await; }, } @@ -340,7 +347,7 @@ impl McpServerActor { let service_handle = self.service_handle.clone(); let tx = self.message_tx.clone(); tokio::spawn(async move { - let res = service_handle.list_tools().await; + let res = service_handle.list_all_tools().await; let _ = tx.send(McpMessage::Tools(res)).await; }); } @@ -351,7 +358,7 @@ impl McpServerActor { let service_handle = self.service_handle.clone(); let tx = self.message_tx.clone(); tokio::spawn(async move { - let res = service_handle.list_prompts().await; + let res = service_handle.list_all_prompts().await; let _ = tx.send(McpMessage::Prompts(res)).await; }); } diff --git a/crates/agent/src/agent/mcp/mod.rs b/crates/agent/src/agent/mcp/mod.rs index ee58dba629..c07a64df7b 100644 --- a/crates/agent/src/agent/mcp/mod.rs +++ b/crates/agent/src/agent/mcp/mod.rs @@ -1,8 +1,116 @@ -mod actor; +//! # MCP (Model Context Protocol) Module +//! +//! This module provides a manager for launching and interacting with multiple MCP servers. +//! It implements a multi-layered architecture with asynchronous communication between components. +//! +//! ## Architecture Overview +//! +//! The module consists of the following key constructs organized in multiple layers: +//! +//! ### Management Layer +//! +//! - **[`McpManager`]**: The central manager that runs in its own async task. It maintains the +//! lifecycle of multiple MCP server instances and routes requests to the appropriate servers. +//! +//! - **[`McpManagerHandle`]**: A cloneable handle for interacting with the `McpManager` from other +//! parts of the application. It provides a safe, async API for launching servers, querying tool +//! specifications, executing tools, and receiving server events. +//! +//! ### Actor Layer +//! +//! - **[`McpServerActor`]** (in [`actor`] module): Individual server actors that manage the +//! lifecycle of a single MCP server process. Each actor handles initialization, tool execution, +//! and communication with its associated server. +//! +//! - **[`McpServerActorHandle`]** (in [`actor`] module): A handle for interacting with a specific +//! `McpServerActor`. Used internally by `McpManager` to communicate with servers. +//! +//! ### Service Layer +//! +//! - **`McpService`** (in `service` module): Implements the `rmcp::Service` trait to handle +//! server-to-client requests and notifications. Created during server launch and consumed by the +//! rmcp crate. +//! +//! - **`RunningMcpService`** (in `service` module): A handle to a running MCP server that wraps the +//! rmcp service. Provides methods for calling tools, listing tools/prompts, and handles +//! authentication/token refresh for remote servers. +//! +//! - **`rmcp::RunningService`** (from rmcp crate): The underlying service from the rmcp library +//! that handles the actual MCP protocol communication over stdio (for local servers) or HTTP (for +//! remote servers). +//! +//! ## Communication Patterns +//! +//! The module uses two primary communication patterns: +//! +//! ### 1. Request/Response Pattern +//! +//! ```text +//! McpManagerHandle McpManager McpServerActor RunningMcpService rmcp::RunningService +//! | | | | | +//! |--[LaunchServer]--->| | | | +//! | |----[spawn]----->| | | +//! | | |--[McpService]--->| | +//! | | | |--[serve]----------->| +//! |<--[response]-------| (initializing) | | | +//! | | |<--[initialized]--| | +//! | | | | | +//! |--[GetToolSpecs]--->| | | | +//! | |--[get_tools]--->| | | +//! | | | (returns cached) | | +//! | |<--[tools]-------| | | +//! |<--[tools]----------| | | | +//! | | | | | +//! |--[ExecuteTool]---->| | | | +//! | |--[execute]----->| | | +//! | | |--[call_tool]---->| | +//! | | | |--[call_tool]------->| +//! |<--[oneshot rx]-----| | | | +//! | | | |<--[result]----------| +//! | | |<--[result]-------| | +//! |<--[result via rx]------------------------[async]--------| | +//! ``` +//! +//! ### 2. Event Broadcasting Pattern +//! +//! ```text +//! McpServerActor McpManager McpManagerHandle +//! | | | +//! |--[Initialized event]---->| | +//! | |--[forward event]------->| +//! | | (moves server from | +//! | | initializing_servers | +//! | | to servers HashMap) | +//! | | | +//! |--[OauthRequest event]--->| | +//! | |--[forward event]------->| +//! | | | +//! |--[InitializeError]------>| | +//! | |--[forward event]------->| +//! | | (removes from | +//! | | initializing_servers) | +//! ``` +//! +//! ## Server Lifecycle +//! +//! MCP servers go through the following states: +//! +//! 1. **Not Launched**: Server configuration exists but no actor has been spawned +//! 2. **Initializing**: `McpServerActor` has been spawned and is stored in +//! `McpManager::initializing_servers`. The actor is establishing connection and fetching initial +//! metadata (tools, prompts) +//! 3. **Initialized**: Server is ready and stored in `McpManager::servers`. Tools can now be +//! executed +//! 4. **Error**: Initialization failed, server is removed from `initializing_servers` + +pub mod actor; +pub mod oauth_util; mod service; pub mod types; use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Duration; use actor::{ McpServerActor, @@ -10,23 +118,30 @@ use actor::{ McpServerActorEvent, McpServerActorHandle, }; -use futures::stream::FuturesUnordered; use rmcp::model::CallToolResult; use serde::{ Deserialize, Serialize, }; use serde_json::Value; -use tokio::sync::oneshot; -use tokio_stream::StreamExt as _; +use tokio::sync::broadcast::error::RecvError; +use tokio::sync::{ + broadcast, + mpsc, + oneshot, +}; use tracing::{ debug, error, + info, warn, }; use types::Prompt; use super::agent_loop::types::ToolSpec; +use super::consts::DEFAULT_MCP_CREDENTIAL_PATH; +use super::util::path::expand_path; +use super::util::providers::RealProvider; use super::util::request_channel::{ RequestReceiver, new_request_channel, @@ -37,24 +152,40 @@ use crate::agent::util::request_channel::{ respond, }; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct McpManagerHandle { /// Sender for sending requests to the tool manager task - sender: RequestSender, + request_tx: RequestSender, + mcp_main_loop_to_handle_server_event_rx: broadcast::Receiver, +} + +impl Clone for McpManagerHandle { + fn clone(&self) -> Self { + Self { + request_tx: self.request_tx.clone(), + mcp_main_loop_to_handle_server_event_rx: self.mcp_main_loop_to_handle_server_event_rx.resubscribe(), + } + } } impl McpManagerHandle { - fn new(sender: RequestSender) -> Self { - Self { sender } + fn new( + request_tx: RequestSender, + mcp_main_loop_to_handle_server_event_rx: broadcast::Receiver, + ) -> Self { + Self { + request_tx, + mcp_main_loop_to_handle_server_event_rx, + } } pub async fn launch_server( - &self, + &mut self, name: String, config: McpServerConfig, ) -> Result, McpManagerError> { match self - .sender + .request_tx .send_recv(McpManagerRequest::LaunchServer { server_name: name, config, @@ -72,7 +203,7 @@ impl McpManagerHandle { pub async fn get_tool_specs(&self, server_name: String) -> Result, McpManagerError> { match self - .sender + .request_tx .send_recv(McpManagerRequest::GetToolSpecs { server_name }) .await .unwrap_or(Err(McpManagerError::Channel))? @@ -87,7 +218,7 @@ impl McpManagerHandle { pub async fn get_prompts(&self, server_name: String) -> Result, McpManagerError> { match self - .sender + .request_tx .send_recv(McpManagerRequest::GetPrompts { server_name }) .await .unwrap_or(Err(McpManagerError::Channel))? @@ -107,7 +238,7 @@ impl McpManagerHandle { args: Option>, ) -> Result, McpManagerError> { match self - .sender + .request_tx .send_recv(McpManagerRequest::ExecuteTool { server_name, tool_name, @@ -123,55 +254,66 @@ impl McpManagerHandle { ))), } } + + pub async fn recv(&mut self) -> Result { + self.mcp_main_loop_to_handle_server_event_rx + .recv() + .await + .map(|evt| evt.into()) + } } #[derive(Debug)] pub struct McpManager { request_tx: RequestSender, request_rx: RequestReceiver, + server_event_tx: mpsc::Sender, + server_event_rx: mpsc::Receiver, + + cred_path: PathBuf, initializing_servers: HashMap)>, servers: HashMap, + event_buf: Vec, } impl McpManager { - pub fn new() -> Self { + pub fn new(cred_path: PathBuf) -> Self { let (request_tx, request_rx) = new_request_channel(); + let (server_event_tx, server_event_rx) = mpsc::channel::(100); + Self { request_tx, request_rx, + server_event_tx, + server_event_rx, + cred_path, initializing_servers: HashMap::new(), servers: HashMap::new(), + event_buf: Vec::::new(), } } pub fn spawn(self) -> McpManagerHandle { let request_tx = self.request_tx.clone(); + let (mcp_main_loop_to_handle_server_event_tx, mcp_main_loop_to_handle_server_event_rx) = + broadcast::channel::(100); tokio::spawn(async move { - self.main_loop().await; + self.main_loop(mcp_main_loop_to_handle_server_event_tx).await; }); - McpManagerHandle::new(request_tx) + McpManagerHandle::new(request_tx, mcp_main_loop_to_handle_server_event_rx) } - async fn main_loop(mut self) { + async fn main_loop(mut self, mcp_main_loop_to_handle_server_event_tx: broadcast::Sender) { loop { - let mut initializing_servers = FuturesUnordered::new(); - for (name, (handle, _)) in &mut self.initializing_servers { - let name_clone = name.clone(); - initializing_servers.push(async { (name_clone, handle.recv().await) }); - } - let mut initialized_servers = FuturesUnordered::new(); - for (name, handle) in &mut self.servers { - let name_clone = name.clone(); - initialized_servers.push(async { (name_clone, handle.recv().await) }); - } + self.event_buf + .drain(..) + .for_each(|evt| _ = mcp_main_loop_to_handle_server_event_tx.send(evt)); tokio::select! { req = self.request_rx.recv() => { - std::mem::drop(initializing_servers); - std::mem::drop(initialized_servers); let Some(req) = req else { warn!("Tool manager request channel has closed, exiting"); break; @@ -179,20 +321,11 @@ impl McpManager { let res = self.handle_mcp_manager_request(req.payload).await; respond!(req, res); }, - res = initializing_servers.next(), if !initializing_servers.is_empty() => { - std::mem::drop(initializing_servers); - std::mem::drop(initialized_servers); - if let Some((name, evt)) = res { - self.handle_initializing_mcp_actor_event(name, evt).await; + res = self.server_event_rx.recv() => { + if let Some(evt) = res { + self.handle_mcp_actor_event(evt); } - }, - res = initialized_servers.next(), if !initialized_servers.is_empty() => { - std::mem::drop(initializing_servers); - std::mem::drop(initialized_servers); - if let Some((name, evt)) = res { - self.handle_mcp_actor_event(name, evt).await; - } - }, + } } } } @@ -212,8 +345,10 @@ impl McpManager { } else if self.servers.contains_key(&name) { return Err(McpManagerError::ServerAlreadyLaunched { name }); } + let event_tx = self.server_event_tx.clone(); + let handle = McpServerActor::spawn(name.clone(), config, self.cred_path.clone(), event_tx); let (tx, rx) = oneshot::channel(); - let handle = McpServerActor::spawn(name.clone(), config); + self.initializing_servers.insert(name, (handle, tx)); Ok(McpManagerResponse::LaunchServer(rx)) }, @@ -238,44 +373,50 @@ impl McpManager { } } - async fn handle_mcp_actor_event(&mut self, server_name: String, evt: Option) { - debug!(?server_name, ?evt, "Received event from an MCP actor"); - debug_assert!(self.servers.contains_key(&server_name)); - } + fn handle_mcp_actor_event(&mut self, evt: McpServerActorEvent) { + // TODO: keep a record of all the different server events received in this layer? + match &evt { + McpServerActorEvent::Initialized { + server_name, + serve_duration: _, + list_tools_duration: _, + list_prompts_duration: _, + } => { + let Some((handle, result_tx)) = self.initializing_servers.remove(server_name) else { + warn!(?server_name, ?evt, "event was not from an initializing MCP server"); + return; + }; + + if let Err(e) = result_tx.send(Ok(())) { + error!(?server_name, ?e, "failed to send server initialized message"); + } - async fn handle_initializing_mcp_actor_event(&mut self, server_name: String, evt: Option) { - debug!(?server_name, ?evt, "Received event from initializing MCP actor"); - debug_assert!(self.initializing_servers.contains_key(&server_name)); - - let Some((handle, tx)) = self.initializing_servers.remove(&server_name) else { - warn!(?server_name, ?evt, "event was not from an initializing MCP server"); - return; - }; - - // Event should always exist, otherwise indicates a bug with the initialization logic. - let Some(evt) = evt else { - let _ = tx.send(Err(McpManagerError::Custom("Server channel closed".to_string()))); - self.initializing_servers.remove(&server_name); - return; - }; - - // First event from an initializing server should only be either of these Initialize variants. - match evt { - McpServerActorEvent::Initialized { .. } => { - let _ = tx.send(Ok(())); - self.servers.insert(server_name, handle); + if self.servers.insert(server_name.clone(), handle).is_some() { + warn!(?server_name, "duplicated server. old server dropped"); + } + }, + McpServerActorEvent::InitializeError { server_name, error } => { + if let Some((_, result_tx)) = self.initializing_servers.remove(server_name) { + if let Err(e) = result_tx.send(Err(McpManagerError::Custom(error.clone()))) { + error!(?server_name, ?e, "failed to send server initialized message"); + } + } }, - McpServerActorEvent::InitializeError(msg) => { - let _ = tx.send(Err(McpManagerError::Custom(msg))); - self.initializing_servers.remove(&server_name); + McpServerActorEvent::OauthRequest { server_name, oauth_url } => { + info!(?server_name, ?oauth_url, "received oauth request"); }, } + self.event_buf.push(evt); } } impl Default for McpManager { fn default() -> Self { - Self::new() + let expanded_path = + expand_path(DEFAULT_MCP_CREDENTIAL_PATH, &RealProvider).expect("failed to expand default credential path"); + let default_path = PathBuf::from(expanded_path.as_ref()); + + Self::new(default_path) } } @@ -327,3 +468,49 @@ pub enum McpManagerError { #[error("{}", .0)] Custom(String), } + +/// MCP events relevant to agent operations. +/// Provides abstraction over [McpServerActorEvent] to avoid leaking implementation details. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum McpServerEvent { + /// The MCP server has launched successfully + Initialized { + server_name: String, + /// Time taken to launch the server + serve_duration: Duration, + /// Time taken to list all tools. + /// + /// None if the server does not support tools, or there was an error fetching tools. + list_tools_duration: Option, + /// Time taken to list all prompts + /// + /// None if the server does not support prompts, or there was an error fetching prompts. + list_prompts_duration: Option, + }, + /// The MCP server failed to initialize successfully + InitializeError { server_name: String, error: String }, + /// An OAuth authentication request from the MCP server + OauthRequest { server_name: String, oauth_url: String }, +} + +impl From for McpServerEvent { + fn from(value: McpServerActorEvent) -> Self { + match value { + McpServerActorEvent::Initialized { + server_name, + serve_duration, + list_tools_duration, + list_prompts_duration, + } => Self::Initialized { + server_name, + serve_duration, + list_tools_duration, + list_prompts_duration, + }, + McpServerActorEvent::InitializeError { server_name, error } => Self::InitializeError { server_name, error }, + McpServerActorEvent::OauthRequest { server_name, oauth_url } => { + Self::OauthRequest { server_name, oauth_url } + }, + } + } +} diff --git a/crates/agent/src/agent/mcp/oauth_util.rs b/crates/agent/src/agent/mcp/oauth_util.rs new file mode 100644 index 0000000000..33df2ce15c --- /dev/null +++ b/crates/agent/src/agent/mcp/oauth_util.rs @@ -0,0 +1,753 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::path::{ + Path, + PathBuf, +}; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; + +use http::{ + HeaderMap, + StatusCode, +}; +use http_body_util::Full; +use hyper::Response; +use hyper::body::Bytes; +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use reqwest::Client; +use rmcp::service::{ + DynService, + ServiceExt, +}; +use rmcp::transport::auth::{ + AuthClient, + OAuthClientConfig, + OAuthState, + OAuthTokenResponse, +}; +use rmcp::transport::sse_client::SseClientConfig; +use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; +use rmcp::transport::{ + AuthorizationManager, + AuthorizationSession, + SseClientTransport, + StreamableHttpClientTransport, +}; +use rmcp::{ + RoleClient, + Service, + serde_json, +}; +use schemars::JsonSchema; +use serde::{ + Deserialize, + Serialize, +}; +use sha2::{ + Digest, + Sha256, +}; +use tokio::sync::mpsc; +use tokio::sync::oneshot::Sender; +use tokio_util::sync::CancellationToken; +use tracing::{ + debug, + error, + info, +}; +use url::Url; + +use super::actor::McpServerActorEvent; + +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct OAuthConfig { + /// Custom redirect URI for OAuth flow (e.g., "127.0.0.1:7778") + /// If not specified, a random available port will be assigned by the OS + #[serde(skip_serializing_if = "Option::is_none")] + pub redirect_uri: Option, +} + +#[derive(Debug, thiserror::Error)] +pub enum OauthUtilError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Parse(#[from] url::ParseError), + #[error(transparent)] + Auth(#[from] rmcp::transport::AuthError), + #[error(transparent)] + Serde(#[from] serde_json::Error), + #[error("Missing authorization manager")] + MissingAuthorizationManager, + #[error("Missing auth client when token refresh is needed")] + MissingAuthClient, + #[error(transparent)] + OneshotRecv(#[from] tokio::sync::oneshot::error::RecvError), + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + #[error("{0}")] + Http(String), + #[error("Malformed directory")] + MalformDirectory, + #[error("Missing credential")] + MissingCredentials, + #[error("Failed to create a running service after running through all fallbacks: {0}")] + ServiceNotObtained(String), + #[error("{0}")] + SseTransport(String), +} + +/// A guard that automatically cancels the cancellation token when dropped. +/// This ensures that the OAuth loopback server is properly cleaned up +/// when the guard goes out of scope. +struct LoopBackDropGuard { + cancellation_token: CancellationToken, +} + +impl Drop for LoopBackDropGuard { + fn drop(&mut self) { + self.cancellation_token.cancel(); + } +} + +/// OAuth Authorization Server metadata for endpoint discovery +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct OAuthMeta { + pub authorization_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: Option, +} + +/// This is modeled after [OAuthClientConfig] +/// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Registration { + pub client_id: String, + pub client_secret: Option, + pub scopes: Vec, + pub redirect_uri: String, +} + +impl From for Registration { + fn from(value: OAuthClientConfig) -> Self { + Self { + client_id: value.client_id, + client_secret: value.client_secret, + scopes: value.scopes, + redirect_uri: value.redirect_uri, + } + } +} + +/// A wrapper that manages an authenticated MCP client. +/// +/// This struct wraps an `AuthClient` and provides access to OAuth credentials +/// for MCP server connections that require authentication. The credentials +/// are managed separately from this wrapper's lifecycle. +#[derive(Clone, Debug)] +pub struct AuthClientWrapper { + pub cred_full_path: PathBuf, + pub auth_client: AuthClient, +} + +impl AuthClientWrapper { + pub fn new(cred_full_path: PathBuf, auth_client: AuthClient) -> Self { + Self { + cred_full_path, + auth_client, + } + } + + /// Refreshes token in memory using the registration read from when the auth client was + /// spawned. This also persists the retrieved token + pub async fn refresh_token(&self) -> Result<(), OauthUtilError> { + let cred = self.auth_client.auth_manager.lock().await.refresh_token().await?; + let parent_path = self.cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(parent_path).await?; + + let cred_as_bytes = serde_json::to_string_pretty(&cred)?; + tokio::fs::write(&self.cred_full_path, &cred_as_bytes).await?; + + Ok(()) + } +} + +pub fn get_default_scopes() -> &'static [&'static str] { + &["openid", "email", "profile", "offline_access"] +} + +enum TransportType { + Http, + Sse, +} + +enum HttpServiceBuilderState { + AttemptConnection(TransportType, bool), + FailedBecauseTokenMightBeExpired, + Exhausted, +} + +pub type HttpRunningService = ( + rmcp::service::RunningService>>, + Option, +); + +pub struct HttpServiceBuilder<'a> { + pub server_name: &'a str, + pub url: &'a str, + pub timeout: u64, + pub scopes: &'a [String], + pub headers: &'a HashMap, + pub oauth_config: &'a Option, + pub server_actor_event_tx: &'a mpsc::Sender, +} + +impl<'a> HttpServiceBuilder<'a> { + #[allow(clippy::too_many_arguments)] + pub fn new( + server_name: &'a str, + url: &'a str, + timeout: u64, + scopes: &'a [String], + headers: &'a HashMap, + oauth_config: &'a Option, + server_actor_event_tx: &'a mpsc::Sender, + ) -> Self { + Self { + server_name, + url, + timeout, + scopes, + headers, + oauth_config, + server_actor_event_tx, + } + } + + pub async fn try_build + Clone>( + self, + service: &S, + cred_dir: &Path, + ) -> Result { + let HttpServiceBuilder { + server_name, + url, + timeout, + scopes, + headers, + oauth_config, + server_actor_event_tx, + } = self; + + let mut state = HttpServiceBuilderState::AttemptConnection(TransportType::Http, false); + let url = Url::from_str(url)?; + let key = compute_key(&url); + let cred_full_path = cred_dir.join(format!("{key}.token.json")); + let reg_full_path = cred_dir.join(format!("{key}.registration.json")); + let mut auth_client = None::>; + + let mut client_builder = reqwest::ClientBuilder::new().timeout(std::time::Duration::from_millis(timeout)); + if !headers.is_empty() { + let headers = HeaderMap::try_from(headers).map_err(|e| OauthUtilError::Http(e.to_string()))?; + client_builder = client_builder.default_headers(headers); + }; + let reqwest_client = client_builder.build()?; + + // The probe request, like all other request, should adhere to the standards as per https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + let probe_resp = reqwest_client + .post(url.clone()) + .header("Accept", "application/json, text/event-stream") + .send() + .await; + let is_probe_err = probe_resp.is_err(); + let is_status_401_or_403 = probe_resp + .as_ref() + .is_ok_and(|resp| resp.status() == StatusCode::UNAUTHORIZED || resp.status() == StatusCode::FORBIDDEN); + + let contains_auth_header = probe_resp.is_ok_and(|resp| { + resp.headers().get("www-authenticate").is_some_and(|v| { + let value_as_str = v.to_str(); + if let Ok(value) = value_as_str { + value.to_lowercase().contains("bearer") + } else { + false + } + }) + }); + let needs_auth = is_probe_err || is_status_401_or_403 || contains_auth_header; + + // Here we attempt the following in the order they are presented: + // 1. Build transport, first assume http on attempt one, sse on attempt two + // - If it fails and it needs auth, attempt to refresh token (#2) + // - If it fails and it does not need auth OR if it fails after a refresh, attempt sse (#3) + // 2. Refresh token, go back to #1 + // 3. Attempt sse + // - If it fails, abort (because at this point we have run out of things to try, note that + // refreshing of token is agnostic to the type of transport) + loop { + match state { + HttpServiceBuilderState::AttemptConnection(transport_type, has_refreshed) => { + if needs_auth { + let ac = match auth_client { + Some(ref auth_client) => auth_client.clone(), + None => { + let am = get_auth_manager( + server_name, + url.clone(), + cred_full_path.clone(), + reg_full_path.clone(), + scopes, + oauth_config, + server_actor_event_tx, + ) + .await?; + + let ac = AuthClient::new(reqwest_client.clone(), am); + auth_client.replace(ac.clone()); + ac + }, + }; + + match transport_type { + TransportType::Http => { + let transport = StreamableHttpClientTransport::with_client( + ac.clone(), + StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: true, + ..Default::default() + }, + ); + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => { + let auth_client_wrapper = AuthClientWrapper::new(cred_full_path, ac); + return Ok((service, Some(auth_client_wrapper))); + }, + Err(e) => { + if !has_refreshed { + error!( + "## mcp: http handshake attempt failed for {server_name}: {:?}. Attempting to refresh token", + e + ); + // first we'll try refreshing the token + state = HttpServiceBuilderState::FailedBecauseTokenMightBeExpired; + } else { + error!( + "## mcp: http handshake attempt failed for {server_name}: {:?}. Attempting sse", + e + ); + state = + HttpServiceBuilderState::AttemptConnection(TransportType::Sse, true); + } + }, + } + }, + TransportType::Sse => { + let transport = SseClientTransport::start_with_client(ac.clone(), SseClientConfig { + sse_endpoint: url.as_str().into(), + ..Default::default() + }) + .await + .map_err(|e| OauthUtilError::SseTransport(e.to_string()))?; + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => { + let auth_client_wrapper = AuthClientWrapper::new(cred_full_path, ac); + return Ok((service, Some(auth_client_wrapper))); + }, + Err(e) => { + // at this point we would have already tried refreshing + // we are out of things to try and should just fail + error!( + "## mcp: sse handshake attempted failed for {server_name}: {:?}. Aborting", + e + ); + state = HttpServiceBuilderState::Exhausted; + }, + } + }, + } + } else { + info!( + "## mcp: No OAuth endpoints discovered for {server_name}, using unauthenticated transport" + ); + + match transport_type { + TransportType::Http => { + info!("## mcp: attempting open http handshake for {server_name}"); + let transport = StreamableHttpClientTransport::with_client( + reqwest_client.clone(), + StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: true, + ..Default::default() + }, + ); + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => return Ok((service, None)), + Err(e) => { + error!( + "## mcp: open http handshake attempted failed for {server_name}: {:?}. Attempting sse", + e + ); + state = HttpServiceBuilderState::AttemptConnection(TransportType::Sse, false); + }, + } + }, + TransportType::Sse => { + info!("## mcp: attempting open sse handshake for {server_name}"); + let transport = + SseClientTransport::start_with_client(reqwest_client.clone(), SseClientConfig { + sse_endpoint: url.as_str().into(), + ..Default::default() + }) + .await + .map_err(|e| OauthUtilError::SseTransport(e.to_string()))?; + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => return Ok((service, None)), + Err(e) => { + error!( + "## mcp: open sse handshake attempted failed for {server_name}: {:?}. Aborting", + e + ); + state = HttpServiceBuilderState::Exhausted; + }, + } + }, + } + } + }, + HttpServiceBuilderState::FailedBecauseTokenMightBeExpired => { + let auth_client_ref = auth_client.as_ref().ok_or(OauthUtilError::MissingAuthClient)?; + let auth_client_wrapper = AuthClientWrapper::new(cred_full_path.clone(), auth_client_ref.clone()); + let refresh_res = auth_client_wrapper.refresh_token().await; + + if let Err(e) = refresh_res { + error!("## mcp: token refresh attempt failed: {:?}", e); + info!("Retry for http transport failed {e}. Possible reauth needed"); + // This could be because the refresh token is expired, in which + // case we would need to have user go through the auth flow + // again. We do this by deleting the cred + // and discarding the client to trigger a full auth flow + if cred_full_path.is_file() { + tokio::fs::remove_file(&cred_full_path).await?; + } + + // we'll also need to remove the auth client to force a reauth when we go + // back to attempt the first step again + auth_client.take(); + } + + state = HttpServiceBuilderState::AttemptConnection(TransportType::Http, true); + }, + HttpServiceBuilderState::Exhausted => { + return Err(OauthUtilError::ServiceNotObtained( + "Max number of retries exhausted".to_string(), + )); + }, + } + } + } +} + +async fn get_auth_manager( + server_name: &str, + url: Url, + cred_full_path: PathBuf, + reg_full_path: PathBuf, + scopes: &[String], + oauth_config: &Option, + server_actor_event_tx: &mpsc::Sender, +) -> Result { + let cred_as_bytes = tokio::fs::read(&cred_full_path).await; + let reg_as_bytes = tokio::fs::read(®_full_path).await; + let mut oauth_state = OAuthState::new(url, None).await?; + + match (cred_as_bytes, reg_as_bytes) { + (Ok(cred_as_bytes), Ok(reg_as_bytes)) => { + let token = serde_json::from_slice::(&cred_as_bytes)?; + let reg = serde_json::from_slice::(®_as_bytes)?; + + oauth_state.set_credentials(®.client_id, token).await?; + + debug!("## mcp: credentials set with cache"); + + Ok(oauth_state + .into_authorization_manager() + .ok_or(OauthUtilError::MissingAuthorizationManager)?) + }, + _ => { + info!("Error reading cached credentials"); + debug!("## mcp: cache read failed. constructing auth manager from scratch"); + let (am, redirect_uri) = + get_auth_manager_impl(server_name, oauth_state, scopes, oauth_config, server_actor_event_tx).await?; + + // Client registration is done in [start_authorization] + // If we have gotten past that point that means we have the info to persist the + // registration on disk. + let (client_id, credentials) = am.get_credentials().await?; + let reg = Registration { + client_id, + client_secret: None, + scopes: get_default_scopes() + .iter() + .map(|s| (*s).to_string()) + .collect::>(), + redirect_uri, + }; + let reg_as_str = serde_json::to_string_pretty(®)?; + let reg_parent_path = reg_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(reg_parent_path).await?; + tokio::fs::write(reg_full_path, ®_as_str).await?; + + let credentials = credentials.ok_or(OauthUtilError::MissingCredentials)?; + + let cred_parent_path = cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(cred_parent_path).await?; + let reg_as_str = serde_json::to_string_pretty(&credentials)?; + tokio::fs::write(cred_full_path, ®_as_str).await?; + + Ok(am) + }, + } +} + +async fn get_auth_manager_impl( + server_name: &str, + mut oauth_state: OAuthState, + scopes: &[String], + oauth_config: &Option, + server_actor_event_tx: &mpsc::Sender, +) -> Result<(AuthorizationManager, String), OauthUtilError> { + // Get port from per-server oauth config, or use 0 for random port assignment + let port = oauth_config + .as_ref() + .and_then(|cfg| cfg.redirect_uri.as_ref()) + .and_then(|uri| { + // Parse port from redirect_uri like "127.0.0.1:7778" or ":7778" + uri.split(':').next_back().and_then(|p| p.parse::().ok()) + }) + .unwrap_or(0); // Port 0 = OS assigns random available port + + let socket_addr = SocketAddr::from(([127, 0, 0, 1], port)); + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let (tx, rx) = tokio::sync::oneshot::channel::<(String, String)>(); + + let (actual_addr, _dg) = make_svc(tx, socket_addr, cancellation_token).await?; + info!("Listening on local host port {:?} for oauth", actual_addr); + + let redirect_uri = format!("http://{}", actual_addr); + let scopes_as_str = scopes.iter().map(String::as_str).collect::>(); + let scopes_as_slice = scopes_as_str.as_slice(); + start_authorization(&mut oauth_state, scopes_as_slice, &redirect_uri).await?; + + let oauth_url = oauth_state.get_authorization_url().await?; + debug!(?oauth_url, "generated auth url"); + if let Err(e) = server_actor_event_tx + .send(McpServerActorEvent::OauthRequest { + server_name: server_name.to_string(), + oauth_url, + }) + .await + { + error!(?e, "failed to send auth url"); + } + + let (auth_code, csrf_token) = rx.await?; + oauth_state.handle_callback(&auth_code, &csrf_token).await?; + let am = oauth_state + .into_authorization_manager() + .ok_or(OauthUtilError::MissingAuthorizationManager)?; + + Ok((am, redirect_uri)) +} + +pub fn compute_key(rs: &Url) -> String { + let mut hasher = Sha256::new(); + let input = format!("{}{}", rs.origin().ascii_serialization(), rs.path()); + hasher.update(input.as_bytes()); + format!("{:x}", hasher.finalize()) +} + +/// This is our own implementation of [OAuthState::start_authorization]. +/// This differs from [OAuthState::start_authorization] by assigning our own client_id for DCR. +/// We need this because the SDK hardcodes their own client id. And some servers will use client_id +/// to identify if a client is even allowed to perform the auth handshake. +async fn start_authorization( + oauth_state: &mut OAuthState, + scopes: &[&str], + redirect_uri: &str, +) -> Result<(), OauthUtilError> { + // DO NOT CHANGE THIS + // This string has significance as it is used for remote servers to identify us + const CLIENT_ID: &str = "Q DEV CLI"; + + let stub_cred = get_stub_credentials()?; + oauth_state.set_credentials(CLIENT_ID, stub_cred).await?; + + // The setting of credentials would put the oauth state into authorize. + if let OAuthState::Authorized(auth_manager) = oauth_state { + // set redirect uri + let config = OAuthClientConfig { + client_id: CLIENT_ID.to_string(), + client_secret: None, + scopes: scopes.iter().map(|s| (*s).to_string()).collect(), + redirect_uri: redirect_uri.to_string(), + }; + + // try to dynamic register client + let config = match auth_manager.register_client(CLIENT_ID, redirect_uri).await { + Ok(config) => config, + Err(e) => { + eprintln!("Dynamic registration failed: {}", e); + // fallback to default config + config + }, + }; + // reset client config + auth_manager.configure_client(config)?; + let auth_url = auth_manager.get_authorization_url(scopes).await?; + + let mut stub_auth_manager = AuthorizationManager::new("http://localhost").await?; + std::mem::swap(auth_manager, &mut stub_auth_manager); + + let session = AuthorizationSession { + auth_manager: stub_auth_manager, + auth_url, + redirect_uri: redirect_uri.to_string(), + }; + + let mut new_oauth_state = OAuthState::Session(session); + std::mem::swap(oauth_state, &mut new_oauth_state); + } else { + unreachable!() + } + + Ok(()) +} + +/// This looks silly but [rmcp::transport::auth::OAuthTokenResponse] is private and there is no +/// other way to create this directly +fn get_stub_credentials() -> Result { + const STUB_TOKEN: &str = r#" + { + "access_token": "stub", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "stub", + "scope": "stub" + } + "#; + + serde_json::from_str::(STUB_TOKEN) +} + +async fn make_svc( + one_shot_sender: Sender<(String, String)>, + socket_addr: SocketAddr, + cancellation_token: CancellationToken, +) -> Result<(SocketAddr, LoopBackDropGuard), OauthUtilError> { + type AuthCodeSender = Sender<(String, String)>; + #[derive(Clone, Debug)] + struct LoopBackForSendingAuthCode { + one_shot_sender: Arc>>, + } + + #[derive(Debug, thiserror::Error)] + enum LoopBackError { + #[error("Poison error encountered: {0}")] + Poison(String), + #[error(transparent)] + Http(#[from] http::Error), + #[error("Failed to send auth code")] + Send((String, String)), + } + + fn mk_response(s: String) -> Result>, LoopBackError> { + Ok(Response::builder().body(Full::new(Bytes::from(s)))?) + } + + impl hyper::service::Service> for LoopBackForSendingAuthCode { + type Error = LoopBackError; + type Future = Pin> + Send>>; + type Response = Response>; + + fn call(&self, req: hyper::Request) -> Self::Future { + let uri = req.uri(); + let query = uri.query().unwrap_or(""); + let params: std::collections::HashMap = + url::form_urlencoded::parse(query.as_bytes()).into_owned().collect(); + debug!("## mcp: uri: {}, query: {}, params: {:?}", uri, query, params); + + let self_clone = self.clone(); + Box::pin(async move { + let error = params.get("error"); + let resp = if let Some(err) = error { + mk_response(format!( + "OAuth failed. Check URL for precise reasons. Possible reasons: {}.\n\ + If this is scope related, you can try configuring the server scopes \n\ + to be an empty array by adding \"oauthScopes\": [] to your server config.\n\ + Example: {{\"type\": \"http\", \"uri\": \"https://example.com/mcp\", \"oauthScopes\": []}}\n", + err + )) + } else { + mk_response("You can close this page now".to_string()) + }; + + let code = params.get("code").cloned().unwrap_or_default(); + let state = params.get("state").cloned().unwrap_or_default(); + if let Some(sender) = self_clone + .one_shot_sender + .lock() + .map_err(|e| LoopBackError::Poison(e.to_string()))? + .take() + { + sender.send((code, state)).map_err(LoopBackError::Send)?; + } + + resp + }) + } + } + + let listener = tokio::net::TcpListener::bind(socket_addr).await?; + let actual_addr = listener.local_addr()?; + let cancellation_token_clone = cancellation_token.clone(); + let dg = LoopBackDropGuard { + cancellation_token: cancellation_token_clone, + }; + + let loop_back = LoopBackForSendingAuthCode { + one_shot_sender: Arc::new(std::sync::Mutex::new(Some(one_shot_sender))), + }; + + // This is one and done + // This server only needs to last as long as it takes to send the auth code or to fail the auth + // flow + tokio::spawn(async move { + let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); + + tokio::select! { + _ = cancellation_token.cancelled() => { + info!("Oauth loopback server cancelled"); + }, + res = http1::Builder::new().serve_connection(io, loop_back) => { + if let Err(err) = res { + error!("Auth code loop back has failed: {:?}", err); + } + } + } + + Ok::<(), eyre::Report>(()) + }); + + Ok((actual_addr, dg)) +} diff --git a/crates/agent/src/agent/mcp/service.rs b/crates/agent/src/agent/mcp/service.rs index eb1d8a38f0..55c995c85d 100644 --- a/crates/agent/src/agent/mcp/service.rs +++ b/crates/agent/src/agent/mcp/service.rs @@ -1,9 +1,11 @@ +use std::path::PathBuf; use std::process::Stdio; use std::time::{ Duration, Instant, }; +use rmcp::RoleClient; use rmcp::model::{ CallToolRequestParam, CallToolResult, @@ -16,15 +18,14 @@ use rmcp::model::{ ServerRequest, Tool as RmcpTool, }; +use rmcp::service::{ + DynService, + ServiceExt, +}; use rmcp::transport::{ ConfigureCommandExt as _, TokioChildProcess, }; -use rmcp::{ - RoleClient, - ServiceError, - ServiceExt as _, -}; use tokio::io::AsyncReadExt as _; use tokio::process::{ ChildStderr, @@ -39,38 +40,59 @@ use tracing::{ warn, }; -use super::actor::McpMessage; +use super::actor::{ + McpMessage, + McpServerActorEvent, +}; +use super::oauth_util::{ + AuthClientWrapper, + HttpServiceBuilder, +}; use super::types::Prompt; use crate::agent::agent_config::definitions::McpServerConfig; use crate::agent::agent_loop::types::ToolSpec; use crate::agent::util::expand_env_vars; use crate::agent::util::path::expand_path; +use crate::agent_config::definitions::RemoteMcpServerConfig; use crate::util::providers::RealProvider; /// This struct is consumed by the [rmcp] crate on server launch. The only purpose of this struct /// is to handle server-to-client requests. Client-side code will own a [RunningMcpService] /// instance. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct McpService { server_name: String, config: McpServerConfig, + cred_path: PathBuf, /// Sender to the related [McpServerActor] message_tx: mpsc::Sender, } impl McpService { - pub fn new(server_name: String, config: McpServerConfig, message_tx: mpsc::Sender) -> Self { + pub fn new( + server_name: String, + config: McpServerConfig, + cred_path: PathBuf, + message_tx: mpsc::Sender, + ) -> Self { Self { server_name, config, + cred_path, message_tx, } } /// Launches the provided MCP server, returning a client handle to the server for sending /// requests. - pub async fn launch(self) -> eyre::Result<(RunningMcpService, LaunchMetadata)> { - match &self.config { + pub async fn launch( + self, + event_tx: &mpsc::Sender, + ) -> eyre::Result<(RunningMcpService, LaunchMetadata)> { + let serve_time_taken: std::time::Duration; + let server_name = self.server_name.clone(); + + let (service, child_stderr, auth_client) = match &self.config { McpServerConfig::Local(config) => { // TODO - don't use real provider let cmd = expand_path(&config.command, &RealProvider)?; @@ -93,74 +115,106 @@ impl McpService { let start_time = Instant::now(); info!(?server_name, "Launching MCP server"); - let service = self.serve(process).await?; - let serve_time_taken = start_time.elapsed(); + let service = self.into_dyn().serve(process).await?; + serve_time_taken = start_time.elapsed(); info!(?serve_time_taken, ?server_name, "MCP server launched successfully"); - let launch_md = match service.peer_info() { - Some(info) => { - debug!(?server_name, ?info, "peer info found"); - - // Fetch tools, if we can - let (tools, list_tools_duration) = if info.capabilities.tools.is_some() { - let start_time = Instant::now(); - match service.list_all_tools().await { - Ok(tools) => ( - Some(tools.into_iter().map(Into::into).collect()), - Some(start_time.elapsed()), - ), - Err(err) => { - error!(?err, "failed to list tools during server initialization"); - (None, None) - }, - } - } else { + (service, stderr, None) + }, + McpServerConfig::Remote(config) => { + let RemoteMcpServerConfig { + url, + headers, + timeout_ms: timeout, + oauth_scopes: scopes, + oauth: oauth_config, + disabled: _, + } = config; + + let start_time = Instant::now(); + info!(?self.server_name, "Launching MCP server"); + + let mut processed_headers = headers.clone(); + expand_env_vars(&mut processed_headers); + + let http_service_builder = HttpServiceBuilder::new( + &self.server_name, + url, + *timeout, + scopes, + &processed_headers, + oauth_config, + event_tx, + ); + let (service, auth_client) = http_service_builder.try_build(&self, &self.cred_path).await?; + serve_time_taken = start_time.elapsed(); + + (service, None, auth_client) + }, + }; + + let launch_md = match service.peer_info() { + Some(info) => { + debug!(?server_name, ?info, "peer info found"); + + // Fetch tools, if we can + let (tools, list_tools_duration) = if info.capabilities.tools.is_some() { + let start_time = Instant::now(); + match service.list_all_tools().await { + Ok(tools) => ( + Some(tools.into_iter().map(Into::into).collect()), + Some(start_time.elapsed()), + ), + Err(err) => { + error!(?err, "failed to list tools during server initialization"); (None, None) - }; - - // Fetch prompts, if we can - let (prompts, list_prompts_duration) = if info.capabilities.prompts.is_some() { - let start_time = Instant::now(); - match service.list_all_prompts().await { - Ok(prompts) => ( - Some(prompts.into_iter().map(Into::into).collect()), - Some(start_time.elapsed()), - ), - Err(err) => { - error!(?err, "failed to list prompts during server initialization"); - (None, None) - }, - } - } else { + }, + } + } else { + (None, None) + }; + + // Fetch prompts, if we can + let (prompts, list_prompts_duration) = if info.capabilities.prompts.is_some() { + let start_time = Instant::now(); + match service.list_all_prompts().await { + Ok(prompts) => ( + Some(prompts.into_iter().map(Into::into).collect()), + Some(start_time.elapsed()), + ), + Err(err) => { + error!(?err, "failed to list prompts during server initialization"); (None, None) - }; - - LaunchMetadata { - serve_time_taken, - tools, - list_tools_duration, - prompts, - list_prompts_duration, - } - }, - None => { - warn!(?server_name, "no peer info found"); - LaunchMetadata { - serve_time_taken, - tools: None, - list_tools_duration: None, - prompts: None, - list_prompts_duration: None, - } - }, + }, + } + } else { + (None, None) }; - Ok((RunningMcpService::new(server_name, service, stderr), launch_md)) + LaunchMetadata { + serve_time_taken, + tools, + list_tools_duration, + prompts, + list_prompts_duration, + } }, - McpServerConfig::StreamableHTTP(_) => { - eyre::bail!("not supported"); + None => { + warn!(?server_name, "no peer info found"); + LaunchMetadata { + serve_time_taken, + tools: None, + list_tools_duration: None, + prompts: None, + list_prompts_duration: None, + } }, - } + }; + + Ok(( + RunningMcpService::new(server_name, service, child_stderr, auth_client), + launch_md, + )) } } @@ -253,6 +307,97 @@ pub struct LaunchMetadata { pub list_prompts_duration: Option, } +/// Decorates the method passed in with retry logic, but only if the [RunningService] has an +/// instance of [AuthClientDropGuard]. +/// The various methods to interact with the mcp server provided by RMCP supposedly does refresh +/// token once the token expires but that logic would require us to also note down the time at +/// which a token is obtained since the only time related information in the token is the duration +/// for which a token is valid. However, if we do solely rely on the internals of these methods to +/// refresh tokens, we would have no way of knowing when a token is obtained. (Maybe there is a +/// method that would allow us to configure what extra info to include in the token. If you find it, +/// feel free to remove this. That would also enable us to simplify the definition of +/// [RunningService]) +macro_rules! decorate_with_auth_retry { + ($param_type:ty, $method_name:ident, $return_type:ty) => { + pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> { + let first_attempt = match &self.running_service { + InnerService::Original(rs) => rs.$method_name(param.clone()).await, + InnerService::Peer(peer) => peer.$method_name(param.clone()).await, + }; + + match first_attempt { + Ok(result) => Ok(result), + Err(e) => { + // TODO: discern error type prior to retrying + // Not entirely sure what is thrown when auth is required + if let Some(auth_client) = self.auth_client.as_ref() { + let refresh_result = auth_client.refresh_token().await; + match refresh_result { + Ok(_) => { + info!("Token refreshed"); + // Retry the operation after token refresh + match &self.running_service { + InnerService::Original(rs) => rs.$method_name(param).await, + InnerService::Peer(peer) => peer.$method_name(param).await, + } + }, + Err(_) => { + // If refresh fails, return the original error + // Currently our event loop just does not allow us easy ways to + // reauth entirely once a session starts since this would mean + // swapping of transport (which also means swapping of client) + Err(e) + }, + } + } else { + // No auth client available, return original error + Err(e) + } + }, + } + } + }; + ($method_name:ident, $return_type:ty) => { + pub async fn $method_name(&self) -> Result<$return_type, rmcp::ServiceError> { + let first_attempt = match &self.running_service { + InnerService::Original(rs) => rs.$method_name().await, + InnerService::Peer(peer) => peer.$method_name().await, + }; + + match first_attempt { + Ok(result) => Ok(result), + Err(e) => { + // TODO: discern error type prior to retrying + // Not entirely sure what is thrown when auth is required + if let Some(auth_client) = self.auth_client.as_ref() { + let refresh_result = auth_client.refresh_token().await; + match refresh_result { + Ok(_) => { + info!("Token refreshed"); + // Retry the operation after token refresh + match &self.running_service { + InnerService::Original(rs) => rs.$method_name().await, + InnerService::Peer(peer) => peer.$method_name().await, + } + }, + Err(_) => { + // If refresh fails, return the original error + // Currently our event loop just does not allow us easy ways to + // reauth entirely once a session starts since this would mean + // swapping of transport (which also means swapping of client) + Err(e) + }, + } + } else { + // No auth client available, return original error + Err(e) + } + }, + } + } + }; +} + /// Represents a handle to a running MCP server. #[derive(Debug, Clone)] pub struct RunningMcpService { @@ -262,13 +407,21 @@ pub struct RunningMcpService { /// TODO - maybe replace RunningMcpService with just InnerService? Probably not, once OAuth is /// implemented since that may require holding an auth guard. running_service: InnerService, + auth_client: Option, } impl RunningMcpService { + decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult); + + decorate_with_auth_retry!(list_all_tools, Vec); + + decorate_with_auth_retry!(list_all_prompts, Vec); + fn new( server_name: String, - running_service: rmcp::service::RunningService, + running_service: rmcp::service::RunningService>>, child_stderr: Option, + auth_client: Option, ) -> Self { // We need to read from the child process stderr - otherwise, ?? will happen if let Some(mut stderr) = child_stderr { @@ -295,20 +448,9 @@ impl RunningMcpService { Self { running_service: InnerService::Original(running_service), + auth_client, } } - - pub async fn call_tool(&self, param: CallToolRequestParam) -> Result { - self.running_service.peer().call_tool(param).await - } - - pub async fn list_tools(&self) -> Result, ServiceError> { - self.running_service.peer().list_all_tools().await - } - - pub async fn list_prompts(&self) -> Result, ServiceError> { - self.running_service.peer().list_all_prompts().await - } } /// Wrapper around rmcp service types to enable cloning. @@ -319,19 +461,10 @@ impl RunningMcpService { /// pointer type to `Peer`. This enum allows us to hold either the original service or its /// peer representation, enabling cloning by converting the original service to a peer when needed. pub enum InnerService { - Original(rmcp::service::RunningService), + Original(rmcp::service::RunningService>>), Peer(rmcp::service::Peer), } -impl InnerService { - fn peer(&self) -> &rmcp::Peer { - match self { - InnerService::Original(service) => service.peer(), - InnerService::Peer(peer) => peer, - } - } -} - impl std::fmt::Debug for InnerService { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/agent/src/agent/mod.rs b/crates/agent/src/agent/mod.rs index e0f0809796..2af446b6ba 100644 --- a/crates/agent/src/agent/mod.rs +++ b/crates/agent/src/agent/mod.rs @@ -58,6 +58,7 @@ use agent_loop::{ use chrono::Utc; use consts::MAX_RESOURCE_FILE_LENGTH; use futures::stream::FuturesUnordered; +use mcp::McpServerEvent; use permissions::evaluate_tool_permission; use protocol::{ AgentError, @@ -347,7 +348,13 @@ impl Agent { } let mut results = FuturesUnordered::new(); - for config in &self.cached_mcp_configs.configs { + + for config in self + .cached_mcp_configs + .configs + .iter() + .filter(|config| config.is_enabled()) + { let Ok(rx) = self .mcp_manager_handle .launch_server(config.server_name.clone(), config.config.clone()) @@ -482,7 +489,18 @@ impl Agent { } self.agent_event_buf.push(evt.into()); } - } + }, + + evt = self.mcp_manager_handle.recv() => { + match evt { + Ok(evt) => { + self.handle_mcp_events(evt).await; + }, + Err(e) => { + error!(?e, "mcp manager handle closed"); + } + } + }, } } } @@ -1679,6 +1697,11 @@ impl Agent { self.set_active_state(ActiveState::ExecutingRequest).await; Ok(()) } + + async fn handle_mcp_events(&mut self, evt: McpServerEvent) { + let converted_evt = AgentEvent::Mcp(evt); + self.agent_event_buf.push(converted_evt); + } } /// Creates a request structure for sending to the model. diff --git a/crates/agent/src/agent/protocol.rs b/crates/agent/src/agent/protocol.rs index d0e83ed895..2a0bc4012f 100644 --- a/crates/agent/src/agent/protocol.rs +++ b/crates/agent/src/agent/protocol.rs @@ -17,8 +17,11 @@ use super::agent_loop::types::{ ImageBlock, ToolUseBlock, }; -use super::mcp::McpManagerError; use super::mcp::types::Prompt; +use super::mcp::{ + McpManagerError, + McpServerEvent, +}; use super::task_executor::TaskExecutorEvent; use super::tools::{ Tool, @@ -73,6 +76,9 @@ pub enum AgentEvent { /// Lower-level events associated with the agent's execution. Generally only useful for /// debugging or telemetry purposes. Internal(InternalEvent), + + /// Events from MCP (Model Context Protocol) servers + Mcp(McpServerEvent), } impl From for AgentEvent { diff --git a/crates/agent/src/cli/run.rs b/crates/agent/src/cli/run.rs index 49e045677f..f02df2f99a 100644 --- a/crates/agent/src/cli/run.rs +++ b/crates/agent/src/cli/run.rs @@ -1,4 +1,5 @@ use std::io::Write as _; +use std::path::PathBuf; use std::process::ExitCode; use std::sync::Arc; @@ -61,6 +62,9 @@ pub struct RunArgs { /// Trust all tools #[arg(long)] dangerously_trust_all_tools: bool, + /// Credential path for remote mcp + #[arg(long)] + mcp_cred_path: Option, /// The initial prompt. prompt: Vec, } @@ -102,7 +106,13 @@ impl RunArgs { } }; - let agent = Agent::new(snapshot, model, McpManager::new().spawn()).await?.spawn(); + let mcp_manager_handle = if let Some(path) = self.mcp_cred_path.as_ref() { + let cred_path = PathBuf::from(path); + McpManager::new(cred_path).spawn() + } else { + McpManager::default().spawn() + }; + let agent = Agent::new(snapshot, model, mcp_manager_handle).await?.spawn(); self.main_loop(agent).await } @@ -112,6 +122,10 @@ impl RunArgs { // First, wait for agent initialization while let Ok(evt) = agent.recv().await { + if matches!(evt, AgentEvent::Mcp(_)) { + info!(?evt, "received mcp agent event"); + // TODO: Send it through conduit + } if matches!(evt, AgentEvent::Initialized) { break; } @@ -159,7 +173,10 @@ impl RunArgs { .await?; } }, - _ => (), + AgentEvent::Mcp(evt) => { + info!(?evt, "received mcp agent event"); + }, + _ => {}, } } diff --git a/crates/agent/tests/common/mod.rs b/crates/agent/tests/common/mod.rs index 6bf02e33dd..985de2eeac 100644 --- a/crates/agent/tests/common/mod.rs +++ b/crates/agent/tests/common/mod.rs @@ -100,7 +100,7 @@ impl TestCaseBuilder { model = model.with_response(response); } - let mut agent = Agent::new(snapshot, Arc::new(model), McpManager::new().spawn()).await?; + let mut agent = Agent::new(snapshot, Arc::new(model), McpManager::default().spawn()).await?; let mut test_base = TestBase::new().await; for file in self.files {