diff --git a/src/transport/tls.rs b/src/transport/tls.rs index 52bd611..b8cf768 100644 --- a/src/transport/tls.rs +++ b/src/transport/tls.rs @@ -29,6 +29,34 @@ pub struct TlsConfig { pub client_key: Option>, // Root CA certificates in PEM format pub ca_certs: Option>, + // SNI hostname for TLS client connections (overrides the hostname derived from the remote address) + pub sni_hostname: Option, +} + +fn parse_private_key(key_data: &[u8]) -> Result> { + // Try PKCS8 format first + let mut reader = std::io::BufReader::new(key_data); + let keys = rustls_pemfile::pkcs8_private_keys(&mut reader) + .collect::, std::io::Error>>() + .map_err(|e| Error::Error(format!("Failed to parse PKCS8 key: {}", e)))?; + + if !keys.is_empty() { + let key_der = pki_types::PrivatePkcs8KeyDer::from(keys[0].clone_key()); + return Ok(pki_types::PrivateKeyDer::Pkcs8(key_der)); + } + + // Try PKCS1 format + let mut reader = std::io::BufReader::new(key_data); + let keys = rustls_pemfile::rsa_private_keys(&mut reader) + .collect::, std::io::Error>>() + .map_err(|e| Error::Error(format!("Failed to parse RSA key: {}", e)))?; + + if !keys.is_empty() { + let key_der = pki_types::PrivatePkcs1KeyDer::from(keys[0].clone_key()); + return Ok(pki_types::PrivateKeyDer::Pkcs1(key_der)); + } + + Err(Error::Error("No valid private key found".to_string())) } // TLS Listener Connection Structure @@ -151,31 +179,7 @@ impl TlsListenerConnection { // Load private key let key = match &config.key { - Some(key_data) => { - let mut reader = std::io::BufReader::new(key_data.as_slice()); - // Try PKCS8 format first - let keys = rustls_pemfile::pkcs8_private_keys(&mut reader) - .collect::, std::io::Error>>() - .map_err(|e| Error::Error(format!("Failed to parse PKCS8 key: {}", e)))?; - - if !keys.is_empty() { - let key_der = pki_types::PrivatePkcs8KeyDer::from(keys[0].clone_key()); - pki_types::PrivateKeyDer::Pkcs8(key_der) - } else { - // Try PKCS1 format - let mut reader = std::io::BufReader::new(key_data.as_slice()); - let keys = rustls_pemfile::rsa_private_keys(&mut reader) - .collect::, std::io::Error>>() - .map_err(|e| Error::Error(format!("Failed to parse RSA key: {}", e)))?; - - if !keys.is_empty() { - let key_der = pki_types::PrivatePkcs1KeyDer::from(keys[0].clone_key()); - pki_types::PrivateKeyDer::Pkcs1(key_der) - } else { - return Err(Error::Error("No valid private key found".to_string())); - } - } - } + Some(key_data) => parse_private_key(key_data)?, None => return Err(Error::Error("No private key provided".to_string())), }; @@ -239,22 +243,61 @@ impl TlsConnection { // Connect to a remote TLS server pub async fn connect( remote_addr: &SipAddr, + tls_config: Option<&TlsConfig>, custom_verifier: Option>, cancel_token: Option, ) -> Result { - let root_store = RootCertStore::empty(); - - let mut config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); + let mut root_store = RootCertStore::empty(); + + // Load CA certificates if provided + if let Some(ca_data) = tls_config.and_then(|c| c.ca_certs.as_ref()) { + let mut reader = std::io::BufReader::new(ca_data.as_slice()); + let certs = rustls_pemfile::certs(&mut reader) + .collect::, std::io::Error>>() + .map_err(|e| Error::Error(format!("Failed to parse CA certificates: {}", e)))?; + for cert in certs { + root_store + .add(cert) + .map_err(|e| Error::Error(format!("Failed to add CA certificate: {}", e)))?; + } + } - match custom_verifier { - Some(verifier) => { - config.dangerous().set_certificate_verifier(verifier); + // Build client config with optional mutual TLS + let mut client_config = match ( + tls_config.and_then(|c| c.client_cert.as_ref()), + tls_config.and_then(|c| c.client_key.as_ref()), + ) { + (Some(cert_data), Some(key_data)) => { + let mut reader = std::io::BufReader::new(cert_data.as_slice()); + let certs = rustls_pemfile::certs(&mut reader) + .collect::, std::io::Error>>() + .map_err(|e| { + Error::Error(format!("Failed to parse client certificate: {}", e)) + })?; + let key = parse_private_key(key_data)?; + ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert(certs, key) + .map_err(|e| Error::Error(format!("Client auth configuration error: {}", e)))? } - None => {} + _ => ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(), + }; + + if let Some(verifier) = custom_verifier { + client_config.dangerous().set_certificate_verifier(verifier); } - let connector = TlsConnector::from(Arc::new(config)); + + // Prefer explicit SNI, otherwise use the remote host. + let domain_string = tls_config + .and_then(|c| c.sni_hostname.clone()) + .unwrap_or_else(|| match &remote_addr.addr.host { + rsip::host_with_port::Host::Domain(domain) => domain.to_string(), + rsip::host_with_port::Host::IpAddr(ip) => ip.to_string(), + }); + + let connector = TlsConnector::from(Arc::new(client_config)); let socket_addr = match &remote_addr.addr.host { rsip::host_with_port::Host::Domain(domain) => { @@ -267,11 +310,6 @@ impl TlsConnection { } }; - let domain_string = match &remote_addr.addr.host { - rsip::host_with_port::Host::Domain(domain) => domain.to_string(), - rsip::host_with_port::Host::IpAddr(ip) => ip.to_string(), - }; - let server_name = pki_types::ServerName::try_from(domain_string.as_str()) .map_err(|_| Error::Error(format!("Invalid DNS name: {}", domain_string)))? .to_owned(); diff --git a/src/transport/transport_layer.rs b/src/transport/transport_layer.rs index d1a1717..859c49c 100644 --- a/src/transport/transport_layer.rs +++ b/src/transport/transport_layer.rs @@ -1,4 +1,4 @@ -use super::tls::TlsConnection; +use super::tls::{TlsConfig, TlsConnection}; use super::websocket::WebSocketConnection; use super::{connection::TransportSender, sip_addr::SipAddr, tcp::TcpConnection, SipConnection}; use crate::resolver::SipResolver; @@ -110,6 +110,7 @@ pub struct TransportLayerInner { pub(crate) transport_rx: Mutex>, pub domain_resolver: Box, whitelist: RwLock>, + tls_config: RwLock>, } pub(crate) type TransportLayerInnerRef = Arc; @@ -133,6 +134,7 @@ impl TransportLayer { transport_rx: Mutex::new(Some(transport_rx)), domain_resolver, whitelist: RwLock::new(None), + tls_config: RwLock::new(None), }; Self { outbound: None, @@ -213,6 +215,16 @@ impl TransportLayer { pub fn clear_whitelist(&self) { self.inner.set_whitelist(None); } + + /// Set the TLS configuration used for future outbound TLS connections. + pub fn set_tls_config(&self, tls_config: TlsConfig) { + self.inner.set_tls_config(Some(tls_config)); + } + + /// Remove the TLS configuration used for future outbound TLS connections. + pub fn clear_tls_config(&self) { + self.inner.set_tls_config(None); + } } impl TransportLayerInner { @@ -242,6 +254,27 @@ impl TransportLayerInner { } } + fn set_tls_config(&self, tls_config: Option) { + match self.tls_config.write() { + Ok(mut guard) => { + *guard = tls_config; + } + Err(e) => { + warn!(error = ?e, "Failed to update tls config"); + } + } + } + + fn tls_config(&self) -> Option { + match self.tls_config.read() { + Ok(guard) => guard.clone(), + Err(e) => { + warn!(error = ?e, "Failed to read tls config"); + None + } + } + } + pub fn add_listener(&self, connection: SipConnection) { match self.listens.write() { Ok(mut listens) => { @@ -294,6 +327,14 @@ impl TransportLayerInner { key: Option<&TransactionKey>, ) -> Result<(SipConnection, SipAddr)> { let target = outbound.unwrap_or(destination); + let tls_config = self.tls_config(); + + // Capture the original domain name before DNS resolution for TLS SNI + let original_domain = match &target.addr.host { + rsip::Host::Domain(domain) => Some(domain.to_string()), + _ => None, + }; + let target = if matches!(target.addr.host, rsip::Host::Domain(_)) { &self.domain_resolver.resolve(target).await? } else { @@ -330,8 +371,14 @@ impl TransportLayerInner { SipConnection::Tcp(connection) } Some(rsip::transport::Transport::Tls) => { + // Build effective TLS config with SNI from the original domain + let mut effective_config = tls_config.clone().unwrap_or_default(); + if effective_config.sni_hostname.is_none() { + effective_config.sni_hostname = original_domain; + } let connection = TlsConnection::connect( target, + Some(&effective_config), None, Some(self.cancel_token.child_token()), )