-
Notifications
You must be signed in to change notification settings - Fork 241
feat(mcp): Add MCP sampling support #2239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -284,7 +284,7 @@ impl ChatArgs { | |
info!(?conversation_id, "Generated new conversation id"); | ||
let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::<Option<String>>(); | ||
let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::<Vec<String>>(); | ||
let mut tool_manager = ToolManagerBuilder::default() | ||
let (mut tool_manager, sampling_receiver) = ToolManagerBuilder::default() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had left my question below on how |
||
.prompt_list_sender(prompt_response_sender) | ||
.prompt_list_receiver(prompt_request_receiver) | ||
.conversation_id(&conversation_id) | ||
|
@@ -293,6 +293,9 @@ impl ChatArgs { | |
.await?; | ||
let tool_config = tool_manager.load_tools(os, &mut stderr).await?; | ||
|
||
// Set the ApiClient for MCP clients that have sampling enabled | ||
tool_manager.set_streaming_client(std::sync::Arc::new(os.client.clone())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also looking at the definition of the client ins #[derive(Clone, Debug)]
pub struct ApiClient {
client: CodewhispererClient,
streaming_client: Option<CodewhispererStreamingClient>,
sigv4_streaming_client: Option<QDeveloperStreamingClient>,
mock_client: Option<Arc<Mutex<std::vec::IntoIter<Vec<ChatResponseStream>>>>>,
profile: Option<AuthProfile>,
}
/// Client for Amazon CodeWhisperer
///
/// Client for invoking operations on Amazon CodeWhisperer. Each operation on Amazon CodeWhisperer
/// is a method on this this struct. `.send()` MUST be invoked on the generated operations to
/// dispatch the request to the service.
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
pub struct Client {
handle: ::std::sync::Arc<Handle>,
}
/// Client for Amazon CodeWhisperer Streaming
///
/// Client for invoking operations on Amazon CodeWhisperer Streaming. Each operation on Amazon
/// CodeWhisperer Streaming is a method on this this struct. `.send()` MUST be invoked on the
/// generated operations to dispatch the request to the service.
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
pub struct Client {
handle: ::std::sync::Arc<Handle>,
}
/// Client for Amazon Q Developer Streaming
///
/// Client for invoking operations on Amazon Q Developer Streaming. Each operation on Amazon Q
/// Developer Streaming is a method on this this struct. `.send()` MUST be invoked on the generated
/// operations to dispatch the request to the service.
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
pub struct Client {
handle: ::std::sync::Arc<Handle>,
} All of which are already wrapped in |
||
|
||
ChatSession::new( | ||
os, | ||
stdout, | ||
|
@@ -307,6 +310,7 @@ impl ChatArgs { | |
model_id, | ||
tool_config, | ||
!self.no_interactive, | ||
sampling_receiver, | ||
) | ||
.await? | ||
.spawn(os) | ||
|
@@ -480,6 +484,8 @@ pub struct ChatSession { | |
conversation: ConversationState, | ||
tool_uses: Vec<QueuedTool>, | ||
pending_tool_index: Option<usize>, | ||
/// Channel receiver for incoming sampling requests from MCP servers | ||
sampling_receiver: tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
/// Telemetry events to be sent as part of the conversation. | ||
tool_use_telemetry_events: HashMap<String, ToolUseEventBuilder>, | ||
/// State used to keep track of tool use relation | ||
|
@@ -508,6 +514,7 @@ impl ChatSession { | |
model_id: Option<String>, | ||
tool_config: HashMap<String, ToolSpec>, | ||
interactive: bool, | ||
sampling_receiver: tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to just bring these types in with use statements. |
||
) -> Result<Self> { | ||
let valid_model_id = match model_id { | ||
Some(id) => id, | ||
|
@@ -584,6 +591,7 @@ impl ChatSession { | |
conversation, | ||
tool_uses: vec![], | ||
pending_tool_index: None, | ||
sampling_receiver, | ||
tool_use_telemetry_events: HashMap::new(), | ||
tool_use_status: ToolUseStatus::Idle, | ||
failed_request_ids: Vec::new(), | ||
|
@@ -1267,6 +1275,17 @@ impl ChatSession { | |
.put_skim_command_selector(os, Arc::new(context_manager.clone()), tool_names); | ||
} | ||
|
||
// Check for incoming sampling requests and automatically approve them | ||
// Since servers now opt-in via configuration, any request that comes through should be approved | ||
while let Ok(mut sampling_request) = self.sampling_receiver.try_recv() { | ||
tracing::info!(target: "mcp", "Auto-approving sampling request from configured server: {}", sampling_request.server_name); | ||
|
||
// Automatically approve the sampling request | ||
sampling_request.send_approval_result( | ||
crate::mcp_client::sampling_ipc::SamplingApprovalResult::approved() | ||
); | ||
} | ||
Comment on lines
+1280
to
+1287
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not entirely sure what this is doing. Some questions:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The docs don't mention anything either way (though the dataflow diagram doesn't mention an ongoing tool-call when the server initiates sampling), but to my understanding this is incorrect. MCP servers support long-running operations (see for instance FastMCP's doc on progress reporting, which is meant to support such operations); it seems strange to support sampling for tool-calls but not such long-running operations. There is also at least one real customer use-case on the MCP discussions where it seems the server proactively starts sampling requests even though the client isn't explicitly expecting them.
This was missed in a refactoring, as mentioned in this other thread transport_ref is handling sending data to the server now so we can probably get rid of the sampling_result call. Will test that this works when refactoring.
We're filtering whether a server has access to sampling by whether we provide a sampling callback at all, since right now it's a binary trust/don't-trust situation. If a server is able to make a sampling request and reach this section, it should be allowed. Supporting manual approve/deny is difficult from a CLI interface given we only have one input-stream (and we can't rely on things like popups working without breaking customer orchestrations, either). Adding approvals only during tool calls and then rejecting untrusted sampling requests outside that context is maybe doable, but I don't think it should block an initial implementation of the feature. EDIT: The code-as-written of the python SDK also seems to support sampling at arbitrary times with a similar callback-based approach as the one we're using here (assuming I'm reading it correctly). |
||
|
||
execute!( | ||
self.stderr, | ||
style::SetForegroundColor(Color::Reset), | ||
|
@@ -2360,6 +2379,12 @@ mod tests { | |
agents | ||
} | ||
|
||
#[cfg(test)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This entire module is already configured to only be compiled during test. No need to annotate again here. |
||
fn create_dummy_sampling_receiver() -> tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest> { | ||
let (_sender, receiver) = tokio::sync::mpsc::unbounded_channel(); | ||
receiver | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_flow() { | ||
let mut os = Os::new().await.unwrap(); | ||
|
@@ -2403,6 +2428,7 @@ mod tests { | |
None, | ||
tool_config, | ||
true, | ||
create_dummy_sampling_receiver(), | ||
) | ||
.await | ||
.unwrap() | ||
|
@@ -2544,6 +2570,7 @@ mod tests { | |
None, | ||
tool_config, | ||
true, | ||
create_dummy_sampling_receiver(), | ||
) | ||
.await | ||
.unwrap() | ||
|
@@ -2640,6 +2667,7 @@ mod tests { | |
None, | ||
tool_config, | ||
true, | ||
create_dummy_sampling_receiver(), | ||
) | ||
.await | ||
.unwrap() | ||
|
@@ -2714,6 +2742,7 @@ mod tests { | |
None, | ||
tool_config, | ||
true, | ||
create_dummy_sampling_receiver(), | ||
) | ||
.await | ||
.unwrap() | ||
|
@@ -2764,6 +2793,7 @@ mod tests { | |
None, | ||
tool_config, | ||
true, | ||
create_dummy_sampling_receiver(), | ||
) | ||
.await | ||
.unwrap() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -183,7 +183,7 @@ impl ToolManagerBuilder { | |
os: &mut Os, | ||
mut output: Box<dyn Write + Send + Sync + 'static>, | ||
interactive: bool, | ||
) -> eyre::Result<ToolManager> { | ||
) -> eyre::Result<(ToolManager, tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>)> { | ||
let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; | ||
debug_assert!(self.conversation_id.is_some()); | ||
let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; | ||
|
@@ -199,6 +199,9 @@ impl ToolManagerBuilder { | |
.map(|(server_name, _)| server_name.clone()) | ||
.collect(); | ||
|
||
// Create channel for sampling requests | ||
let (sampling_sender, sampling_receiver) = tokio::sync::mpsc::unbounded_channel(); | ||
|
||
let pre_initialized = enabled_servers | ||
.into_iter() | ||
.filter_map(|(server_name, server_config)| { | ||
|
@@ -211,7 +214,11 @@ impl ToolManagerBuilder { | |
); | ||
None | ||
} else { | ||
let custom_tool_client = CustomToolClient::from_config(server_name.clone(), server_config); | ||
let custom_tool_client = CustomToolClient::from_config( | ||
server_name.clone(), | ||
server_config, | ||
Some(sampling_sender.clone()), | ||
); | ||
Some((server_name, custom_tool_client)) | ||
} | ||
}) | ||
|
@@ -687,7 +694,7 @@ impl ToolManagerBuilder { | |
}); | ||
} | ||
|
||
Ok(ToolManager { | ||
let tool_manager = ToolManager { | ||
conversation_id, | ||
clients, | ||
prompts, | ||
|
@@ -701,8 +708,11 @@ impl ToolManagerBuilder { | |
mcp_load_record: load_record, | ||
agent, | ||
disabled_servers: disabled_servers_display, | ||
sampling_request_sender: Some(sampling_sender), | ||
..Default::default() | ||
}) | ||
}; | ||
|
||
Ok((tool_manager, sampling_receiver)) | ||
} | ||
} | ||
|
||
|
@@ -829,6 +839,9 @@ pub struct ToolManager { | |
/// The value is the load message (i.e. load time, warnings, and errors) | ||
pub mcp_load_record: Arc<Mutex<HashMap<String, Vec<LoadingRecord>>>>, | ||
|
||
/// Channel sender for MCP clients to send sampling requests for approval | ||
pub sampling_request_sender: Option<tokio::sync::mpsc::UnboundedSender<crate::mcp_client::sampling_ipc::PendingSamplingRequest>>, | ||
|
||
/// List of disabled MCP server names for display purposes | ||
disabled_servers: Vec<String>, | ||
|
||
|
@@ -850,12 +863,24 @@ impl Clone for ToolManager { | |
is_interactive: self.is_interactive, | ||
mcp_load_record: self.mcp_load_record.clone(), | ||
disabled_servers: self.disabled_servers.clone(), | ||
sampling_request_sender: self.sampling_request_sender.clone(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see how this field is used. Why does |
||
..Default::default() | ||
} | ||
} | ||
} | ||
|
||
impl ToolManager { | ||
/// Set the ApiClient for all MCP clients that have sampling enabled | ||
pub fn set_streaming_client(&self, api_client: std::sync::Arc<crate::api_client::ApiClient>) { | ||
tracing::info!(target: "mcp", "Setting ApiClient for MCP clients with sampling enabled"); | ||
|
||
for (server_name, client) in &self.clients { | ||
// Use the shared reference to call set_streaming_client | ||
client.set_streaming_client(api_client.clone()); | ||
tracing::debug!(target: "mcp", "Set ApiClient for server: {}", server_name); | ||
} | ||
} | ||
|
||
pub async fn load_tools( | ||
&mut self, | ||
os: &mut Os, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably just revert the changes in this file since it looks unintentional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is accidental from the rebase, will remove.