Skip to content

Commit 70f391d

Browse files
committed
adds http transport option to mcp root command
1 parent c90c87d commit 70f391d

File tree

2 files changed

+207
-40
lines changed

2 files changed

+207
-40
lines changed

crates/chat-cli/src/cli/chat/tools/custom_tool.rs

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use std::borrow::Cow;
2-
use std::collections::HashMap;
2+
use std::collections::{
3+
HashMap,
4+
HashSet,
5+
};
36
use std::io::Write;
47

58
use crossterm::{
@@ -39,32 +42,53 @@ pub enum TransportType {
3942
Http,
4043
}
4144

45+
impl std::str::FromStr for TransportType {
46+
type Err = String;
47+
48+
fn from_str(s: &str) -> Result<Self, Self::Err> {
49+
match s.to_lowercase().as_str() {
50+
"stdio" => Ok(TransportType::Stdio),
51+
"http" => Ok(TransportType::Http),
52+
_ => Err(format!("Invalid transport type: {}", s)),
53+
}
54+
}
55+
}
56+
57+
impl std::fmt::Display for TransportType {
58+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59+
match self {
60+
TransportType::Stdio => write!(f, "stdio"),
61+
TransportType::Http => write!(f, "http"),
62+
}
63+
}
64+
}
65+
4266
impl Default for TransportType {
4367
fn default() -> Self {
4468
Self::Stdio
4569
}
4670
}
4771

