diff --git a/src/attested_get.rs b/src/attested_get.rs index 3b4575f..7fc40b9 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -55,7 +55,7 @@ async fn attested_get_with_client( mod tests { use super::*; use crate::{ - ProxyServer, + OuterTlsConfig, OuterTlsMode, ProxyServer, attestation::AttestationType, file_server::static_file_server, test_helpers::{generate_certificate_chain_for_host, generate_tls_config}, @@ -77,13 +77,19 @@ mod tests { let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); // Setup a proxy server targetting the static file server - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: "localhost".to_string(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); diff --git a/src/file_server.rs b/src/file_server.rs index 424008e..4c5c9bb 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -1,5 +1,8 @@ //! Static HTTP file server provided by an attested TLS proxy server -use crate::{AttestationGenerator, AttestationVerifier, ProxyError, ProxyServer, TlsCertAndKey}; +use crate::{ + AttestationGenerator, AttestationVerifier, OuterTlsConfig, OuterTlsMode, ProxyError, + ProxyServer, TlsCertAndKey, +}; use std::{net::SocketAddr, path::PathBuf}; use tokio::net::ToSocketAddrs; use tower_http::services::ServeDir; @@ -7,17 +10,28 @@ use tower_http::services::ServeDir; /// Setup a static file server serving the given directory, and a proxy server targetting it pub async fn attested_file_server( path_to_serve: PathBuf, - cert_and_key: TlsCertAndKey, - listen_addr: impl ToSocketAddrs, + outer_cert_and_key: Option, + outer_listen_addr: Option, + inner_listen_addr: Option, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, ) -> Result<(), ProxyError> { let target_addr = static_file_server(path_to_serve).await?; + let outer_session = match (outer_cert_and_key, outer_listen_addr) { + (Some(cert_and_key), Some(listen_addr)) => Some(OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), + (Some(_), None) | (None, Some(_)) => { + return Err(ProxyError::NoListenersConfigured); + } + (None, None) => None, + }; let server = ProxyServer::new( - cert_and_key, - listen_addr, + outer_session, + inner_listen_addr, target_addr.to_string(), attestation_generator, attestation_verifier, @@ -52,7 +66,7 @@ pub(crate) async fn static_file_server(path: PathBuf) -> Resultfoo"); - let (body, content_type) = get_body_and_content_type( - format!("http://{}/data.bin", proxy_client_addr.to_string()), - &client, - ) - .await; + let (body, content_type) = + get_body_and_content_type(format!("http://{}/data.bin", proxy_client_addr), &client) + .await; assert_eq!(content_type, "application/octet-stream"); assert_eq!(body, [0u8; 32]); } diff --git a/src/lib.rs b/src/lib.rs index cf4c327..88aa200 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ use thiserror::Error; use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; +use tokio_rustls::TlsAcceptor; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; use tokio_rustls::rustls::{ self, ClientConfig, RootCertStore, ServerConfig, @@ -46,12 +47,14 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; - type RequestWithResponseSender = ( http::Request, oneshot::Sender>, hyper::Error>>, ); +type OuterProxySession = (Arc, NestingTlsAcceptor); +type InnerProxySession = (Arc, TlsAcceptor); + /// TLS Credentials pub struct TlsCertAndKey { /// Der-encoded TLS certificate chain @@ -60,6 +63,81 @@ pub struct TlsCertAndKey { pub key: PrivateKeyDer<'static>, } +/// Configuration for the optional outer nested-TLS listener. +pub struct OuterTlsConfig { + /// The socket address to bind for the outer listener. + pub listen_addr: A, + /// How the outer TLS server configuration should be constructed. + pub tls: OuterTlsMode, +} + +/// TLS configuration sources for the outer nested-TLS listener. +pub enum OuterTlsMode { + /// Build the outer TLS server config from certificate and key material. + CertAndKey(TlsCertAndKey), + /// Use an already-constructed outer TLS server config. + Preconfigured { + /// The outer TLS server configuration to expose on the listener. + server_config: ServerConfig, + /// The server identity to embed into the inner attested certificate. + certificate_name: String, + }, +} + +impl OuterTlsConfig +where + A: ToSocketAddrs, +{ + fn certificate_name(&self) -> Result { + match &self.tls { + OuterTlsMode::CertAndKey(cert_and_key) => { + Ok(certificate_identity_from_chain(&cert_and_key.cert_chain)?) + } + OuterTlsMode::Preconfigured { + certificate_name, .. + } => Ok(certificate_name.clone()), + } + } + + async fn into_listener_and_acceptor( + self, + inner_server_config: Arc, + client_auth: bool, + ) -> Result<(Arc, NestingTlsAcceptor), ProxyError> { + let listen_addr = self.listen_addr; + let outer_server_config = match self.tls { + OuterTlsMode::CertAndKey(cert_and_key) => { + if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(verifier) + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } + } + OuterTlsMode::Preconfigured { server_config, .. } => server_config, + }; + + let outer_listener = Arc::new(TcpListener::bind(listen_addr).await?); + let outer_tls_acceptor = + NestingTlsAcceptor::new(Arc::new(outer_server_config), inner_server_config); + + Ok((outer_listener, outer_tls_acceptor)) + } +} + /// Adds HTTP 1 and 2 to the list of allowed protocols fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { for protocol in [ALPN_H2, ALPN_HTTP11] { @@ -132,111 +210,65 @@ pub async fn get_inner_tls_cert_with_config( /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { - nesting_tls_acceptor: NestingTlsAcceptor, - /// The underlying TCP listener - listener: Arc, + outer: Option, + inner: Option, /// The address/hostname of the target service we are proxying to target: String, } impl ProxyServer { - pub async fn new( - cert_and_key: TlsCertAndKey, - local: impl ToSocketAddrs, + /// Start with dual listeners. The outer nested-TLS listener is optional. + pub async fn new( + outer_session: Option>, + inner_local: Option, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, - ) -> Result { - let outer_server_config = if client_auth { - let root_store = - RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_client_cert_verifier(verifier) - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - } else { - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - }; - - Self::new_with_tls_config_and_client_auth( - cert_and_key.cert_chain, - outer_server_config, - local, - target, - attestation_generator, - attestation_verifier, - client_auth, - ) - .await - } - - /// Start with preconfigured TLS - pub async fn new_with_tls_config( - cert_chain: Vec>, - outer_server_config: ServerConfig, - local: impl ToSocketAddrs, - target: String, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, - ) -> Result { - Self::new_with_tls_config_and_client_auth( - cert_chain, - outer_server_config, - local, - target, - attestation_generator, - attestation_verifier, - false, - ) - .await - } + ) -> Result + where + O: ToSocketAddrs, + I: ToSocketAddrs, + { + if outer_session.is_none() && inner_local.is_none() { + return Err(ProxyError::NoListenersConfigured); + } - /// Start with preconfigured TLS and require client auth on both nested sessions - pub async fn new_with_tls_config_and_client_auth( - cert_chain: Vec>, - outer_server_config: ServerConfig, - local: impl ToSocketAddrs, - target: String, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, - client_auth: bool, - ) -> Result { - let server_name = certificate_identity_from_chain(&cert_chain)?; - let inner_cert_resolver = - build_attested_cert_resolver(attestation_generator, &cert_chain).await?; - - let mut inner_server_config = if client_auth { - let attested_cert_verifier = - AttestedCertificateVerifier::new(None, attestation_verifier)?; - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_client_cert_verifier(Arc::new(attested_cert_verifier)) - .with_cert_resolver(Arc::new(inner_cert_resolver)) - } else { - let _ = server_name; - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_cert_resolver(Arc::new(inner_cert_resolver)) + let certificate_name = outer_session + .as_ref() + .map(OuterTlsConfig::certificate_name) + .transpose()?; + let inner_server_config = Arc::new( + build_inner_server_config( + attestation_generator, + attestation_verifier, + client_auth, + certificate_name, + ) + .await?, + ); + let inner = match inner_local { + Some(inner_local) => { + let inner_listener = Arc::new(TcpListener::bind(inner_local).await?); + let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone()); + Some((inner_listener, inner_tls_acceptor)) + } + None => None, }; - ensure_proxy_alpn_protocols(&mut inner_server_config.alpn_protocols); - - let nesting_tls_acceptor = - NestingTlsAcceptor::new(Arc::new(outer_server_config), Arc::new(inner_server_config)); - let listener = TcpListener::bind(local).await?; + let outer = match outer_session { + Some(outer_session) => { + let (outer_listener, outer_tls_acceptor) = outer_session + .into_listener_and_acceptor(inner_server_config.clone(), client_auth) + .await?; + Some((outer_listener, outer_tls_acceptor)) + } + None => None, + }; Ok(Self { - nesting_tls_acceptor, - listener: listener.into(), + outer, + inner, target, }) } @@ -246,33 +278,121 @@ impl ProxyServer { /// Returns the handle for the task handling the connection pub async fn accept(&self) -> Result, ProxyError> { let target = self.target.clone(); - let (inbound, client_addr) = self.listener.accept().await?; - let nesting_tls_acceptor = self.nesting_tls_acceptor.clone(); + let outer = self.outer.clone(); + let inner = self.inner.clone(); + + let join_handle = match (outer, inner) { + ( + Some((outer_listener, outer_tls_acceptor)), + Some((inner_listener, inner_tls_acceptor)), + ) => { + let ((inbound, client_addr), use_outer) = tokio::select! { + accepted = outer_listener.accept() => (accepted?, true), + accepted = inner_listener.accept() => (accepted?, false), + }; - let join_handle = tokio::spawn(async move { - match nesting_tls_acceptor.accept(inbound).await { - Ok(tls_stream) => { - if let Err(err) = Self::handle_connection(tls_stream, target, client_addr).await - { - warn!("Failed to handle connection: {err}"); + tokio::spawn(async move { + if use_outer { + match outer_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_outer_connection(tls_stream, target, client_addr) + .await + { + warn!("Failed to handle outer connection: {err}"); + } + } + Err(err) => { + warn!("Outer attestation exchange failed: {err}"); + } + } + } else { + match inner_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_inner_connection(tls_stream, target, client_addr) + .await + { + warn!("Failed to handle inner connection: {err}"); + } + } + Err(err) => { + warn!("Inner attestation exchange failed: {err}"); + } + } } - } - Err(err) => { - warn!("Attestation exchange failed: {err}"); - } + }) } - }); + (None, Some((inner_listener, inner_tls_acceptor))) => { + let (inbound, client_addr) = inner_listener.accept().await?; + tokio::spawn(async move { + match inner_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_inner_connection(tls_stream, target, client_addr).await + { + warn!("Failed to handle inner connection: {err}"); + } + } + Err(err) => { + warn!("Inner attestation exchange failed: {err}"); + } + } + }) + } + (Some((outer_listener, outer_tls_acceptor)), None) => { + let (inbound, client_addr) = outer_listener.accept().await?; + tokio::spawn(async move { + match outer_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_outer_connection(tls_stream, target, client_addr).await + { + warn!("Failed to handle outer connection: {err}"); + } + } + Err(err) => { + warn!("Outer attestation exchange failed: {err}"); + } + } + }) + } + _ => return Err(ProxyError::NoListenersConfigured), + }; Ok(join_handle) } - /// Helper to get the socket address of the underlying TCP listener + /// Helper to get the socket address of either underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.listener.local_addr() + match &self.outer { + Some((listener, _)) => listener.local_addr(), + None => self + .inner + .as_ref() + .map(|(listener, _)| listener) + .ok_or_else(|| std::io::Error::other("no listeners configured"))? + .local_addr(), + } } - /// Handle an incoming connection from a proxy-client - async fn handle_connection( + /// Helper to get the socket address of the underlying outer TCP listener if present + pub fn outer_local_addr(&self) -> std::io::Result> { + self.outer + .as_ref() + .map(|(listener, _)| listener.local_addr()) + .transpose() + } + + /// Helper to get the socket address of the underlying inner TCP listener if present + pub fn inner_local_addr(&self) -> std::io::Result> { + self.inner + .as_ref() + .map(|(listener, _)| listener.local_addr()) + .transpose() + } + + async fn handle_outer_connection( tls_stream: NestingTlsStream, target: String, client_addr: SocketAddr, @@ -280,7 +400,29 @@ impl ProxyServer { debug!("[proxy-server] accepted connection"); let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); + Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + } + async fn handle_inner_connection( + tls_stream: tokio_rustls::server::TlsStream, + target: String, + client_addr: SocketAddr, + ) -> Result<(), ProxyError> { + debug!("[proxy-server] accepted inner-only connection"); + + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); + Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + } + + async fn serve_tls_stream( + tls_stream: IO, + http_version: HttpVersion, + target: String, + client_addr: SocketAddr, + ) -> Result<(), ProxyError> + where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + { // Setup a request handler let service = service_fn(move |mut req| { debug!("[proxy-server] Handling request {req:?}"); @@ -457,8 +599,11 @@ impl ProxyClient { let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { - let inner_cert_resolver = - build_attested_cert_resolver(attestation_generator, cert_chain).await?; + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_identity_from_chain(cert_chain)?, + ) + .await?; ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .dangerous() .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) @@ -803,6 +948,8 @@ pub enum ProxyError { MpscSend, #[error("Client auth must be configured on both the inner and outer TLS sessions")] ClientAuthMisconfigured, + #[error("At least one server listener must be configured")] + NoListenersConfigured, } impl From> for ProxyError { @@ -835,15 +982,42 @@ fn certificate_identity_from_chain( async fn build_attested_cert_resolver( attestation_generator: AttestationGenerator, - cert_chain: &[CertificateDer<'static>], + certificate_name: String, ) -> Result { - let certificate_name = certificate_identity_from_chain(cert_chain)?; Ok( AttestedCertificateResolver::new(attestation_generator, None, certificate_name, vec![]) .await?, ) } +async fn build_inner_server_config( + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + client_auth: bool, + certificate_name: Option, +) -> Result { + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_name.unwrap_or_else(|| "localhost".to_string()), + ) + .await?; + + let mut inner_server_config = if client_auth { + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(Arc::new(attested_cert_verifier)) + .with_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_cert_resolver(Arc::new(inner_cert_resolver)) + }; + + ensure_proxy_alpn_protocols(&mut inner_server_config.alpn_protocols); + + Ok(inner_server_config) +} + /// If no port was provided, default to 443 pub(crate) fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { @@ -882,6 +1056,7 @@ where #[cfg(test)] mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; + use tokio_rustls::TlsConnector; use super::*; use test_helpers::{ @@ -906,25 +1081,82 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] - async fn http_proxy_negotiates_http2_by_default() { + async fn proxy_server_requires_at_least_one_listener() { + let result = ProxyServer::new( + None::>, + None::<&str>, + "127.0.0.1:1".to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::NoListenersConfigured))); + } + + #[tokio::test(flavor = "multi_thread")] + async fn dual_listener_server_reports_expected_addresses() { let target_addr = example_http_service().await; let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); - let (server_config, outer_client_config) = - generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( + let tls_cert_and_key = TlsCertAndKey { cert_chain, - server_config, - "127.0.0.1:0", + key: private_key, + }; + + let dual_listener_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::CertAndKey(tls_cert_and_key), + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let outer_addr = dual_listener_server.outer_local_addr().unwrap().unwrap(); + let inner_addr = dual_listener_server.inner_local_addr().unwrap().unwrap(); + assert_eq!(dual_listener_server.local_addr().unwrap(), outer_addr); + assert_ne!(outer_addr, inner_addr); + + let inner_only_server = ProxyServer::new( + None::>, + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let inner_only_addr = inner_only_server.inner_local_addr().unwrap().unwrap(); + assert!(inner_only_server.outer_local_addr().unwrap().is_none()); + assert_eq!(inner_only_server.local_addr().unwrap(), inner_only_addr); + } + + #[tokio::test(flavor = "multi_thread")] + async fn inner_only_listener_negotiates_http2_by_default() { + let _ = rustls::crypto::ring::default_provider().install_default(); + let target_addr = example_http_service().await; + + let proxy_server = ProxyServer::new( + None::>, + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); - let proxy_addr = proxy_server.local_addr().unwrap(); + let inner_addr = proxy_server.inner_local_addr().unwrap().unwrap(); tokio::spawn(async move { proxy_server.accept().await.unwrap(); @@ -932,40 +1164,46 @@ mod tests { let attested_cert_verifier = AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap(); - let mut inner_client_config = + let mut client_config = ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .dangerous() .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) .with_no_client_auth(); - ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); - let nesting_tls_connector = - NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); + let tls_connector = TlsConnector::from(Arc::new(client_config)); + let outbound_stream = TcpStream::connect(inner_addr).await.unwrap(); + let domain = ServerName::try_from("localhost".to_string()).unwrap(); + let mut tls_stream = tls_connector + .connect(domain, outbound_stream) + .await + .unwrap(); - let (sender, conn) = ProxyClient::setup_connection( - &nesting_tls_connector, - &format!("localhost:{}", proxy_addr.port()), - ) - .await - .unwrap(); + assert!(matches!( + HttpVersion::from_negotiated_protocol_client(&tls_stream), + HttpVersion::Http2 + )); - assert!(matches!(sender, HttpSender::Http2(_))); - assert!(matches!(conn, HttpConnection::Http2 { .. })); + tls_stream.shutdown().await.unwrap(); } #[tokio::test(flavor = "multi_thread")] - async fn http_proxy_default_constructors_work() { + async fn http_proxy_negotiates_http2_by_default() { let target_addr = example_http_service().await; let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); - let server_cert = cert_chain[0].clone(); + let (server_config, outer_client_config) = + generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new( - TlsCertAndKey { - cert_chain, - key: private_key, - }, - "127.0.0.1:0", + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -980,29 +1218,27 @@ mod tests { proxy_server.accept().await.unwrap(); }); - let proxy_client = ProxyClient::new( - None, - "127.0.0.1:0".to_string(), - format!("localhost:{}", proxy_addr.port()), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - Some(server_cert), + let attested_cert_verifier = + AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap(); + let mut inner_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + + let nesting_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); + + let (sender, conn) = ProxyClient::setup_connection( + &nesting_tls_connector, + &format!("localhost:{}", proxy_addr.port()), ) .await .unwrap(); - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_client.accept().await.unwrap(); - }); - - let res = reqwest::get(format!("http://{}", proxy_client_addr)) - .await - .unwrap(); - - let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); + assert!(matches!(sender, HttpSender::Http2(_))); + assert!(matches!(conn, HttpConnection::Http2 { .. })); } // Server has mock DCAP, client has no attestation and no client auth @@ -1014,13 +1250,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1048,7 +1290,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1076,10 +1318,15 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config_and_client_auth( - server_cert_chain, - server_tls_server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config: server_tls_server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), @@ -1129,13 +1376,19 @@ mod tests { let (server_config, client_config) = generate_tls_config(server_cert_chain.clone(), server_private_key); - let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), + false, ) .await .unwrap(); @@ -1166,7 +1419,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1193,10 +1446,15 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config_and_client_auth( - server_cert_chain, - server_tls_server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config: server_tls_server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), @@ -1248,13 +1506,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain.clone(), - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1290,13 +1554,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1330,13 +1600,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1395,13 +1671,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1441,7 +1723,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1449,7 +1731,7 @@ mod tests { connection_breaker_tx.send(()).unwrap(); // Make another request - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1468,13 +1750,19 @@ mod tests { server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1502,7 +1790,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); diff --git a/src/main.rs b/src/main.rs index c14c414..a80a54b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,9 +7,9 @@ use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ - AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey, attested_get::attested_get, - file_server::attested_file_server, get_inner_tls_cert, health_check, - normalize_pem::normalize_private_key_pem_to_pkcs8, + AttestationGenerator, OuterTlsConfig, OuterTlsMode, ProxyClient, ProxyServer, TlsCertAndKey, + attested_get::attested_get, file_server::attested_file_server, get_inner_tls_cert, + health_check, normalize_pem::normalize_private_key_pem_to_pkcs8, }; const GIT_REV: &str = match option_env!("GIT_REV") { @@ -77,19 +77,22 @@ enum CliCommand { }, /// Run a proxy server Server { - /// Socket address to listen on - #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] - listen_addr: SocketAddr, + /// Socket address to listen on for the outer nested-TLS listener, if enabled + #[arg(long)] + outer_listen_addr: Option, + /// Socket address to listen on for the inner-only attested TLS listener + #[arg(long)] + inner_listen_addr: Option, /// The hostname:port or ip:port of the target service to forward traffic to target_addr: String, /// Type of attestation to present (dafaults to 'auto' for automatic detection) - /// If other than None, a TLS key and certicate must also be given + /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, - /// The path to a PEM encoded private key + /// The path to a PEM encoded private key for the optional outer nested-TLS listener #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: Option, - /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. + /// PEM certificate chain for the optional outer nested-TLS listener #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: Option, /// Whether to use client authentication. If the client is running in a CVM this must be @@ -119,19 +122,22 @@ enum CliCommand { AttestedFileServer { /// Filesystem path to statically serve path_to_serve: PathBuf, - /// Socket address to listen on - #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] - listen_addr: SocketAddr, + /// Socket address to listen on for the outer nested-TLS listener, if enabled + #[arg(long)] + outer_listen_addr: Option, + /// Socket address to listen on for the inner-only attested TLS listener + #[arg(long)] + inner_listen_addr: Option, /// Type of attestation to present (dafaults to none) - /// If other than None, a TLS key and certicate must also be given + /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, - /// The path to a PEM encoded private key + /// The path to a PEM encoded private key for the optional outer nested-TLS listener #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] - tls_private_key_path: PathBuf, - /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. + tls_private_key_path: Option, + /// PEM certificate chain for the optional outer nested-TLS listener #[arg(long, env = "TLS_CERTIFICATE_PATH")] - tls_certificate_path: PathBuf, + tls_certificate_path: Option, /// URL of the remote dummy attestation service. Only use with --server-attestation-type /// dummy #[arg(long)] @@ -277,7 +283,8 @@ async fn main() -> anyhow::Result<()> { } } CliCommand::Server { - listen_addr, + outer_listen_addr, + inner_listen_addr, target_addr, tls_private_key_path, tls_certificate_path, @@ -292,14 +299,24 @@ async fn main() -> anyhow::Result<()> { let tls_cert_and_chain = load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; + validate_listener_args( + inner_listen_addr, + outer_listen_addr, + tls_cert_and_chain.is_some(), + )?; let local_attestation_generator = AttestationGenerator::new_with_detection(server_attestation_type, dev_dummy_dcap) .await?; let server = ProxyServer::new( - tls_cert_and_chain, - listen_addr, + tls_cert_and_chain + .zip(outer_listen_addr) + .map(|(cert_and_key, listen_addr)| OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), + inner_listen_addr, target_addr, local_attestation_generator, attestation_verifier, @@ -344,14 +361,20 @@ async fn main() -> anyhow::Result<()> { } CliCommand::AttestedFileServer { path_to_serve, - listen_addr, + outer_listen_addr, + inner_listen_addr, server_attestation_type, tls_private_key_path, tls_certificate_path, dev_dummy_dcap, } => { let tls_cert_and_chain = - load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?; + load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; + validate_listener_args( + inner_listen_addr, + outer_listen_addr, + tls_cert_and_chain.is_some(), + )?; let server_attestation_type: AttestationType = serde_json::from_value( serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())), @@ -363,7 +386,8 @@ async fn main() -> anyhow::Result<()> { attested_file_server( path_to_serve, tls_cert_and_chain, - listen_addr, + outer_listen_addr, + inner_listen_addr, attestation_generator, attestation_verifier, false, @@ -410,19 +434,43 @@ async fn main() -> anyhow::Result<()> { fn load_tls_cert_and_key_server( cert_chain: Option, private_key: Option, -) -> anyhow::Result { - if let Some(private_key) = private_key { - load_tls_cert_and_key( - cert_chain.ok_or(anyhow!("Private key given but no certificate chain"))?, - private_key, - ) - } else if cert_chain.is_some() { - Err(anyhow!("Certificate chain provided but no private key")) - } else { - Err(anyhow!("No private key provided")) +) -> anyhow::Result> { + match (cert_chain, private_key) { + (Some(cert_chain), Some(private_key)) => { + Ok(Some(load_tls_cert_and_key(cert_chain, private_key)?)) + } + (Some(_), None) => Err(anyhow!("Certificate chain provided but no private key")), + (None, Some(_)) => Err(anyhow!("Private key given but no certificate chain")), + (None, None) => Ok(None), } } +fn validate_listener_args( + inner_listen_addr: Option, + outer_listen_addr: Option, + has_outer_tls: bool, +) -> anyhow::Result<()> { + if inner_listen_addr.is_none() && outer_listen_addr.is_none() { + return Err(anyhow!( + "At least one of --inner-listen-addr or --outer-listen-addr must be provided" + )); + } + + if has_outer_tls && outer_listen_addr.is_none() { + return Err(anyhow!( + "--outer-listen-addr is required when TLS certificate and key are provided" + )); + } + + if !has_outer_tls && outer_listen_addr.is_some() { + return Err(anyhow!( + "--outer-listen-addr requires TLS certificate and key" + )); + } + + Ok(()) +} + /// Load TLS details from storage fn load_tls_cert_and_key( cert_chain: PathBuf,