Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions crates/chat-cli/src/cli/chat/tools/custom_tool.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::collections::{
HashMap,
HashSet,
};
use std::io::Write;

use crossterm::{
Expand Down Expand Up @@ -39,32 +42,53 @@ pub enum TransportType {
Http,
}

impl std::str::FromStr for TransportType {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"stdio" => Ok(TransportType::Stdio),
"http" => Ok(TransportType::Http),
_ => Err(format!("Invalid transport type: {}", s)),
}
}
}

impl std::fmt::Display for TransportType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportType::Stdio => write!(f, "stdio"),
TransportType::Http => write!(f, "http"),
}
}
}

impl Default for TransportType {
fn default() -> Self {
Self::Stdio
}
}

#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
#[derive(Clone, Serialize, Deserialize, Debug, Default, Eq, PartialEq, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct CustomToolConfig {
/// The transport type to use for communication with the MCP server
#[serde(default)]
#[serde(default, skip_serializing_if = "should_skip_ser_transport_type")]
pub r#type: TransportType,
/// The URL for HTTP-based MCP server communication
#[serde(default)]
#[serde(default, skip_serializing_if = "String::is_empty")]
pub url: String,
/// HTTP headers to include when communicating with HTTP-based MCP servers
#[serde(default)]
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub headers: HashMap<String, String>,
/// Scopes with which oauth is done
#[serde(default = "get_default_scopes")]
#[serde(default = "get_default_scopes", skip_serializing_if = "should_skip_ser_scope")]
pub oauth_scopes: Vec<String>,
/// The command string used to initialize the mcp server
#[serde(default)]
#[serde(default, skip_serializing_if = "String::is_empty")]
pub command: String,
/// A list of arguments to be used to run the command with
#[serde(default)]
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub args: Vec<String>,
/// A list of environment variables to run the command with
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -73,13 +97,31 @@ pub struct CustomToolConfig {
#[serde(default = "default_timeout")]
pub timeout: u64,
/// A boolean flag to denote whether or not to load this mcp server
#[serde(default)]
#[serde(default, skip_serializing_if = "should_skip_ser_disabled")]
pub disabled: bool,
/// A flag to denote whether this is a server from the legacy mcp.json
#[serde(skip)]
pub is_from_legacy_mcp_json: bool,
}

fn should_skip_ser_disabled(disabled: &bool) -> bool {
!*disabled
}

fn should_skip_ser_transport_type(transport_type: &TransportType) -> bool {
matches!(transport_type, &TransportType::Stdio)
}

fn should_skip_ser_scope(scopes: &[String]) -> bool {
let mut set = HashSet::<&str>::new();
let default_scopes = oauth_util::get_default_scopes();
for scope in default_scopes {
set.insert(*scope);
}

scopes.iter().all(|s| set.contains(s.as_str()))
}