48-
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
72+
#[derive(Clone, Serialize, Deserialize, Debug, Default, Eq, PartialEq, JsonSchema)]
4973
#[serde(rename_all = "camelCase")]
5074
pub struct CustomToolConfig {
5175
/// The transport type to use for communication with the MCP server
52-
#[serde(default)]
76+
#[serde(default, skip_serializing_if = "should_skip_ser_transport_type")]
5377
pub r#type: TransportType,
5478
/// The URL for HTTP-based MCP server communication
55-
#[serde(default)]
79+
#[serde(default, skip_serializing_if = "String::is_empty")]
5680
pub url: String,
5781
/// HTTP headers to include when communicating with HTTP-based MCP servers
58-
#[serde(default)]
82+
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
5983
pub headers: HashMap<String, String>,
6084
/// Scopes with which oauth is done
61-
#[serde(default = "get_default_scopes")]
85+
#[serde(default = "get_default_scopes", skip_serializing_if = "should_skip_ser_scope")]
6286
pub oauth_scopes: Vec<String>,
6387
/// The command string used to initialize the mcp server
64-
#[serde(default)]
88+
#[serde(default, skip_serializing_if = "String::is_empty")]
6589
pub command: String,
6690
/// A list of arguments to be used to run the command with
67-
#[serde(default)]
91+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
6892
pub args: Vec<String>,
6993
/// A list of environment variables to run the command with
7094
#[serde(skip_serializing_if = "Option::is_none")]
@@ -73,13 +97,31 @@ pub struct CustomToolConfig {
7397
#[serde(default = "default_timeout")]
7498
pub timeout: u64,
7599
/// A boolean flag to denote whether or not to load this mcp server
76-
#[serde(default)]
100+
#[serde(default, skip_serializing_if = "should_skip_ser_disabled")]
77101
pub disabled: bool,
78102
/// A flag to denote whether this is a server from the legacy mcp.json
79103
#[serde(skip)]
80104
pub is_from_legacy_mcp_json: bool,
81105
}
82106

107+
fn should_skip_ser_disabled(disabled: &bool) -> bool {
108+
!*disabled
109+
}
110+
111+
fn should_skip_ser_transport_type(transport_type: &TransportType) -> bool {
112+
matches!(transport_type, &TransportType::Stdio)
113+
}
114+
115+
fn should_skip_ser_scope(scopes: &[String]) -> bool {
116+
let mut set = HashSet::<&str>::new();
117+
let default_scopes = oauth_util::get_default_scopes();
118+
for scope in default_scopes {
119+
set.insert(*scope);
120+
}
121+
122+
scopes.iter().all(|s| set.contains(s.as_str()))
123+
}
124+
83125
pub fn get_default_scopes() -> Vec<String> {
84126
oauth_util::get_default_scopes()
85127
.iter()

crates/chat-cli/src/cli/mcp.rs

Lines changed: 156 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use super::agent::{
2727
DEFAULT_AGENT_NAME,
2828
McpServerConfig,
2929
};
30+
use super::chat::tools::custom_tool::TransportType;
3031
use crate::cli::chat::tool_manager::{
3132
global_mcp_config_path,
3233
workspace_mcp_config_path,
@@ -85,8 +86,11 @@ impl McpSubcommand {
8586
}
8687
}
8788

88-
#[derive(Debug, Clone, PartialEq, Eq, Args)]
89+
#[derive(Debug, Default, Clone, PartialEq, Eq, Args)]
8990
pub struct AddArgs {
91+
/// Transport type for the MCP server (e.g., stdio, http)
92+
#[arg(long)]
93+
pub r#type: Option<TransportType>,
9094
/// Name for the server
9195
#[arg(long)]
9296
pub name: String,
@@ -95,7 +99,13 @@ pub struct AddArgs {
9599
pub scope: Option<Scope>,
96100
/// The command used to launch the server
97101
#[arg(long)]
98-
pub command: String,
102+
pub command: Option<String>,
103+
/// URL for HTTP-based MCP servers
104+
#[arg(long)]
105+
pub url: Option<String>,
106+
/// HTTP headers to include with requests for HTTP-based MCP servers
107+
#[arg(long, value_parser = parse_key_val_pair)]
108+
pub headers: Vec<HashMap<String, String>>,
99109
/// Arguments to pass to the command. Can be provided as:
100110
/// 1. Multiple --args flags: --args arg1 --args arg2 --args "arg,with,commas"
101111
/// 2. Comma-separated with escaping: --args "arg1,arg2,arg\,with\,commas"
@@ -107,7 +117,7 @@ pub struct AddArgs {
107117
#[arg(long)]
108118
pub agent: Option<String>,
109119
/// Environment variables to use when launching the server
110-
#[arg(long, value_parser = parse_env_vars)]
120+
#[arg(long, value_parser = parse_key_val_pair)]
111121
pub env: Vec<HashMap<String, String>>,
112122
/// Server launch timeout, in milliseconds
113123
#[arg(long)]
@@ -121,11 +131,8 @@ pub struct AddArgs {
121131
}
122132

123133
impl AddArgs {
124-
pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> {
125-
// Process args to handle comma-separated values, escaping, and JSON arrays
126-
let processed_args = self.process_args()?;
127-
128-
match self.agent.as_deref() {
134+
pub async fn execute(mut self, os: &Os, output: &mut impl Write) -> Result<()> {
135+
match self.agent.take().as_deref() {
129136
Some(agent_name) => {
130137
let (mut agent, config_path) = Agent::get_agent_by_name(os, agent_name).await?;
131138
let mcp_servers = &mut agent.mcp_servers.mcp_servers;
@@ -139,19 +146,13 @@ impl AddArgs {
139146
);
140147
}
141148

142-
let merged_env = self.env.into_iter().flatten().collect::<HashMap<_, _>>();
143-
let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({
144-
"command": self.command,
145-
"args": processed_args,
146-
"env": merged_env,
147-
"timeout": self.timeout.unwrap_or(default_timeout()),
148-
"disabled": self.disabled,
149-
}))?;
149+
let name = self.name.clone();
150+
let tool = self.into_custom_tool_config()?;
150151

151-
mcp_servers.insert(self.name.clone(), tool);
152+
mcp_servers.insert(name.clone(), tool);
152153
let json = agent.to_str_pretty()?;
153154
os.fs.write(config_path, json).await?;
154-
writeln!(output, "✓ Added MCP server '{}' to agent {}\n", self.name, agent_name)?;
155+
writeln!(output, "✓ Added MCP server '{}' to agent {}\n", name, agent_name)?;
155156
},
156157
None => {
157158
let legacy_mcp_config_path = match self.scope {
@@ -172,21 +173,15 @@ impl AddArgs {
172173
);
173174
}
174175

175-
let merged_env = self.env.into_iter().flatten().collect::<HashMap<_, _>>();
176-
let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({
177-
"command": self.command,
178-
"args": processed_args,
179-
"env": merged_env,
180-
"timeout": self.timeout.unwrap_or(default_timeout()),
181-
"disabled": self.disabled,
182-
}))?;
176+
let name = self.name.clone();
177+
let tool = self.into_custom_tool_config()?;
183178

184-
mcp_servers.mcp_servers.insert(self.name.clone(), tool);
179+
mcp_servers.mcp_servers.insert(name.clone(), tool);
185180
mcp_servers.save_to_file(os, &legacy_mcp_config_path).await?;
186181
writeln!(
187182
output,
188183
"✓ Added MCP server '{}' to global config in {}\n",
189-
self.name,
184+
name,
190185
legacy_mcp_config_path.display()
191186
)?;
192187
},
@@ -195,6 +190,45 @@ impl AddArgs {
195190
Ok(())
196191
}
197192

193+
fn into_custom_tool_config(self) -> Result<CustomToolConfig> {
194+
match self.r#type {
195+
Some(TransportType::Http) => {
196+
if let Some(url) = self.url {
197+
let merged_headers = self.headers.into_iter().flatten().collect::<HashMap<_, _>>();
198+
Ok(CustomToolConfig {
199+
r#type: TransportType::Http,
200+
url,
201+
headers: merged_headers,
202+
timeout: self.timeout.unwrap_or(default_timeout()),
203+
disabled: self.disabled,
204+
..Default::default()
205+
})
206+
} else {
207+
bail!("Transport type is specified to be http but url is not provided");
208+
}
209+
},
210+
Some(TransportType::Stdio) | None => {
211+
if self.command.is_some() {
212+
let processed_args = self.process_args()?;
213+
let merged_env = self.env.into_iter().flatten().collect::<HashMap<_, _>>();
214+
Ok(CustomToolConfig {
215+
r#type: TransportType::Stdio,
216+
// Doing this saves us an allocation and this is safe because we have
217+
// already verified that command is Some
218+
command: self.command.unwrap(),
219+
args: processed_args,
220+
env: Some(merged_env),
221+
timeout: self.timeout.unwrap_or(default_timeout()),
222+
disabled: self.disabled,
223+
..Default::default()
224+
})
225+
} else {
226+
bail!("Transport type is specified to be stdio but command is not provided")
227+
}
228+
},
229+
}
230+
}
231+
198232
fn process_args(&self) -> Result<Vec<String>> {
199233
let mut processed_args = Vec::new();
200234

@@ -504,7 +538,7 @@ async fn ensure_config_file(os: &Os, path: &PathBuf, output: &mut impl Write) ->
504538
load_cfg(os, path).await
505539
}
506540

507-
fn parse_env_vars(arg: &str) -> Result<HashMap<String, String>> {
541+
fn parse_key_val_pair(arg: &str) -> Result<HashMap<String, String>> {
508542
let mut vars = HashMap::new();
509543

510544
for pair in arg.split(",") {
@@ -640,7 +674,7 @@ mod tests {
640674
AddArgs {
641675
name: "local".into(),
642676
scope: None,
643-
command: "echo hi".into(),
677+
command: Some("echo hi".into()),
644678
args: vec![
645679
"awslabs.eks-mcp-server".to_string(),
646680
"--allow-write".to_string(),
@@ -651,6 +685,7 @@ mod tests {
651685
agent: None,
652686
disabled: false,
653687
force: false,
688+
..Default::default()
654689
}
655690
.execute(&os, &mut vec![])
656691
.await
@@ -693,7 +728,7 @@ mod tests {
693728
RootSubcommand::Mcp(McpSubcommand::Add(AddArgs {
694729
name: "test_server".to_string(),
695730
scope: None,
696-
command: "test_command".to_string(),
731+
command: Some("test_command".to_string()),
697732
args: vec!["awslabs.eks-mcp-server,--allow-write,--allow-sensitive-data-access".to_string(),],
698733
agent: None,
699734
env: vec![
@@ -707,6 +742,7 @@ mod tests {
707742
timeout: None,
708743
disabled: false,
709744
force: false,
745+
..Default::default()
710746
}))
711747
);
712748
}
@@ -794,4 +830,93 @@ mod tests {
794830
let result = parse_args(r#"["invalid json"#);
795831
assert!(result.is_err());
796832
}
833+
834+
#[test]
835+
fn test_parse_http_transport_type() {
836+
let add_args = AddArgs {
837+
r#type: Some(TransportType::Http),
838+
name: "test_http_server".to_string(),
839+
url: Some("https://api.example.com".to_string()),
840+
headers: vec![
841+
[("Authorization".to_string(), "Bearer token123".to_string())]
842+
.into_iter()
843+
.collect(),
844+
[("Content-Type".to_string(), "application/json".to_string())]
845+
.into_iter()
846+
.collect(),
847+
],
848+
timeout: Some(5000),
849+
disabled: false,
850+
..Default::default()
851+
};
852+
853+
let config = add_args.into_custom_tool_config().unwrap();
854+
855+
assert_eq!(config.r#type, TransportType::Http);
856+
assert_eq!(config.url, "https://api.example.com");
857+
assert_eq!(config.timeout, 5000);
858+
assert!(!config.disabled);
859+
assert_eq!(
860+
config.headers.get("Authorization"),
861+
Some(&"Bearer token123".to_string())
862+
);
863+
assert_eq!(
864+
config.headers.get("Content-Type"),
865+
Some(&"application/json".to_string())
866+
);
867+
}
868+
869+
#[test]
870+
fn test_incorrect_transport_type_should_fail() {
871+
// Test HTTP transport without URL should fail
872+
let add_args = AddArgs {
873+
r#type: Some(TransportType::Http),
874+
name: "test_http_server".to_string(),
875+
url: None,
876+
..Default::default()
877+
};
878+
879+
let result = add_args.into_custom_tool_config();
880+
assert!(result.is_err());
881+
assert!(
882+
result
883+
.unwrap_err()
884+
.to_string()
885+
.contains("Transport type is specified to be http but url is not provided")
886+
);
887+
888+
// Test STDIO transport without command should fail
889+
let add_args = AddArgs {
890+
r#type: Some(TransportType::Stdio),
891+
name: "test_stdio_server".to_string(),
892+
command: None,
893+
..Default::default()
894+
};
895+
896+
let result = add_args.into_custom_tool_config();
897+
assert!(result.is_err());
898+
assert!(
899+
result
900+
.unwrap_err()
901+
.to_string()
902+
.contains("Transport type is specified to be stdio but command is not provided")
903+
);
904+
905+
// Test default (stdio) transport without command should fail
906+
let add_args = AddArgs {
907+
r#type: None,
908+
name: "test_default_server".to_string(),
909+
command: None,
910+
..Default::default()
911+
};
912+
913+
let result = add_args.into_custom_tool_config();
914+
assert!(result.is_err());
915+
assert!(
916+
result
917+
.unwrap_err()
918+
.to_string()
919+
.contains("Transport type is specified to be stdio but command is not provided")
920+
);
921+
}
797922
}

0 commit comments

Comments
 (0)