diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs index 7510099a1..6aef12c05 100644 --- a/tonic/src/transport/channel/service/tls.rs +++ b/tonic/src/transport/channel/service/tls.rs @@ -13,10 +13,13 @@ use tokio_rustls::{ }; use super::io::BoxedIo; -use crate::transport::service::tls::{ - convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2, -}; use crate::transport::tls::{Certificate, Identity}; +use crate::transport::{ + channel::tls::ModdifyConfigFn, + service::tls::{ + convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2, + }, +}; #[derive(Clone)] pub(crate) struct TlsConnector { @@ -34,6 +37,7 @@ impl TlsConnector { domain: &str, assume_http2: bool, use_key_log: bool, + modify_config: Option, #[cfg(feature = "tls-native-roots")] with_native_roots: bool, #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, ) -> Result { @@ -94,6 +98,11 @@ impl TlsConnector { } config.alpn_protocols.push(ALPN_H2.into()); + + if let Some(modify_config) = modify_config { + modify_config.0(&mut config); + } + Ok(Self { config: Arc::new(config), domain: Arc::new(ServerName::try_from(domain)?.to_owned()), diff --git a/tonic/src/transport/channel/tls.rs b/tonic/src/transport/channel/tls.rs index 945384fd2..9cc38695d 100644 --- a/tonic/src/transport/channel/tls.rs +++ b/tonic/src/transport/channel/tls.rs @@ -13,6 +13,7 @@ pub struct ClientTlsConfig { certs: Vec, trust_anchors: Vec>, identity: Option, + modify_config: Option, assume_http2: bool, #[cfg(feature = "tls-native-roots")] with_native_roots: bool, @@ -21,6 +22,17 @@ pub struct ClientTlsConfig { use_key_log: bool, } +#[derive(Clone)] +pub(crate) struct ModdifyConfigFn( + pub(crate) std::sync::Arc, +); + +impl std::fmt::Debug for ModdifyConfigFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ModdifyConfigFn").finish() + } +} + impl ClientTlsConfig { /// Creates a new `ClientTlsConfig` using Rustls. pub fn new() -> Self { @@ -121,6 +133,18 @@ impl ClientTlsConfig { config } + /// Adds a function to modify the `ClientConfig` before it is used. + pub fn modify_config(self, f: F) -> Self + where + F: Fn(&mut tokio_rustls::rustls::ClientConfig) + Send + Sync + 'static, + { + let modify_config = ModdifyConfigFn(std::sync::Arc::new(f)); + ClientTlsConfig { + modify_config: Some(modify_config), + ..self + } + } + pub(crate) fn into_tls_connector(self, uri: &Uri) -> Result { let domain = match &self.domain { Some(domain) => domain, @@ -133,6 +157,7 @@ impl ClientTlsConfig { domain, self.assume_http2, self.use_key_log, + self.modify_config, #[cfg(feature = "tls-native-roots")] self.with_native_roots, #[cfg(feature = "tls-webpki-roots")]