pub fn get_default_scopes() -> Vec<String> {
oauth_util::get_default_scopes()
.iter()
Expand Down
187 changes: 156 additions & 31 deletions crates/chat-cli/src/cli/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use super::agent::{
DEFAULT_AGENT_NAME,
McpServerConfig,
};
use super::chat::tools::custom_tool::TransportType;
use crate::cli::chat::tool_manager::{
global_mcp_config_path,
workspace_mcp_config_path,
Expand Down Expand Up @@ -85,8 +86,11 @@ impl McpSubcommand {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Args)]
#[derive(Debug, Default, Clone, PartialEq, Eq, Args)]
pub struct AddArgs {
/// Transport type for the MCP server (e.g., stdio, http)
#[arg(long)]
pub r#type: Option<TransportType>,
/// Name for the server
#[arg(long)]
pub name: String,
Expand All @@ -95,7 +99,13 @@ pub struct AddArgs {
pub scope: Option<Scope>,
/// The command used to launch the server
#[arg(long)]
pub command: String,
pub command: Option<String>,
/// URL for HTTP-based MCP servers
#[arg(long)]
pub url: Option<String>,
/// HTTP headers to include with requests for HTTP-based MCP servers
#[arg(long, value_parser = parse_key_val_pair)]
pub headers: Vec<HashMap<String, String>>,
/// Arguments to pass to the command. Can be provided as:
/// 1. Multiple --args flags: --args arg1 --args arg2 --args "arg,with,commas"
/// 2. Comma-separated with escaping: --args "arg1,arg2,arg\,with\,commas"
Expand All @@ -107,7 +117,7 @@ pub struct AddArgs {
#[arg(long)]
pub agent: Option<String>,
/// Environment variables to use when launching the server
#[arg(long, value_parser = parse_env_vars)]
#[arg(long, value_parser = parse_key_val_pair)]
pub env: Vec<HashMap<String, String>>,
/// Server launch timeout, in milliseconds
#[arg(long)]
Expand All @@ -121,11 +131,8 @@ pub struct AddArgs {
}

impl AddArgs {
pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> {
// Process args to handle comma-separated values, escaping, and JSON arrays
let processed_args = self.process_args()?;

match self.agent.as_deref() {
pub async fn execute(mut self, os: &Os, output: &mut impl Write) -> Result<()> {
match self.agent.take().as_deref() {
Some(agent_name) => {
let (mut agent, config_path) = Agent::get_agent_by_name(os, agent_name).await?;
let mcp_servers = &mut agent.mcp_servers.mcp_servers;
Expand All @@ -139,19 +146,13 @@ impl AddArgs {
);
}

let merged_env = self.env.into_iter().flatten().collect::<HashMap<_, _>>();
let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({
"command": self.command,
"args": processed_args,
"env": merged_env,
"timeout": self.timeout.unwrap_or(default_timeout()),
"disabled": self.disabled,
}))?;
let name = self.name.clone();
let tool = self.into_custom_tool_config()?;

mcp_servers.insert(self.name.clone(), tool);
mcp_servers.insert(name.clone(), tool);
let json = agent.to_str_pretty()?;
os.fs.write(config_path, json).await?;
writeln!(output, "✓ Added MCP server '{}' to agent {}\n", self.name, agent_name)?;
writeln!(output, "✓ Added MCP server '{}' to agent {}\n", name, agent_name)?;
},
None => {
let legacy_mcp_config_path = match self.scope {
Expand All @@ -172,21 +173,15 @@ impl AddArgs {
);
}

let merged_env = self.env.into_iter().flatten().collect::<HashMap<_, _>>();
let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({
"command": self.command,
"args": processed_args,
"env": merged_env,
"timeout": self.timeout.unwrap_or(default_timeout()),
"disabled": self.disabled,
}))?;
let name = self.name.clone();
let tool = self.into_custom_tool_config()?;

mcp_servers.mcp_servers.insert(self.name.clone(), tool);
mcp_servers.mcp_servers.insert(name.clone(), tool);
mcp_servers.save_to_file(os, &legacy_mcp_config_path).await?;
writeln!(
output,
"✓ Added MCP server '{}' to global config in {}\n",
self.name,
name,
legacy_mcp_config_path.display()
)?;
},
Expand All @@ -195,6 +190,45 @@ impl AddArgs {
Ok(())
}

fn into_custom_tool_config(self) -> Result<CustomToolConfig> {
match self.r#type {
Some(TransportType::Http) => {
if let Some(url) = self.url {
let merged_headers = self.headers.into_iter().flatten().collect::<HashMap<_, _>>();
Ok(CustomToolConfig {
r#type: TransportType::Http,
url,
headers: merged_headers,
timeout: self.timeout.unwrap_or(default_timeout()),
disabled: self.disabled,
..Default::default()
})
} else {
bail!("Transport type is specified to be http but url is not provided");
}
},
Some(TransportType::Stdio) | None => {
if self.command.is_some() {
let processed_args = self.process_args()?;
let merged_env = self.env.into_iter().flatten().collect::<HashMap<_, _>>();
Ok(CustomToolConfig {
r#type: TransportType::Stdio,
// Doing this saves us an allocation and this is safe because we have
// already verified that command is Some
command: self.command.unwrap(),
args: processed_args,
env: Some(merged_env),
timeout: self.timeout.unwrap_or(default_timeout()),
disabled: self.disabled,
..Default::default()
})
} else {
bail!("Transport type is specified to be stdio but command is not provided")
}
},
}
}

fn process_args(&self) -> Result<Vec<String>> {
let mut processed_args = Vec::new();

Expand Down Expand Up @@ -504,7 +538,7 @@ async fn ensure_config_file(os: &Os, path: &PathBuf, output: &mut impl Write) ->
load_cfg(os, path).await
}

fn parse_env_vars(arg: &str) -> Result<HashMap<String, String>> {
fn parse_key_val_pair(arg: &str) -> Result<HashMap<String, String>> {
let mut vars = HashMap::new();

for pair in arg.split(",") {
Expand Down Expand Up @@ -640,7 +674,7 @@ mod tests {
AddArgs {
name: "local".into(),
scope: None,
command: "echo hi".into(),
command: Some("echo hi".into()),
args: vec![
"awslabs.eks-mcp-server".to_string(),
"--allow-write".to_string(),
Expand All @@ -651,6 +685,7 @@ mod tests {
agent: None,
disabled: false,
force: false,
..Default::default()
}
.execute(&os, &mut vec![])
.await
Expand Down Expand Up @@ -693,7 +728,7 @@ mod tests {
RootSubcommand::Mcp(McpSubcommand::Add(AddArgs {
name: "test_server".to_string(),
scope: None,
command: "test_command".to_string(),
command: Some("test_command".to_string()),
args: vec!["awslabs.eks-mcp-server,--allow-write,--allow-sensitive-data-access".to_string(),],
agent: None,
env: vec![
Expand All @@ -707,6 +742,7 @@ mod tests {
timeout: None,
disabled: false,
force: false,
..Default::default()
}))
);
}
Expand Down Expand Up @@ -794,4 +830,93 @@ mod tests {
let result = parse_args(r#"["invalid json"#);
assert!(result.is_err());
}

#[test]
fn test_parse_http_transport_type() {
let add_args = AddArgs {
r#type: Some(TransportType::Http),
name: "test_http_server".to_string(),
url: Some("https://api.example.com".to_string()),
headers: vec![
[("Authorization".to_string(), "Bearer token123".to_string())]
.into_iter()
.collect(),
[("Content-Type".to_string(), "application/json".to_string())]
.into_iter()
.collect(),
],
timeout: Some(5000),
disabled: false,
..Default::default()
};

let config = add_args.into_custom_tool_config().unwrap();

assert_eq!(config.r#type, TransportType::Http);
assert_eq!(config.url, "https://api.example.com");
assert_eq!(config.timeout, 5000);
assert!(!config.disabled);
assert_eq!(
config.headers.get("Authorization"),
Some(&"Bearer token123".to_string())
);
assert_eq!(
config.headers.get("Content-Type"),
Some(&"application/json".to_string())
);
}

#[test]
fn test_incorrect_transport_type_should_fail() {
// Test HTTP transport without URL should fail
let add_args = AddArgs {
r#type: Some(TransportType::Http),
name: "test_http_server".to_string(),
url: None,
..Default::default()
};

let result = add_args.into_custom_tool_config();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Transport type is specified to be http but url is not provided")
);

// Test STDIO transport without command should fail
let add_args = AddArgs {
r#type: Some(TransportType::Stdio),
name: "test_stdio_server".to_string(),
command: None,
..Default::default()
};

let result = add_args.into_custom_tool_config();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Transport type is specified to be stdio but command is not provided")
);

// Test default (stdio) transport without command should fail
let add_args = AddArgs {
r#type: None,
name: "test_default_server".to_string(),
command: None,
..Default::default()
};

let result = add_args.into_custom_tool_config();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Transport type is specified to be stdio but command is not provided")
);
}
}
Loading