Skip to content

feat(mcp): Add MCP sampling support #2239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/chat-cli/src/api_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub use endpoints::Endpoint;
pub use error::ApiClientError;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably just revert the changes in this file since it looks unintentional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is accidental from the rebase, will remove.

use parking_lot::Mutex;
pub use profile::list_available_profiles;

use serde_json::Map;
use tracing::{
debug,
Expand Down Expand Up @@ -446,6 +447,7 @@ impl ApiClient {

self.mock_client = Some(Arc::new(Mutex::new(mock.into_iter())));
}

}

fn timeout_config(database: &Database) -> TimeoutConfig {
Expand Down
32 changes: 31 additions & 1 deletion crates/chat-cli/src/cli/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ impl ChatArgs {
info!(?conversation_id, "Generated new conversation id");
let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::<Option<String>>();
let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::<Vec<String>>();
let mut tool_manager = ToolManagerBuilder::default()
let (mut tool_manager, sampling_receiver) = ToolManagerBuilder::default()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had left my question below on how sampling_receiver is used. That question aside, I think sampling_receiver can just be owned by ToolManager.

.prompt_list_sender(prompt_response_sender)
.prompt_list_receiver(prompt_request_receiver)
.conversation_id(&conversation_id)
Expand All @@ -293,6 +293,9 @@ impl ChatArgs {
.await?;
let tool_config = tool_manager.load_tools(os, &mut stderr).await?;

// Set the ApiClient for MCP clients that have sampling enabled
tool_manager.set_streaming_client(std::sync::Arc::new(os.client.clone()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also looking at the definition of the client ins Os, we have the following:

#[derive(Clone, Debug)]
pub struct ApiClient {
    client: CodewhispererClient,
    streaming_client: Option<CodewhispererStreamingClient>,
    sigv4_streaming_client: Option<QDeveloperStreamingClient>,
    mock_client: Option<Arc<Mutex<std::vec::IntoIter<Vec<ChatResponseStream>>>>>,
    profile: Option<AuthProfile>,
}

CodewhispererClient is defined as follows:

/// Client for Amazon CodeWhisperer
///
/// Client for invoking operations on Amazon CodeWhisperer. Each operation on Amazon CodeWhisperer
/// is a method on this this struct. `.send()` MUST be invoked on the generated operations to
/// dispatch the request to the service.
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
pub struct Client {
    handle: ::std::sync::Arc<Handle>,
}

CodewhispererStreamingClient is defined as follows:

/// Client for Amazon CodeWhisperer Streaming
///
/// Client for invoking operations on Amazon CodeWhisperer Streaming. Each operation on Amazon
/// CodeWhisperer Streaming is a method on this this struct. `.send()` MUST be invoked on the
/// generated operations to dispatch the request to the service.
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
pub struct Client {
    handle: ::std::sync::Arc<Handle>,
}

QDeveloperStreamingClient:

/// Client for Amazon Q Developer Streaming
///
/// Client for invoking operations on Amazon Q Developer Streaming. Each operation on Amazon Q
/// Developer Streaming is a method on this this struct. `.send()` MUST be invoked on the generated
/// operations to dispatch the request to the service.
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
pub struct Client {
    handle: ::std::sync::Arc<Handle>,
}

All of which are already wrapped in Arcs. Unless there is some other reasons that I am not seeing I am not entirely convinced that we would need to introduce another indirection here.


ChatSession::new(
os,
stdout,
Expand All @@ -307,6 +310,7 @@ impl ChatArgs {
model_id,
tool_config,
!self.no_interactive,
sampling_receiver,
)
.await?
.spawn(os)
Expand Down Expand Up @@ -480,6 +484,8 @@ pub struct ChatSession {
conversation: ConversationState,
tool_uses: Vec<QueuedTool>,
pending_tool_index: Option<usize>,
/// Channel receiver for incoming sampling requests from MCP servers
sampling_receiver: tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// Telemetry events to be sent as part of the conversation.
tool_use_telemetry_events: HashMap<String, ToolUseEventBuilder>,
/// State used to keep track of tool use relation
Expand Down Expand Up @@ -508,6 +514,7 @@ impl ChatSession {
model_id: Option<String>,
tool_config: HashMap<String, ToolSpec>,
interactive: bool,
sampling_receiver: tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to just bring these types in with use statements.

) -> Result<Self> {
let valid_model_id = match model_id {
Some(id) => id,
Expand Down Expand Up @@ -584,6 +591,7 @@ impl ChatSession {
conversation,
tool_uses: vec![],
pending_tool_index: None,
sampling_receiver,
tool_use_telemetry_events: HashMap::new(),
tool_use_status: ToolUseStatus::Idle,
failed_request_ids: Vec::new(),
Expand Down Expand Up @@ -1267,6 +1275,17 @@ impl ChatSession {
.put_skim_command_selector(os, Arc::new(context_manager.clone()), tool_names);
}

// Check for incoming sampling requests and automatically approve them
// Since servers now opt-in via configuration, any request that comes through should be approved
while let Ok(mut sampling_request) = self.sampling_receiver.try_recv() {
tracing::info!(target: "mcp", "Auto-approving sampling request from configured server: {}", sampling_request.server_name);

// Automatically approve the sampling request
sampling_request.send_approval_result(
crate::mcp_client::sampling_ipc::SamplingApprovalResult::approved()
);
}
Comment on lines +1280 to +1287
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure what this is doing. Some questions:

  • This is sending back an approval every time?
  • What is receiving what send_approval_result is sending?
  • This is being called in the main chat loop under prompt_user, which means this will only get run when user is to be prompted (i.e. after Q CLI has responded). Correct me if I am wrong here, but I think sampling would occur during a tool call (i.e. before the user is to have their turn again). To me this is not the right point in the logic flow to check and respond for a sampling request. Let me know if I had misunderstood anything.

Copy link
Contributor Author

@swapneils swapneils Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sampling would occur during a tool call (i.e. before the user is to have their turn again).

The docs don't mention anything either way (though the dataflow diagram doesn't mention an ongoing tool-call when the server initiates sampling), but to my understanding this is incorrect.

MCP servers support long-running operations (see for instance FastMCP's doc on progress reporting, which is meant to support such operations); it seems strange to support sampling for tool-calls but not such long-running operations.

There is also at least one real customer use-case on the MCP discussions where it seems the server proactively starts sampling requests even though the client isn't explicitly expecting them.

What is receiving what send_approval_result is sending?

This was missed in a refactoring, as mentioned in this other thread transport_ref is handling sending data to the server now so we can probably get rid of the sampling_result call. Will test that this works when refactoring.

This is sending back an approval every time?

We're filtering whether a server has access to sampling by whether we provide a sampling callback at all, since right now it's a binary trust/don't-trust situation. If a server is able to make a sampling request and reach this section, it should be allowed.

Supporting manual approve/deny is difficult from a CLI interface given we only have one input-stream (and we can't rely on things like popups working without breaking customer orchestrations, either).

Adding approvals only during tool calls and then rejecting untrusted sampling requests outside that context is maybe doable, but I don't think it should block an initial implementation of the feature.

EDIT: The code-as-written of the python SDK also seems to support sampling at arbitrary times with a similar callback-based approach as the one we're using here (assuming I'm reading it correctly).


execute!(
self.stderr,
style::SetForegroundColor(Color::Reset),
Expand Down Expand Up @@ -2360,6 +2379,12 @@ mod tests {
agents
}

#[cfg(test)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire module is already configured to only be compiled during test. No need to annotate again here.

fn create_dummy_sampling_receiver() -> tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest> {
let (_sender, receiver) = tokio::sync::mpsc::unbounded_channel();
receiver
}

#[tokio::test]
async fn test_flow() {
let mut os = Os::new().await.unwrap();
Expand Down Expand Up @@ -2403,6 +2428,7 @@ mod tests {
None,
tool_config,
true,
create_dummy_sampling_receiver(),
)
.await
.unwrap()
Expand Down Expand Up @@ -2544,6 +2570,7 @@ mod tests {
None,
tool_config,
true,
create_dummy_sampling_receiver(),
)
.await
.unwrap()
Expand Down Expand Up @@ -2640,6 +2667,7 @@ mod tests {
None,
tool_config,
true,
create_dummy_sampling_receiver(),
)
.await
.unwrap()
Expand Down Expand Up @@ -2714,6 +2742,7 @@ mod tests {
None,
tool_config,
true,
create_dummy_sampling_receiver(),
)
.await
.unwrap()
Expand Down Expand Up @@ -2764,6 +2793,7 @@ mod tests {
None,
tool_config,
true,
create_dummy_sampling_receiver(),
)
.await
.unwrap()
Expand Down
33 changes: 29 additions & 4 deletions crates/chat-cli/src/cli/chat/tool_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ impl ToolManagerBuilder {
os: &mut Os,
mut output: Box<dyn Write + Send + Sync + 'static>,
interactive: bool,
) -> eyre::Result<ToolManager> {
) -> eyre::Result<(ToolManager, tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>)> {
let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?;
debug_assert!(self.conversation_id.is_some());
let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?;
Expand All @@ -199,6 +199,9 @@ impl ToolManagerBuilder {
.map(|(server_name, _)| server_name.clone())
.collect();

// Create channel for sampling requests
let (sampling_sender, sampling_receiver) = tokio::sync::mpsc::unbounded_channel();

let pre_initialized = enabled_servers
.into_iter()
.filter_map(|(server_name, server_config)| {
Expand All @@ -211,7 +214,11 @@ impl ToolManagerBuilder {
);
None
} else {
let custom_tool_client = CustomToolClient::from_config(server_name.clone(), server_config);
let custom_tool_client = CustomToolClient::from_config(
server_name.clone(),
server_config,
Some(sampling_sender.clone()),
);
Some((server_name, custom_tool_client))
}
})
Expand Down Expand Up @@ -687,7 +694,7 @@ impl ToolManagerBuilder {
});
}

Ok(ToolManager {
let tool_manager = ToolManager {
conversation_id,
clients,
prompts,
Expand All @@ -701,8 +708,11 @@ impl ToolManagerBuilder {
mcp_load_record: load_record,
agent,
disabled_servers: disabled_servers_display,
sampling_request_sender: Some(sampling_sender),
..Default::default()
})
};

Ok((tool_manager, sampling_receiver))
}
}

Expand Down Expand Up @@ -829,6 +839,9 @@ pub struct ToolManager {
/// The value is the load message (i.e. load time, warnings, and errors)
pub mcp_load_record: Arc<Mutex<HashMap<String, Vec<LoadingRecord>>>>,

/// Channel sender for MCP clients to send sampling requests for approval
pub sampling_request_sender: Option<tokio::sync::mpsc::UnboundedSender<crate::mcp_client::sampling_ipc::PendingSamplingRequest>>,

/// List of disabled MCP server names for display purposes
disabled_servers: Vec<String>,

Expand All @@ -850,12 +863,24 @@ impl Clone for ToolManager {
is_interactive: self.is_interactive,
mcp_load_record: self.mcp_load_record.clone(),
disabled_servers: self.disabled_servers.clone(),
sampling_request_sender: self.sampling_request_sender.clone(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this field is used. Why does ToolManager need to own an instance of this mpsc sender?

..Default::default()
}
}
}

impl ToolManager {
/// Set the ApiClient for all MCP clients that have sampling enabled
pub fn set_streaming_client(&self, api_client: std::sync::Arc<crate::api_client::ApiClient>) {
tracing::info!(target: "mcp", "Setting ApiClient for MCP clients with sampling enabled");

for (server_name, client) in &self.clients {
// Use the shared reference to call set_streaming_client
client.set_streaming_client(api_client.clone());
tracing::debug!(target: "mcp", "Set ApiClient for server: {}", server_name);
}
}

pub async fn load_tools(
&mut self,
os: &mut Os,
Expand Down
23 changes: 22 additions & 1 deletion crates/chat-cli/src/cli/chat/tools/custom_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct CustomToolConfig {
pub timeout: u64,
#[serde(default)]
pub disabled: bool,
#[serde(default)]
pub sampling: bool,
}

pub fn default_timeout() -> u64 {
Expand All @@ -66,14 +68,32 @@ pub enum CustomToolClient {

impl CustomToolClient {
// TODO: add support for http transport
pub fn from_config(server_name: String, config: CustomToolConfig) -> Result<Self> {
/// Set the ApiClient for LLM integration in sampling requests
pub fn set_streaming_client(&self, api_client: std::sync::Arc<crate::api_client::ApiClient>) {
match self {
CustomToolClient::Stdio { client, .. } => {
client.set_streaming_client(api_client);
}
}
}

pub fn from_config(
server_name: String,
config: CustomToolConfig,
sampling_sender: Option<tokio::sync::mpsc::UnboundedSender<crate::mcp_client::sampling_ipc::PendingSamplingRequest>>,
) -> Result<Self> {
let CustomToolConfig {
command,
args,
env,
timeout,
disabled: _,
sampling,
} = config;

// Only pass sampling_sender if sampling is enabled for this server
let conditional_sampling_sender = if sampling { sampling_sender } else { None };

let mcp_client_config = McpClientConfig {
server_name: server_name.clone(),
bin_path: command.clone(),
Expand All @@ -84,6 +104,7 @@ impl CustomToolClient {
"version": "1.0.0"
}),
env,
sampling_sender: conditional_sampling_sender,
};
let client = McpClient::<JsonRpcStdioTransport>::from_config(mcp_client_config)?;
Ok(CustomToolClient::Stdio {
Expand Down
Loading