From c708817a8b82dbb2e2eb8c78491eb96b55d51dc0 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 17 Mar 2026 09:48:13 +0100 Subject: [PATCH 01/20] WIP use nested-attested-tls for proxy --- Cargo.lock | 765 ++++++++++++++++++++++++++++++++++++++++++++++++++--- Cargo.toml | 11 +- src/lib.rs | 296 ++++++++++----------- 3 files changed, 863 insertions(+), 209 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0459228..e01c501 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + [[package]] name = "aho-corasick" version = "1.1.4" @@ -11,6 +17,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "alloy-json-rpc" version = "1.6.3" @@ -242,9 +263,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.100" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "ark-ff" @@ -435,19 +456,41 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive 0.5.1", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "asn1-rs" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60" dependencies = [ - "asn1-rs-derive", + "asn1-rs-derive 0.6.0", "asn1-rs-impl", "displaydoc", "nom", @@ -457,6 +500,18 @@ dependencies = [ "time", ] +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", + "synstructure", +] + [[package]] name = "asn1-rs-derive" version = "0.6.0" @@ -512,7 +567,39 @@ dependencies = [ "az-tdx-vtpm", "base64 0.22.1", "configfs-tsm", - "dcap-qvl", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", + "hex", + "http", + "num-bigint", + "once_cell", + "openssl", + "parity-scale-codec", + "pem-rfc7468", + "rand_core 0.6.4", + "reqwest", + "rustls-webpki", + "serde", + "serde_json", + "tdx-quote", + "thiserror 2.0.17", + "time", + "tokio", + "tokio-rustls", + "tracing", + "tss-esapi", + "x509-parser 0.18.1", +] + +[[package]] +name = "attestation" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +dependencies = [ + "anyhow", + "az-tdx-vtpm", + "base64 0.22.1", + "configfs-tsm", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", "hex", "http", "num-bigint", @@ -532,7 +619,7 @@ dependencies = [ "tokio-rustls", "tracing", "tss-esapi", - "x509-parser", + "x509-parser 0.18.1", ] [[package]] @@ -557,7 +644,7 @@ version = "0.0.1" dependencies = [ "alloy-rpc-client", "alloy-transport-http", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fadd-attestation-crate)", "bytes", "futures-util", "http", @@ -565,7 +652,7 @@ dependencies = [ "hyper", "hyper-util", "parity-scale-codec", - "rcgen", + "rcgen 0.14.7", "serde_json", "sha2", "tempfile", @@ -577,7 +664,26 @@ dependencies = [ "tracing", "url", "webpki-roots", - "x509-parser", + "x509-parser 0.18.1", +] + +[[package]] +name = "attested-tls" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +dependencies = [ + "anyhow", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "ra-tls", + "rcgen 0.14.7", + "rustls", + "serde_json", + "sha2", + "thiserror 2.0.17", + "tokio", + "tracing", + "x509-parser 0.18.1", + "yasna 0.5.2", ] [[package]] @@ -585,7 +691,8 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attested-tls", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", "axum", "bytes", "clap", @@ -594,12 +701,13 @@ dependencies = [ "hyper", "hyper-util", "jsonrpsee", + "nested-tls", "p256", "pem-rfc7468", "pin-project-lite", "pkcs1", "pkcs8", - "rcgen", + "rcgen 0.14.7", "reqwest", "rsa", "rustls-pemfile", @@ -614,7 +722,7 @@ dependencies = [ "tracing", "tracing-subscriber", "webpki-roots", - "x509-parser", + "x509-parser 0.18.1", ] [[package]] @@ -855,6 +963,29 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest 0.10.7", +] + +[[package]] +name = "blake3" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -864,6 +995,31 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bon" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" +dependencies = [ + "darling", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.108", +] + [[package]] name = "borsh" version = "1.6.0" @@ -887,6 +1043,27 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "brotli" +version = "8.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -924,6 +1101,23 @@ dependencies = [ "shlex", ] +[[package]] +name = "cc-eventlog" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "digest 0.10.7", + "ez-hash", + "fs-err", + "hex", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -1004,6 +1198,18 @@ version = "0.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187437900921c8172f33316ad51a3267df588e99a2aebfa5ca1a2ed44df9e703" +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "const-hex" version = "1.17.0" @@ -1042,6 +1248,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + [[package]] name = "convert_case" version = "0.10.0" @@ -1086,6 +1298,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -1170,6 +1391,40 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.108", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.108", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1179,7 +1434,44 @@ checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" [[package]] name = "dcap-qvl" version = "0.3.12" -source = "git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override#b61d8f3ffb59f225d7b98220e2185a66f1c7f8c7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e7842b81018f3b991dc65ec0a95ff347332de58478c4ac43459095af00cc89" +dependencies = [ + "anyhow", + "asn1_der", + "base64 0.22.1", + "borsh", + "byteorder", + "chrono", + "const-oid", + "dcap-qvl-webpki", + "der", + "derive_more 2.1.1", + "futures", + "hex", + "log", + "p256", + "parity-scale-codec", + "pem", + "reqwest", + "ring", + "rustls-pki-types", + "scale-info", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "signature", + "tracing", + "urlencoding", + "wasm-bindgen-futures", + "x509-cert", +] + +[[package]] +name = "dcap-qvl" +version = "0.3.12" +source = "git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override#e38818f0b7b600ceadad1ec3efd9e681bcbdc1e5" dependencies = [ "anyhow", "asn1_der", @@ -1243,13 +1535,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs 0.6.2", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "der-parser" version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", "displaydoc", "nom", "num-bigint", @@ -1370,7 +1676,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -1384,6 +1690,44 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "dstack-attest" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "cc-eventlog", + "dcap-qvl 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", + "dstack-types", + "errify", + "ez-hash", + "fs-err", + "hex", + "hex_fmt", + "insta", + "or-panic", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "sha3", + "tdx-attest", + "tracing", +] + +[[package]] +name = "dstack-types" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "parity-scale-codec", + "serde", + "serde-human-bytes", + "sha3", + "size-parser", +] + [[package]] name = "dunce" version = "1.0.5" @@ -1463,6 +1807,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "enum-as-inner" version = "0.6.1" @@ -1521,6 +1871,28 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errify" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb818c3c01af9cdeb367f7e92e290b9a080935cdc5fb6cc0c1193ae17032849" +dependencies = [ + "anyhow", + "errify-macros", +] + +[[package]] +name = "errify-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e87afa19e6030c2cf5514b00d5a242a3ea9492a2aa618635076914f5d15e7af" +dependencies = [ + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.108", +] + [[package]] name = "errno" version = "0.3.14" @@ -1528,7 +1900,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.59.0", +] + +[[package]] +name = "ez-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42b3b3adc5fbbc9e21416d5b721b1bccb501a87d7b32ac89f2c7cea229d40772" +dependencies = [ + "blake2", + "blake3", + "digest 0.10.7", + "md-5", + "sha1", + "sha2", + "sha3", ] [[package]] @@ -1599,6 +1986,16 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7ac824320a75a52197e8f2d787f6a38b6718bb6897a35142d749af3c0e8f4fe" +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1635,6 +2032,15 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-err" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" +dependencies = [ + "autocfg", +] + [[package]] name = "funty" version = "2.0.0" @@ -1826,6 +2232,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hex_fmt" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b07f60793ff0a4d9cef0f18e63b5357e06209987153a64648c972c1e5aff336f" + [[package]] name = "hickory-proto" version = "0.25.2" @@ -1872,6 +2284,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -2102,6 +2523,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.1.0" @@ -2155,6 +2582,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "insta" +version = "1.46.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" +dependencies = [ + "console", + "once_cell", + "similar", + "tempfile", +] + [[package]] name = "iocuddle" version = "0.1.1" @@ -2374,9 +2813,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" [[package]] name = "libm" @@ -2463,6 +2902,16 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest 0.10.7", +] + [[package]] name = "memchr" version = "2.7.6" @@ -2500,6 +2949,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + [[package]] name = "mio" version = "1.1.0" @@ -2546,6 +3005,30 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nested-tls" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +dependencies = [ + "rustls", + "tokio", + "tokio-rustls", + "tracing", +] + +[[package]] +name = "nix" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "cfg_aliases", + "libc", + "memoffset", +] + [[package]] name = "nom" version = "7.1.3" @@ -2562,7 +3045,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -2647,13 +3130,22 @@ dependencies = [ "serde", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs 0.6.2", +] + [[package]] name = "oid-registry" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", ] [[package]] @@ -2773,6 +3265,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "or-panic" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "596a79faf55e869e7bc0c2162cf2f18a54d4d1112876bceae587ad954fcbd574" + [[package]] name = "p256" version = "0.13.2" @@ -3026,6 +3524,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.108", +] + [[package]] name = "primeorder" version = "0.13.6" @@ -3086,6 +3594,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", + "version_check", + "yansi", +] + [[package]] name = "proptest" version = "1.10.0" @@ -3163,7 +3684,7 @@ dependencies = [ "once_cell", "socket2 0.6.1", "tracing", - "windows-sys 0.60.2", + "windows-sys 0.59.0", ] [[package]] @@ -3181,6 +3702,43 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "ra-tls" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "bon", + "cc-eventlog", + "dcap-qvl 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", + "dstack-attest", + "dstack-types", + "elliptic-curve", + "errify", + "ez-hash", + "flate2", + "fs-err", + "hex", + "hex_fmt", + "hkdf", + "or-panic", + "p256", + "parity-scale-codec", + "rand 0.8.5", + "rcgen 0.13.2", + "ring", + "rmp-serde", + "rustls-pki-types", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "sha3", + "tracing", + "x509-parser 0.16.0", + "yasna 0.5.2", +] + [[package]] name = "radium" version = "0.7.0" @@ -3268,14 +3826,29 @@ dependencies = [ [[package]] name = "rcgen" -version = "0.14.5" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fae430c6b28f1ad601274e78b7dffa0546de0b73b4cd32f46723c0c2a16f7a5" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" dependencies = [ "pem", "ring", "rustls-pki-types", "time", + "x509-parser 0.16.0", + "yasna 0.5.2", +] + +[[package]] +name = "rcgen" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10b99e0098aa4082912d4c649628623db6aba77335e4f4569ff5083a6448b32e" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "x509-parser 0.18.1", "yasna 0.5.2", ] @@ -3422,6 +3995,25 @@ dependencies = [ "rustc-hex", ] +[[package]] +name = "rmp" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + [[package]] name = "route-recognizer" version = "0.3.1" @@ -3532,15 +4124,17 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.34" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a9586e9ee2b4f8fab52a0048ca7334d7024eef48e2cb9407e3497bb7cab7fa7" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "brotli", + "brotli-decompressor", "once_cell", "ring", "rustls-pki-types", @@ -3560,9 +4154,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -3725,10 +4319,11 @@ dependencies = [ [[package]] name = "serde-human-bytes" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ef65cb41f3f9cef63c431193229067e8b98b53c4d4c4ed38a8ca87c4d07676" +checksum = "6a091af6294712930d01e375cce513e4ac416f823e033e8991ec4e5d6e6ef4c0" dependencies = [ + "base64 0.13.1", "hex", "serde", ] @@ -3902,6 +4497,28 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + +[[package]] +name = "size-parser" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "slab" version = "0.4.11" @@ -4083,6 +4700,25 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tdx-attest" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "cc-eventlog", + "fs-err", + "hex", + "libc", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "thiserror 2.0.17", + "vsock", +] + [[package]] name = "tdx-quote" version = "0.0.5" @@ -4107,7 +4743,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4217,9 +4853,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.48.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ "bytes", "libc", @@ -4389,9 +5025,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -4401,9 +5037,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -4412,9 +5048,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", @@ -4689,6 +5325,16 @@ version = "0.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" +[[package]] +name = "vsock" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b82aeb12ad864eb8cd26a6c21175d0bdc66d398584ee6c93c76964c3bcfc78ff" +dependencies = [ + "libc", + "nix", +] + [[package]] name = "wait-timeout" version = "0.2.1" @@ -4889,6 +5535,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" @@ -5158,16 +5813,34 @@ dependencies = [ [[package]] name = "x509-parser" -version = "0.18.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3e137310115a65136898d2079f003ce33331a6c4b0d51f1531d1be082b6425" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" dependencies = [ - "asn1-rs", + "asn1-rs 0.6.2", "data-encoding", - "der-parser", + "der-parser 9.0.0", "lazy_static", "nom", - "oid-registry", + "oid-registry 0.7.1", + "ring", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "x509-parser" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202" +dependencies = [ + "asn1-rs 0.7.1", + "data-encoding", + "der-parser 10.0.0", + "lazy_static", + "nom", + "oid-registry 0.8.1", "ring", "rusticata-macros", "thiserror 2.0.17", @@ -5195,6 +5868,12 @@ dependencies = [ "x509-ocsp", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 40a3ab3..c285277 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,10 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] -attested-tls = { path = "attested-tls", default-features = false } -tokio = { version = "1.48.0", features = ["full"] } +attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +tokio = { version = "1.50.0", features = ["full"] } tokio-rustls = { version = "0.26.4", default-features = false } x509-parser = { version = "0.18.0", features = ["verify"] } thiserror = "2.0.17" @@ -45,14 +47,15 @@ pin-project-lite = "0.2.16" [dev-dependencies] tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } -attested-tls = { path = "attested-tls", features = ["test-helpers", "mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } +tokio = { version = "1.48.0", features = ["full"] } jsonrpsee = { version = "0.26.0", features = ["server"] } [features] default = [] # Adds support for Microsoft Azure attestation generation and verification -azure = ["attested-tls/azure"] +azure = ["attestation/azure"] [package.metadata.deb] maintainer = "Flashbots Team " diff --git a/src/lib.rs b/src/lib.rs index 8e0ce80..3830941 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,24 +1,29 @@ //! An attested TLS protocol and HTTPS proxy -pub mod attested_get; +// pub mod attested_get; pub mod file_server; pub mod health_check; pub mod normalize_pem; -pub mod self_signed; +// pub mod self_signed; -pub use attested_tls; -pub use attested_tls::attestation; -pub use attested_tls::attestation::AttestationGenerator; +pub use attestation; +pub use attestation::AttestationGenerator; mod http_version; #[cfg(test)] mod test_helpers; +use attestation::{ + AttestationError, AttestationType, AttestationVerifier, measurements::MultiMeasurements, +}; +use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; use http_body_util::{BodyExt, combinators::BoxBody}; use hyper::{Response, service::service_fn}; use hyper_util::rt::TokioIo; +use nested_tls::server::NestingTlsStream; +use nested_tls::{client::NestingTlsConnector, server::NestingTlsAcceptor}; use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration}; use thiserror::Error; use tokio::io; @@ -26,17 +31,12 @@ use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; use tokio_rustls::rustls::{ - self, ClientConfig, RootCertStore, ServerConfig, pki_types::CertificateDer, + self, ClientConfig, RootCertStore, ServerConfig, + pki_types::{CertificateDer, PrivateKeyDer, ServerName}, }; use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; -use attested_tls::{ - AttestedTlsClient, AttestedTlsError, AttestedTlsServer, TlsCertAndKey, - attestation::{ - AttestationError, AttestationType, AttestationVerifier, measurements::MultiMeasurements, - }, -}; /// The header name for giving attestation type const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; @@ -61,6 +61,14 @@ type RequestWithResponseSender = ( oneshot::Sender>, hyper::Error>>, ); +/// TLS Credentials +pub struct TlsCertAndKey { + /// Der-encoded TLS certificate chain + pub cert_chain: Vec>, + /// Der-encoded TLS private key + pub key: PrivateKeyDer<'static>, +} + /// Adds HTTP 1 and 2 to the list of allowed protocols fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { for protocol in [ALPN_H2, ALPN_HTTP11] { @@ -72,33 +80,32 @@ fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { } } -/// Retrieve the attested remote TLS certificate. -pub async fn get_tls_cert( - server_name: String, - attestation_verifier: AttestationVerifier, - remote_certificate: Option>, - allow_self_signed: bool, -) -> Result<(Vec>, Option), AttestedTlsError> { - let (cert, measurements) = if allow_self_signed { - let client_tls_config = self_signed::client_tls_config_allow_self_signed()?; - attested_tls::get_tls_cert_with_config( - &server_name, - attestation_verifier, - client_tls_config, - ) - .await? - } else { - attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await? - }; - - debug!("[get-tls-cert] Connected to proxy server with measurements: {measurements:?}"); - Ok((cert, measurements)) -} +// /// Retrieve the attested remote TLS certificate. +// pub async fn get_tls_cert( +// server_name: String, +// attestation_verifier: AttestationVerifier, +// remote_certificate: Option>, +// allow_self_signed: bool, +// ) -> Result<(Vec>, Option), AttestedTlsError> { +// let (cert, measurements) = if allow_self_signed { +// let client_tls_config = self_signed::client_tls_config_allow_self_signed()?; +// attested_tls::get_tls_cert_with_config( +// &server_name, +// attestation_verifier, +// client_tls_config, +// ) +// .await? +// } else { +// attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await? +// }; +// +// debug!("[get-tls-cert] Connected to proxy server with measurements: {measurements:?}"); +// Ok((cert, measurements)) +// } /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { - /// The underlying attested TLS server - attested_tls_server: AttestedTlsServer, + nesting_tls_acceptor: NestingTlsAcceptor, /// The underlying TCP listener listener: Arc, /// The address/hostname of the target service we are proxying to @@ -114,7 +121,7 @@ impl ProxyServer { attestation_verifier: AttestationVerifier, client_auth: bool, ) -> Result { - let mut server_config = if client_auth { + let outer_server_config = if client_auth { let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; @@ -133,46 +140,49 @@ impl ProxyServer { cert_and_key.key.clone_key(), )? }; - ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols); - let attested_tls_server = AttestedTlsServer::new_with_tls_config( + Self::new_with_tls_config( cert_and_key.cert_chain, - server_config, + outer_server_config, + local, + target, attestation_generator, attestation_verifier, - )?; - - let listener = TcpListener::bind(local).await?; - - Ok(Self { - attested_tls_server, - listener: listener.into(), - target, - }) + ) + .await } /// Start with preconfigured TLS pub async fn new_with_tls_config( cert_chain: Vec>, - mut server_config: ServerConfig, + mut outer_server_config: ServerConfig, local: impl ToSocketAddrs, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result { - ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols); - - let attested_tls_server = AttestedTlsServer::new_with_tls_config( - cert_chain, - server_config, + ensure_proxy_alpn_protocols(&mut outer_server_config.alpn_protocols); + let server_name = hostname_from_cert(cert_chain.get(0).unwrap()).unwrap(); + let inner_cert_resolver = AttestedCertificateResolver::new( attestation_generator, - attestation_verifier, - )?; + None, + server_name.to_string(), // TODO get name from outer certificate + vec![], + ) + .await + .unwrap(); + + let inner_server_config = + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() // TODO + .with_cert_resolver(Arc::new(inner_cert_resolver)); + let nesting_tls_acceptor = + NestingTlsAcceptor::new(Arc::new(outer_server_config), Arc::new(inner_server_config)); let listener = TcpListener::bind(local).await?; Ok(Self { - attested_tls_server, + nesting_tls_acceptor, listener: listener.into(), target, }) @@ -184,19 +194,12 @@ impl ProxyServer { pub async fn accept(&self) -> Result, ProxyError> { let target = self.target.clone(); let (inbound, client_addr) = self.listener.accept().await?; - let attested_tls_server = self.attested_tls_server.clone(); + let nesting_tls_acceptor = self.nesting_tls_acceptor.clone(); let join_handle = tokio::spawn(async move { - match attested_tls_server.handle_connection(inbound).await { - Ok((tls_stream, measurements, attestation_type)) => { - if let Err(err) = Self::handle_connection( - tls_stream, - measurements, - attestation_type, - target, - client_addr, - ) - .await + match nesting_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = Self::handle_connection(tls_stream, target, client_addr).await { warn!("Failed to handle connection: {err}"); } @@ -217,13 +220,11 @@ impl ProxyServer { /// Handle an incoming connection from a proxy-client async fn handle_connection( - tls_stream: tokio_rustls::server::TlsStream, - measurements: Option, - remote_attestation_type: AttestationType, + tls_stream: NestingTlsStream, target: String, client_addr: SocketAddr, ) -> Result<(), ProxyError> { - debug!("[proxy-server] accepted connection with measurements: {measurements:?}"); + debug!("[proxy-server] accepted connection"); let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); @@ -251,27 +252,6 @@ impl ProxyServer { update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); - // If we have measurements, from the remote peer, add them to the request header - let measurements = measurements.clone(); - if let Some(measurements) = measurements { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); - } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - error!("Failed to encode measurement values: {e}"); - } - } - } - - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); - let target = target.clone(); async move { match Self::handle_http_request(req, target).await { @@ -381,7 +361,7 @@ impl ProxyClient { None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), }; - let mut client_config = if let Some(ref cert_and_key) = cert_and_key { + let mut outer_client_config = if let Some(ref cert_and_key) = cert_and_key { ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .with_root_certificates(root_store) .with_client_auth_cert( @@ -393,43 +373,48 @@ impl ProxyClient { .with_root_certificates(root_store) .with_no_client_auth() }; - ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); - let attested_tls_client = AttestedTlsClient::new_with_tls_config( - client_config, + Self::new_with_tls_config( + outer_client_config, + address, + server_name, attestation_generator, attestation_verifier, - cert_and_key.map(|c| c.cert_chain), - )?; - - Self::new_with_inner(address, attested_tls_client, &server_name).await + remote_certificate, + ) + .await } /// Create a new proxy client with given TLS configuration pub async fn new_with_tls_config( - mut client_config: ClientConfig, + mut outer_client_config: ClientConfig, address: impl ToSocketAddrs, target_name: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { - ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); + ensure_proxy_alpn_protocols(&mut outer_client_config.alpn_protocols); - let attested_tls_client = AttestedTlsClient::new_with_tls_config( - client_config, - attestation_generator, - attestation_verifier, - cert_chain, - )?; + let attested_cert_verifier = + AttestedCertificateVerifier::new(None, attestation_verifier).unwrap(); + + let inner_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + + let nesting_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - Self::new_with_inner(address, attested_tls_client, &target_name).await + Self::new_with_inner(address, nesting_tls_connector, &target_name).await } /// Create a new proxy client with given [AttestedTlsClient] pub async fn new_with_inner( address: impl ToSocketAddrs, - attested_tls_client: AttestedTlsClient, + nesting_tls_connector: NestingTlsConnector, target_name: &str, ) -> Result { let listener = TcpListener::bind(address).await?; @@ -452,9 +437,9 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn, measurements, remote_attestation_type) = + let (mut sender, conn) = // Connect to the proxy server and provide / verify attestation - match Self::setup_connection_with_backoff(&target, &attested_tls_client, first) + match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await { Ok(output) => { @@ -494,29 +479,8 @@ impl ProxyClient { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server let (response, should_reconnect) = match sender.send_request(req).await { - Ok(mut resp) => { + Ok(resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - // If we have measurements from the proxy-server, inject them into the - // response header - let headers = resp.headers_mut(); - if let Some(measurements) = measurements.clone() { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); - } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - error!("Failed to encode measurement values: {e}"); - } - } - } - - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); (Ok(resp.map(|b| b.boxed())), false) } Err(e) => { @@ -622,22 +586,14 @@ impl ProxyClient { // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( target: &str, - attested_tls_client: &AttestedTlsClient, + nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result< - ( - HttpSender, - HttpConnection, - Option, - AttestationType, - ), - ProxyError, - > { + ) -> Result<(HttpSender, HttpConnection), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); loop { - match Self::setup_connection(attested_tls_client, target).await { + match Self::setup_connection(nesting_tls_connector, target).await { Ok(output) => { return Ok(output); } @@ -659,19 +615,16 @@ impl ProxyClient { /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( - inner: &AttestedTlsClient, + nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result< - ( - HttpSender, - HttpConnection, - Option, - AttestationType, - ), - ProxyError, - > { - let (tls_stream, measurements, remote_attestation_type) = inner.connect_tcp(target).await?; - debug!("[proxy-client] Connected to proxy server with measurements: {measurements:?}"); + ) -> Result<(HttpSender, HttpConnection), ProxyError> { + let outbound_stream = tokio::net::TcpStream::connect(target).await?; + + let domain = ServerName::try_from(target).unwrap(); + let tls_stream = nesting_tls_connector + .connect(domain, outbound_stream) + .await?; + debug!("[proxy-client] Connected to proxy server"); // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -697,7 +650,7 @@ impl ProxyClient { }; // Return the HTTP client, as well as remote measurements - Ok((sender, conn, measurements, remote_attestation_type)) + Ok((sender, conn)) } // Handle a request from the source client to the proxy server @@ -755,8 +708,8 @@ pub enum ProxyError { OneShotRecv(#[from] oneshot::error::RecvError), #[error("Failed to send request, connection to proxy-server dropped")] MpscSend, - #[error("Attested TLS: {0}")] - AttestedTls(#[from] AttestedTlsError), + // #[error("Attested TLS: {0}")] + // AttestedTls(#[from] AttestedTlsError), } impl From> for ProxyError { @@ -765,6 +718,24 @@ impl From> for ProxyError { } } +/// Given a certifcate, get the hostname +fn hostname_from_cert(cert: &CertificateDer<'static>) -> Result { + let cert = x509_parser::parse_x509_certificate(cert.as_ref()) + .map(|(_, parsed)| parsed) + .unwrap(); + + Ok(cert + .subject() + .iter_common_name() + .next() + .unwrap() + // .ok_or_else(|| Self::bad_encoding("Missing common name"))? + .as_str() + // .map_err(|err| Self::bad_encoding(format!("Invalid common name: {err}"))) + .unwrap() + .to_string()) +} + /// If no port was provided, default to 443 pub(crate) fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { @@ -1437,6 +1408,7 @@ mod tests { .unwrap() .to_str() .unwrap(); + assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); From cec9c9e95c152ff2225d14b8baa0337bd196120b Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 17 Mar 2026 10:37:29 +0100 Subject: [PATCH 02/20] Basic tests pass --- src/http_version.rs | 11 +- src/lib.rs | 1223 ++++++++++++++++++------------------- src/{main.rs => main.rs_} | 0 src/test_helpers.rs | 41 +- 4 files changed, 634 insertions(+), 641 deletions(-) rename src/{main.rs => main.rs_} (100%) diff --git a/src/http_version.rs b/src/http_version.rs index bef817c..901df66 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -7,6 +7,9 @@ use std::task::{Context, Poll}; pub const ALPN_H2: &[u8] = b"h2"; pub const ALPN_HTTP11: &[u8] = b"http/1.1"; +type ProxyClientTlsStream = + tokio_rustls::client::TlsStream>; + /// Supported HTTP versions #[derive(Debug)] pub enum HttpVersion { @@ -55,13 +58,11 @@ impl HttpVersion { type Http1Sender = hyper::client::conn::http1::SendRequest; type Http2Sender = hyper::client::conn::http2::SendRequest; -type Http1Connection = hyper::client::conn::http1::Connection< - TokioIo>, - hyper::body::Incoming, ->; +type Http1Connection = + hyper::client::conn::http1::Connection, hyper::body::Incoming>; type Http2Connection = hyper::client::conn::http2::Connection< - TokioIo>, + TokioIo, hyper::body::Incoming, crate::TokioExecutor, >; diff --git a/src/lib.rs b/src/lib.rs index 3830941..1d48551 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ //! An attested TLS protocol and HTTPS proxy // pub mod attested_get; -pub mod file_server; +// pub mod file_server; pub mod health_check; pub mod normalize_pem; // pub mod self_signed; @@ -13,9 +13,7 @@ mod http_version; #[cfg(test)] mod test_helpers; -use attestation::{ - AttestationError, AttestationType, AttestationVerifier, measurements::MultiMeasurements, -}; +use attestation::{AttestationError, AttestationVerifier}; use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -38,12 +36,6 @@ use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; -/// The header name for giving attestation type -const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; - -/// The header name for giving measurements -const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; - /// The header name for giving the forwarded for IP static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); @@ -352,16 +344,16 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, remote_certificate: Option>, ) -> Result { - let root_store = match remote_certificate { + let root_store = match remote_certificate.as_ref() { Some(remote_certificate) => { let mut root_store = RootCertStore::empty(); - root_store.add(remote_certificate)?; + root_store.add(remote_certificate.clone())?; root_store } None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), }; - let mut outer_client_config = if let Some(ref cert_and_key) = cert_and_key { + let outer_client_config = if let Some(ref cert_and_key) = cert_and_key { ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .with_root_certificates(root_store) .with_client_auth_cert( @@ -380,7 +372,7 @@ impl ProxyClient { server_name, attestation_generator, attestation_verifier, - remote_certificate, + remote_certificate.map(|certificate| vec![certificate]), ) .await } @@ -620,7 +612,7 @@ impl ProxyClient { ) -> Result<(HttpSender, HttpConnection), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; - let domain = ServerName::try_from(target).unwrap(); + let domain = server_name_from_host(target).unwrap(); let tls_stream = nesting_tls_connector .connect(domain, outbound_stream) .await?; @@ -745,6 +737,16 @@ pub(crate) fn host_to_host_with_port(host: &str) -> String { } } +/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part +fn server_name_from_host( + host: &str, +) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> { + let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host); + let host_part = host_part.trim_matches(|c| c == '[' || c == ']'); + + ServerName::try_from(host_part.to_string()) +} + /// An Executor for hyper that uses the tokio runtime #[derive(Clone)] pub(crate) struct TokioExecutor; @@ -763,13 +765,12 @@ where #[cfg(test)] mod tests { - use crate::{ - attestation::measurements::MeasurementPolicy, attested_tls::get_tls_cert_with_config, - }; + use attestation::{AttestationType, measurements::MeasurementPolicy}; use super::*; use test_helpers::{ - example_http_service, generate_certificate_chain, generate_tls_config, + example_http_service, generate_certificate_chain, generate_certificate_chain_for_host, + generate_tls_config, generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements, }; @@ -789,11 +790,11 @@ mod tests { assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_default_constructors_work() { let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let server_cert = cert_chain[0].clone(); let proxy_server = ProxyServer::new( @@ -819,7 +820,7 @@ mod tests { let proxy_client = ProxyClient::new( None, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), Some(server_cert), @@ -837,510 +838,17 @@ mod tests { .await .unwrap(); - let headers = res.headers(); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } // Server has mock DCAP, client has no attestation and no client auth - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_server_attestation() { let _ = tracing_subscriber::fmt::try_init(); let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let proxy_client = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - proxy_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .await - .unwrap(); - - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_client.accept().await.unwrap(); - }); - - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap(); - - let headers = res.headers(); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); - } - - // Server has no attestation, client has mock DCAP and client auth - #[tokio::test] - async fn http_proxy_client_attestation() { - let target_addr = example_http_service().await; - - let (server_cert_chain, server_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (client_cert_chain, client_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); - - let ( - (_client_tls_server_config, client_tls_client_config), - (server_tls_server_config, _server_tls_client_config), - ) = generate_tls_config_with_client_auth( - client_cert_chain.clone(), - client_private_key, - server_cert_chain.clone(), - server_private_key, - ); - - let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_tls_server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - // Accept one connection, then finish - proxy_server.accept().await.unwrap(); - }); - - let proxy_client = ProxyClient::new_with_tls_config( - client_tls_client_config, - "127.0.0.1:0", - proxy_addr.to_string(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - Some(client_cert_chain), - ) - .await - .unwrap(); - - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - // Accept two connections, then finish - proxy_client.accept().await.unwrap(); - proxy_client.accept().await.unwrap(); - }); - - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap(); - - // We expect no measurements from the server - let headers = res.headers(); - assert!(headers.get(MEASUREMENT_HEADER).is_none()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::None.as_str()); - - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - } - - // Server has no attestation, client has mock DCAP but no client auth - #[tokio::test] - async fn http_proxy_client_attestation_no_client_auth() { - let target_addr = example_http_service().await; - - let (server_cert_chain, server_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = - generate_tls_config(server_cert_chain.clone(), server_private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - // Accept one connection, then finish - proxy_server.accept().await.unwrap(); - }); - - let proxy_client = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0", - proxy_addr.to_string(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - None, - ) - .await - .unwrap(); - - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - // Accept two connections, then finish - proxy_client.accept().await.unwrap(); - proxy_client.accept().await.unwrap(); - }); - - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap(); - - // We expect no measurements from the server - let headers = res.headers(); - assert!(headers.get(MEASUREMENT_HEADER).is_none()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::None.as_str()); - - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - } - - // Server has mock DCAP, client has mock DCAP and client auth - #[tokio::test] - async fn http_proxy_mutual_attestation() { - let target_addr = example_http_service().await; - - let (server_cert_chain, server_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (client_cert_chain, client_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); - - let ( - (_client_tls_server_config, client_tls_client_config), - (server_tls_server_config, _server_tls_client_config), - ) = generate_tls_config_with_client_auth( - client_cert_chain.clone(), - client_private_key, - server_cert_chain.clone(), - server_private_key, - ); - - let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_tls_server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::mock(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - // Accept one connection, then finish - proxy_server.accept().await.unwrap(); - }); - - let proxy_client = ProxyClient::new_with_tls_config( - client_tls_client_config, - "127.0.0.1:0", - proxy_addr.to_string(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::mock(), - Some(client_cert_chain), - ) - .await - .unwrap(); - - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - // Accept two connections, then finish - proxy_client.accept().await.unwrap(); - proxy_client.accept().await.unwrap(); - }); - - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap(); - - let headers = res.headers(); - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - // Now do another request - to check that the connection has stayed open - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap(); - - let headers = res.headers(); - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - } - - // Server has mock DCAP, client no attestation - just get the server certificate - #[tokio::test] - async fn test_get_tls_cert() { - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain.clone(), - server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .await - .unwrap(); - - let proxy_server_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let (retrieved_chain, _measurements) = get_tls_cert_with_config( - &proxy_server_addr.to_string(), - AttestationVerifier::mock(), - client_config, - ) - .await - .unwrap(); - - assert_eq!(retrieved_chain, cert_chain); - } - - // Negative test - server does not provide attestation but client requires it - // Server has no attestaion, client has no attestation and no client auth - #[tokio::test] - async fn fails_on_no_attestation_when_expected() { - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::expect_none(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let proxy_client_result = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - proxy_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .await; - - assert!(matches!( - proxy_client_result.unwrap_err(), - ProxyError::AttestedTls(AttestedTlsError::Attestation( - AttestationError::AttestationTypeNotAccepted - )) - )); - } - - // Negative test - server does not provide attestation but client requires it - // Server has no attestaion, client has no attestation and no client auth - #[tokio::test] - async fn fails_on_bad_measurements() { - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let measurement_policy = MeasurementPolicy::from_json_bytes( - br#" - [{ - "measurement_id": "test", - "attestation_type": "dcap-tdx", - "measurements": { - "0": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, - "1": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, - "2": { "expected": "010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101" }, - "3": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, - "4": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" } - } - }] - "# - .to_vec(), - ) - .unwrap(); - - let attestation_verifier = AttestationVerifier { - measurement_policy, - pccs_url: None, - log_dcap_quote: false, - override_azure_outdated_tcb: false, - }; - - let proxy_client_result = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - proxy_addr.to_string(), - AttestationGenerator::with_no_attestation(), - attestation_verifier, - None, - ) - .await; - - assert!(matches!( - proxy_client_result.unwrap_err(), - ProxyError::AttestedTls(AttestedTlsError::Attestation( - AttestationError::MeasurementsNotAccepted - )) - )); - } - - #[tokio::test] - async fn http_proxy_client_reconnects_on_lost_connection() { - init_tracing(); - - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new_with_tls_config( @@ -1356,25 +864,14 @@ mod tests { let proxy_addr = proxy_server.local_addr().unwrap(); - // This is used to trigger a dropped connection to the proxy server - let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); - tokio::spawn(async move { - let connection_handle = proxy_server.accept().await.unwrap(); - - // Wait for a signal to simulate a dropped connection, then drop the task handling the - // connection - connection_breaker_rx.await.unwrap(); - connection_handle.abort(); - - // Now accept another connection proxy_server.accept().await.unwrap(); }); let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, @@ -1386,111 +883,585 @@ mod tests { tokio::spawn(async move { proxy_client.accept().await.unwrap(); - proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap(); - - // Now break the connection - connection_breaker_tx.send(()).unwrap(); - - // Make another request let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) .await .unwrap(); - let headers = res.headers(); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } - // Use HTTP 1.1 - #[tokio::test] - async fn http_proxy_with_http1() { - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (mut server_config, client_config) = - generate_tls_config(cert_chain.clone(), private_key); - - server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); - - let attested_tls_server = AttestedTlsServer::new_with_tls_config( - cert_chain, - server_config, - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .unwrap(); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - - let proxy_server = ProxyServer { - attested_tls_server, - listener: listener.into(), - target: target_addr.to_string(), - }; - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let proxy_client = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - proxy_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .await - .unwrap(); - - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_client.accept().await.unwrap(); - }); - - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap(); - - let headers = res.headers(); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); - } + // // Server has no attestation, client has mock DCAP and client auth + // #[tokio::test] + // async fn http_proxy_client_attestation() { + // let target_addr = example_http_service().await; + // + // let (server_cert_chain, server_private_key) = + // generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (client_cert_chain, client_private_key) = + // generate_certificate_chain("127.0.0.1".parse().unwrap()); + // + // let ( + // (_client_tls_server_config, client_tls_client_config), + // (server_tls_server_config, _server_tls_client_config), + // ) = generate_tls_config_with_client_auth( + // client_cert_chain.clone(), + // client_private_key, + // server_cert_chain.clone(), + // server_private_key, + // ); + // + // let proxy_server = ProxyServer::new_with_tls_config( + // server_cert_chain, + // server_tls_server_config, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::with_no_attestation(), + // AttestationVerifier::mock(), + // ) + // .await + // .unwrap(); + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // // Accept one connection, then finish + // proxy_server.accept().await.unwrap(); + // }); + // + // let proxy_client = ProxyClient::new_with_tls_config( + // client_tls_client_config, + // "127.0.0.1:0", + // proxy_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::expect_none(), + // Some(client_cert_chain), + // ) + // .await + // .unwrap(); + // + // let proxy_client_addr = proxy_client.local_addr().unwrap(); + // + // tokio::spawn(async move { + // // Accept two connections, then finish + // proxy_client.accept().await.unwrap(); + // proxy_client.accept().await.unwrap(); + // }); + // + // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + // .await + // .unwrap(); + // + // // We expect no measurements from the server + // let headers = res.headers(); + // assert!(headers.get(MEASUREMENT_HEADER).is_none()); + // + // let attestation_type = headers + // .get(ATTESTATION_TYPE_HEADER) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(attestation_type, AttestationType::None.as_str()); + // + // let res_body = res.text().await.unwrap(); + // + // // The response body shows us what was in the request header (as the test http server + // // handler puts them there) + // let measurements = + // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // } + // + // // Server has no attestation, client has mock DCAP but no client auth + // #[tokio::test] + // async fn http_proxy_client_attestation_no_client_auth() { + // let target_addr = example_http_service().await; + // + // let (server_cert_chain, server_private_key) = + // generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (server_config, client_config) = + // generate_tls_config(server_cert_chain.clone(), server_private_key); + // + // let proxy_server = ProxyServer::new_with_tls_config( + // server_cert_chain, + // server_config, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::with_no_attestation(), + // AttestationVerifier::mock(), + // ) + // .await + // .unwrap(); + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // // Accept one connection, then finish + // proxy_server.accept().await.unwrap(); + // }); + // + // let proxy_client = ProxyClient::new_with_tls_config( + // client_config, + // "127.0.0.1:0", + // proxy_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::expect_none(), + // None, + // ) + // .await + // .unwrap(); + // + // let proxy_client_addr = proxy_client.local_addr().unwrap(); + // + // tokio::spawn(async move { + // // Accept two connections, then finish + // proxy_client.accept().await.unwrap(); + // proxy_client.accept().await.unwrap(); + // }); + // + // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + // .await + // .unwrap(); + // + // // We expect no measurements from the server + // let headers = res.headers(); + // assert!(headers.get(MEASUREMENT_HEADER).is_none()); + // + // let attestation_type = headers + // .get(ATTESTATION_TYPE_HEADER) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(attestation_type, AttestationType::None.as_str()); + // + // let res_body = res.text().await.unwrap(); + // + // // The response body shows us what was in the request header (as the test http server + // // handler puts them there) + // let measurements = + // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // } + // + // // Server has mock DCAP, client has mock DCAP and client auth + // #[tokio::test] + // async fn http_proxy_mutual_attestation() { + // let target_addr = example_http_service().await; + // + // let (server_cert_chain, server_private_key) = + // generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (client_cert_chain, client_private_key) = + // generate_certificate_chain("127.0.0.1".parse().unwrap()); + // + // let ( + // (_client_tls_server_config, client_tls_client_config), + // (server_tls_server_config, _server_tls_client_config), + // ) = generate_tls_config_with_client_auth( + // client_cert_chain.clone(), + // client_private_key, + // server_cert_chain.clone(), + // server_private_key, + // ); + // + // let proxy_server = ProxyServer::new_with_tls_config( + // server_cert_chain, + // server_tls_server_config, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::mock(), + // ) + // .await + // .unwrap(); + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // // Accept one connection, then finish + // proxy_server.accept().await.unwrap(); + // }); + // + // let proxy_client = ProxyClient::new_with_tls_config( + // client_tls_client_config, + // "127.0.0.1:0", + // proxy_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::mock(), + // Some(client_cert_chain), + // ) + // .await + // .unwrap(); + // + // let proxy_client_addr = proxy_client.local_addr().unwrap(); + // + // tokio::spawn(async move { + // // Accept two connections, then finish + // proxy_client.accept().await.unwrap(); + // proxy_client.accept().await.unwrap(); + // }); + // + // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + // .await + // .unwrap(); + // + // let headers = res.headers(); + // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + // let measurements = + // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) + // .unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // + // let attestation_type = headers + // .get(ATTESTATION_TYPE_HEADER) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + // + // let res_body = res.text().await.unwrap(); + // + // // The response body shows us what was in the request header (as the test http server + // // handler puts them there) + // let measurements = + // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // + // // Now do another request - to check that the connection has stayed open + // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + // .await + // .unwrap(); + // + // let headers = res.headers(); + // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + // let measurements = + // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) + // .unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // + // let attestation_type = headers + // .get(ATTESTATION_TYPE_HEADER) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + // + // let res_body = res.text().await.unwrap(); + // + // // The response body shows us what was in the request header (as the test http server + // // handler puts them there) + // let measurements = + // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // } + // + // // Server has mock DCAP, client no attestation - just get the server certificate + // #[tokio::test] + // async fn test_get_tls_cert() { + // let target_addr = example_http_service().await; + // + // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + // + // let proxy_server = ProxyServer::new_with_tls_config( + // cert_chain.clone(), + // server_config, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::expect_none(), + // ) + // .await + // .unwrap(); + // + // let proxy_server_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_server.accept().await.unwrap(); + // }); + // + // let (retrieved_chain, _measurements) = get_tls_cert_with_config( + // &proxy_server_addr.to_string(), + // AttestationVerifier::mock(), + // client_config, + // ) + // .await + // .unwrap(); + // + // assert_eq!(retrieved_chain, cert_chain); + // } + // + // // Negative test - server does not provide attestation but client requires it + // // Server has no attestaion, client has no attestation and no client auth + // #[tokio::test] + // async fn fails_on_no_attestation_when_expected() { + // let target_addr = example_http_service().await; + // + // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + // + // let proxy_server = ProxyServer::new_with_tls_config( + // cert_chain, + // server_config, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::with_no_attestation(), + // AttestationVerifier::expect_none(), + // ) + // .await + // .unwrap(); + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_server.accept().await.unwrap(); + // }); + // + // let proxy_client_result = ProxyClient::new_with_tls_config( + // client_config, + // "127.0.0.1:0".to_string(), + // proxy_addr.to_string(), + // AttestationGenerator::with_no_attestation(), + // AttestationVerifier::mock(), + // None, + // ) + // .await; + // + // assert!(matches!( + // proxy_client_result.unwrap_err(), + // ProxyError::AttestedTls(AttestedTlsError::Attestation( + // AttestationError::AttestationTypeNotAccepted + // )) + // )); + // } + // + // // Negative test - server does not provide attestation but client requires it + // // Server has no attestaion, client has no attestation and no client auth + // #[tokio::test] + // async fn fails_on_bad_measurements() { + // let target_addr = example_http_service().await; + // + // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + // + // let proxy_server = ProxyServer::new_with_tls_config( + // cert_chain, + // server_config, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::expect_none(), + // ) + // .await + // .unwrap(); + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_server.accept().await.unwrap(); + // }); + // + // let measurement_policy = MeasurementPolicy::from_json_bytes( + // br#" + // [{ + // "measurement_id": "test", + // "attestation_type": "dcap-tdx", + // "measurements": { + // "0": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, + // "1": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, + // "2": { "expected": "010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101" }, + // "3": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, + // "4": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" } + // } + // }] + // "# + // .to_vec(), + // ) + // .unwrap(); + // + // let attestation_verifier = AttestationVerifier { + // measurement_policy, + // pccs_url: None, + // log_dcap_quote: false, + // override_azure_outdated_tcb: false, + // }; + // + // let proxy_client_result = ProxyClient::new_with_tls_config( + // client_config, + // "127.0.0.1:0".to_string(), + // proxy_addr.to_string(), + // AttestationGenerator::with_no_attestation(), + // attestation_verifier, + // None, + // ) + // .await; + // + // assert!(matches!( + // proxy_client_result.unwrap_err(), + // ProxyError::AttestedTls(AttestedTlsError::Attestation( + // AttestationError::MeasurementsNotAccepted + // )) + // )); + // } + // + // #[tokio::test] + // async fn http_proxy_client_reconnects_on_lost_connection() { + // init_tracing(); + // + // let target_addr = example_http_service().await; + // + // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + // + // let proxy_server = ProxyServer::new_with_tls_config( + // cert_chain, + // server_config, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::expect_none(), + // ) + // .await + // .unwrap(); + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // // This is used to trigger a dropped connection to the proxy server + // let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + // + // tokio::spawn(async move { + // let connection_handle = proxy_server.accept().await.unwrap(); + // + // // Wait for a signal to simulate a dropped connection, then drop the task handling the + // // connection + // connection_breaker_rx.await.unwrap(); + // connection_handle.abort(); + // + // // Now accept another connection + // proxy_server.accept().await.unwrap(); + // }); + // + // let proxy_client = ProxyClient::new_with_tls_config( + // client_config, + // "127.0.0.1:0".to_string(), + // proxy_addr.to_string(), + // AttestationGenerator::with_no_attestation(), + // AttestationVerifier::mock(), + // None, + // ) + // .await + // .unwrap(); + // + // let proxy_client_addr = proxy_client.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_client.accept().await.unwrap(); + // proxy_client.accept().await.unwrap(); + // }); + // + // let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + // .await + // .unwrap(); + // + // // Now break the connection + // connection_breaker_tx.send(()).unwrap(); + // + // // Make another request + // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + // .await + // .unwrap(); + // + // let headers = res.headers(); + // + // let attestation_type = headers + // .get(ATTESTATION_TYPE_HEADER) + // .unwrap() + // .to_str() + // .unwrap(); + // + // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + // + // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + // let measurements = + // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) + // .unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // + // let res_body = res.text().await.unwrap(); + // assert_eq!(res_body, "No measurements"); + // } + // + // // Use HTTP 1.1 + // #[tokio::test] + // async fn http_proxy_with_http1() { + // let target_addr = example_http_service().await; + // + // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + // let (mut server_config, client_config) = + // generate_tls_config(cert_chain.clone(), private_key); + // + // server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); + // + // let attested_tls_server = AttestedTlsServer::new_with_tls_config( + // cert_chain, + // server_config, + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::expect_none(), + // ) + // .unwrap(); + // + // let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + // + // let proxy_server = ProxyServer { + // attested_tls_server, + // listener: listener.into(), + // target: target_addr.to_string(), + // }; + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_server.accept().await.unwrap(); + // }); + // + // let proxy_client = ProxyClient::new_with_tls_config( + // client_config, + // "127.0.0.1:0".to_string(), + // proxy_addr.to_string(), + // AttestationGenerator::with_no_attestation(), + // AttestationVerifier::mock(), + // None, + // ) + // .await + // .unwrap(); + // + // let proxy_client_addr = proxy_client.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_client.accept().await.unwrap(); + // }); + // + // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + // .await + // .unwrap(); + // + // let headers = res.headers(); + // + // let attestation_type = headers + // .get(ATTESTATION_TYPE_HEADER) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + // + // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + // let measurements = + // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) + // .unwrap(); + // assert_eq!(measurements, mock_dcap_measurements()); + // + // let res_body = res.text().await.unwrap(); + // assert_eq!(res_body, "No measurements"); + // } } diff --git a/src/main.rs b/src/main.rs_ similarity index 100% rename from src/main.rs rename to src/main.rs_ diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 8990734..66ddc60 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -15,10 +15,7 @@ use tracing_subscriber::{EnvFilter, fmt}; static INIT: Once = Once::new(); -use crate::{ - MEASUREMENT_HEADER, - attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}, -}; +use attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}; /// Helper to generate a self-signed certificate for testing pub fn generate_certificate_chain( @@ -26,6 +23,9 @@ pub fn generate_certificate_chain( ) -> (Vec>, PrivateKeyDer<'static>) { let mut params = rcgen::CertificateParams::new(vec![]).unwrap(); params.subject_alt_names.push(rcgen::SanType::IpAddress(ip)); + params + .subject_alt_names + .push(rcgen::SanType::DnsName(ip.to_string().try_into().unwrap())); params .distinguished_name .push(rcgen::DnType::CommonName, ip.to_string()); @@ -38,6 +38,26 @@ pub fn generate_certificate_chain( (certs, key) } +/// Helper to generate a self-signed certificate for testing with a DNS subject name +pub fn generate_certificate_chain_for_host( + host: &str, +) -> (Vec>, PrivateKeyDer<'static>) { + let mut params = rcgen::CertificateParams::new(vec![host.to_string()]).unwrap(); + params + .subject_alt_names + .push(rcgen::SanType::DnsName(host.try_into().unwrap())); + params + .distinguished_name + .push(rcgen::DnType::CommonName, host); + + let keypair = rcgen::KeyPair::generate().unwrap(); + let cert = params.self_signed(&keypair).unwrap(); + + let certs = vec![CertificateDer::from(cert)]; + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(keypair.serialize_der())); + (certs, key) +} + /// Helper to generate TLS configuration for testing /// /// For the server: A given self-signed certificate @@ -131,12 +151,13 @@ pub async fn example_http_service() -> SocketAddr { addr } -async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { - headers - .get(MEASUREMENT_HEADER) - .and_then(|v| v.to_str().ok()) - .unwrap_or("No measurements") - .to_string() +async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse { + // headers + // .get(MEASUREMENT_HEADER) + // .and_then(|v| v.to_str().ok()) + // .unwrap_or("No measurements") + // .to_string() + "No measurements".to_string() } /// All-zero measurment values used in some tests From feec7ca604a63bbb48055e4bb37b329392550cca Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 17 Mar 2026 11:54:27 +0100 Subject: [PATCH 03/20] Main compiles --- src/attested_get.rs | 40 +- src/file_server.rs | 12 +- src/lib.rs | 1235 ++++++++++++++++++------------------- src/{main.rs_ => main.rs} | 83 +-- src/self_signed.rs | 323 ---------- src/test_helpers.rs | 11 - 6 files changed, 663 insertions(+), 1041 deletions(-) rename src/{main.rs_ => main.rs} (89%) delete mode 100644 src/self_signed.rs diff --git a/src/attested_get.rs b/src/attested_get.rs index 27e7ad3..3b4575f 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -9,30 +9,16 @@ pub async fn attested_get( url_path: &str, attestation_verifier: AttestationVerifier, remote_certificate: Option>, - allow_self_signed: bool, ) -> Result { - let proxy_client = if allow_self_signed { - let client_config = crate::self_signed::client_tls_config_allow_self_signed()?; - ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - target_addr, - AttestationGenerator::with_no_attestation(), - attestation_verifier, - None, - ) - .await? - } else { - ProxyClient::new( - None, - "127.0.0.1:0".to_string(), - target_addr, - AttestationGenerator::with_no_attestation(), - attestation_verifier, - remote_certificate, - ) - .await? - }; + let proxy_client = ProxyClient::new( + None, + "127.0.0.1:0".to_string(), + target_addr, + AttestationGenerator::with_no_attestation(), + attestation_verifier, + remote_certificate, + ) + .await?; attested_get_with_client(proxy_client, url_path).await } @@ -72,11 +58,11 @@ mod tests { ProxyServer, attestation::AttestationType, file_server::static_file_server, - test_helpers::{generate_certificate_chain, generate_tls_config}, + test_helpers::{generate_certificate_chain_for_host, generate_tls_config}, }; use tempfile::tempdir; - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_attested_get() { // Create a temporary directory with a file to serve let dir = tempdir().unwrap(); @@ -87,7 +73,7 @@ mod tests { let target_addr = static_file_server(dir.path().to_path_buf()).await.unwrap(); // Create TLS configuration - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); // Setup a proxy server targetting the static file server @@ -113,7 +99,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, diff --git a/src/file_server.rs b/src/file_server.rs index bce4804..424008e 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -55,7 +55,7 @@ mod tests { use crate::{ProxyClient, attestation::AttestationType}; use super::*; - use crate::test_helpers::{generate_certificate_chain, generate_tls_config}; + use crate::test_helpers::{generate_certificate_chain_for_host, generate_tls_config}; use tempfile::tempdir; /// Given a URL, fetch response body and content type header @@ -74,7 +74,7 @@ mod tests { (body.to_vec(), content_type) } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_static_file_server() { // Create a temporary directory with some files to serve let dir = tempdir().unwrap(); @@ -94,7 +94,7 @@ mod tests { let target_addr = static_file_server(dir.path().to_path_buf()).await.unwrap(); // Create TLS configuration - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); // Setup a proxy server targetting the static file server @@ -118,7 +118,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, @@ -128,9 +128,11 @@ mod tests { let proxy_client_addr = proxy_client.local_addr().unwrap(); - // Proxy cient accepts a single connection + // Accept one client connection per request. tokio::spawn(async move { proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); }); let client = reqwest::Client::new(); diff --git a/src/lib.rs b/src/lib.rs index 1d48551..e711527 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,8 @@ //! An attested TLS protocol and HTTPS proxy -// pub mod attested_get; -// pub mod file_server; +pub mod attested_get; +pub mod file_server; pub mod health_check; pub mod normalize_pem; -// pub mod self_signed; pub use attestation; pub use attestation::AttestationGenerator; @@ -14,7 +13,7 @@ mod http_version; mod test_helpers; use attestation::{AttestationError, AttestationVerifier}; -use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier}; +use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier, AttestedTlsError}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; use http_body_util::{BodyExt, combinators::BoxBody}; @@ -24,7 +23,7 @@ use nested_tls::server::NestingTlsStream; use nested_tls::{client::NestingTlsConnector, server::NestingTlsAcceptor}; use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration}; use thiserror::Error; -use tokio::io; +use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; @@ -72,28 +71,65 @@ fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { } } -// /// Retrieve the attested remote TLS certificate. -// pub async fn get_tls_cert( -// server_name: String, -// attestation_verifier: AttestationVerifier, -// remote_certificate: Option>, -// allow_self_signed: bool, -// ) -> Result<(Vec>, Option), AttestedTlsError> { -// let (cert, measurements) = if allow_self_signed { -// let client_tls_config = self_signed::client_tls_config_allow_self_signed()?; -// attested_tls::get_tls_cert_with_config( -// &server_name, -// attestation_verifier, -// client_tls_config, -// ) -// .await? -// } else { -// attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await? -// }; -// -// debug!("[get-tls-cert] Connected to proxy server with measurements: {measurements:?}"); -// Ok((cert, measurements)) -// } +/// Retrieve the inner attested remote TLS certificate. +pub async fn get_inner_tls_cert( + server_name: String, + attestation_verifier: AttestationVerifier, + remote_outer_certificate: Option>, +) -> Result>, ProxyError> { + let root_store = match remote_outer_certificate.as_ref() { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate.clone())?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), + }; + + let outer_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(root_store) + .with_no_client_auth(); + + get_inner_tls_cert_with_config(server_name, attestation_verifier, outer_client_config).await +} + +pub async fn get_inner_tls_cert_with_config( + server_name: String, + attestation_verifier: AttestationVerifier, + mut outer_client_config: ClientConfig, +) -> Result>, ProxyError> { + ensure_proxy_alpn_protocols(&mut outer_client_config.alpn_protocols); + let outbound_stream = tokio::net::TcpStream::connect(&server_name).await?; + + let domain = server_name_from_host(&server_name)?; + + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; + let inner_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + + let nested_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); + + let mut tls_stream = nested_tls_connector + .connect(domain, outbound_stream) + .await?; + debug!("[get-tls-cert] Connected to proxy server"); + + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)? + .to_owned(); + + tls_stream.shutdown().await?; + + Ok(remote_cert_chain) +} /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { @@ -133,13 +169,14 @@ impl ProxyServer { )? }; - Self::new_with_tls_config( + Self::new_with_tls_config_and_client_auth( cert_and_key.cert_chain, outer_server_config, local, target, attestation_generator, attestation_verifier, + client_auth, ) .await } @@ -147,27 +184,51 @@ impl ProxyServer { /// Start with preconfigured TLS pub async fn new_with_tls_config( cert_chain: Vec>, - mut outer_server_config: ServerConfig, + outer_server_config: ServerConfig, local: impl ToSocketAddrs, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result { - ensure_proxy_alpn_protocols(&mut outer_server_config.alpn_protocols); - let server_name = hostname_from_cert(cert_chain.get(0).unwrap()).unwrap(); - let inner_cert_resolver = AttestedCertificateResolver::new( + Self::new_with_tls_config_and_client_auth( + cert_chain, + outer_server_config, + local, + target, attestation_generator, - None, - server_name.to_string(), // TODO get name from outer certificate - vec![], + attestation_verifier, + false, ) .await - .unwrap(); + } + + /// Start with preconfigured TLS and require client auth on both nested sessions + pub async fn new_with_tls_config_and_client_auth( + cert_chain: Vec>, + mut outer_server_config: ServerConfig, + local: impl ToSocketAddrs, + target: String, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + client_auth: bool, + ) -> Result { + ensure_proxy_alpn_protocols(&mut outer_server_config.alpn_protocols); + let server_name = certificate_identity_from_chain(&cert_chain)?; + let inner_cert_resolver = + build_attested_cert_resolver(attestation_generator, &cert_chain).await?; - let inner_server_config = + let inner_server_config = if client_auth { + let attested_cert_verifier = + AttestedCertificateVerifier::new(None, attestation_verifier)?; ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() // TODO - .with_cert_resolver(Arc::new(inner_cert_resolver)); + .with_client_cert_verifier(Arc::new(attested_cert_verifier)) + .with_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + let _ = server_name; + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_cert_resolver(Arc::new(inner_cert_resolver)) + }; let nesting_tls_acceptor = NestingTlsAcceptor::new(Arc::new(outer_server_config), Arc::new(inner_server_config)); @@ -372,7 +433,7 @@ impl ProxyClient { server_name, attestation_generator, attestation_verifier, - remote_certificate.map(|certificate| vec![certificate]), + cert_and_key.map(|cert_and_key| cert_and_key.cert_chain), ) .await } @@ -387,15 +448,28 @@ impl ProxyClient { cert_chain: Option>>, ) -> Result { ensure_proxy_alpn_protocols(&mut outer_client_config.alpn_protocols); + let outer_has_client_auth = outer_client_config.client_auth_cert_resolver.has_certs(); + let inner_has_client_auth = cert_chain.is_some(); + + if outer_has_client_auth != inner_has_client_auth { + return Err(ProxyError::ClientAuthMisconfigured); + } - let attested_cert_verifier = - AttestedCertificateVerifier::new(None, attestation_verifier).unwrap(); + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; - let inner_client_config = + let inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { + let inner_cert_resolver = + build_attested_cert_resolver(attestation_generator, cert_chain).await?; ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .dangerous() .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) - .with_no_client_auth(); + .with_client_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth() + }; let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); @@ -590,7 +664,7 @@ impl ProxyClient { return Ok(output); } Err(e) => { - if matches!(e, ProxyError::Io(_)) || !should_bail { + if should_retry_setup_error(&e, should_bail) { warn!("Reconnect failed: {e}. Retrying in {:#?}...", delay); tokio::time::sleep(delay).await; @@ -656,6 +730,25 @@ impl ProxyClient { } } +fn should_retry_setup_error(error: &ProxyError, should_bail: bool) -> bool { + if !should_bail { + return true; + } + + match error { + ProxyError::Io(io_error) => matches!( + io_error.kind(), + std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::TimedOut + | std::io::ErrorKind::UnexpectedEof + ), + _ => false, + } +} + /// Update a request/response header if we are able to encode the header value /// /// This avoids bailing on bad header values - the headers are simply not updated @@ -694,14 +787,16 @@ pub enum ProxyError { BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), #[error("HTTP: {0}")] Hyper(#[from] hyper::Error), + #[error("Attested TLS: {0}")] + AttestedTls(#[from] AttestedTlsError), #[error("JSON: {0}")] Json(#[from] serde_json::Error), #[error("Could not forward response - sender was dropped")] OneShotRecv(#[from] oneshot::error::RecvError), #[error("Failed to send request, connection to proxy-server dropped")] MpscSend, - // #[error("Attested TLS: {0}")] - // AttestedTls(#[from] AttestedTlsError), + #[error("Client auth must be configured on both the inner and outer TLS sessions")] + ClientAuthMisconfigured, } impl From> for ProxyError { @@ -728,6 +823,23 @@ fn hostname_from_cert(cert: &CertificateDer<'static>) -> Result], +) -> Result { + hostname_from_cert(cert_chain.first().ok_or(ProxyError::NoCertificate)?) +} + +async fn build_attested_cert_resolver( + attestation_generator: AttestationGenerator, + cert_chain: &[CertificateDer<'static>], +) -> Result { + let certificate_name = certificate_identity_from_chain(cert_chain)?; + Ok( + AttestedCertificateResolver::new(attestation_generator, None, certificate_name, vec![]) + .await?, + ) +} + /// If no port was provided, default to 443 pub(crate) fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { @@ -769,9 +881,8 @@ mod tests { use super::*; use test_helpers::{ - example_http_service, generate_certificate_chain, generate_certificate_chain_for_host, - generate_tls_config, - generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements, + example_http_service, generate_certificate_chain_for_host, generate_tls_config, + generate_tls_config_with_client_auth, init_tracing, }; #[test] @@ -893,575 +1004,457 @@ mod tests { assert_eq!(res_body, "No measurements"); } - // // Server has no attestation, client has mock DCAP and client auth - // #[tokio::test] - // async fn http_proxy_client_attestation() { - // let target_addr = example_http_service().await; - // - // let (server_cert_chain, server_private_key) = - // generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (client_cert_chain, client_private_key) = - // generate_certificate_chain("127.0.0.1".parse().unwrap()); - // - // let ( - // (_client_tls_server_config, client_tls_client_config), - // (server_tls_server_config, _server_tls_client_config), - // ) = generate_tls_config_with_client_auth( - // client_cert_chain.clone(), - // client_private_key, - // server_cert_chain.clone(), - // server_private_key, - // ); - // - // let proxy_server = ProxyServer::new_with_tls_config( - // server_cert_chain, - // server_tls_server_config, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::with_no_attestation(), - // AttestationVerifier::mock(), - // ) - // .await - // .unwrap(); - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // // Accept one connection, then finish - // proxy_server.accept().await.unwrap(); - // }); - // - // let proxy_client = ProxyClient::new_with_tls_config( - // client_tls_client_config, - // "127.0.0.1:0", - // proxy_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::expect_none(), - // Some(client_cert_chain), - // ) - // .await - // .unwrap(); - // - // let proxy_client_addr = proxy_client.local_addr().unwrap(); - // - // tokio::spawn(async move { - // // Accept two connections, then finish - // proxy_client.accept().await.unwrap(); - // proxy_client.accept().await.unwrap(); - // }); - // - // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - // .await - // .unwrap(); - // - // // We expect no measurements from the server - // let headers = res.headers(); - // assert!(headers.get(MEASUREMENT_HEADER).is_none()); - // - // let attestation_type = headers - // .get(ATTESTATION_TYPE_HEADER) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(attestation_type, AttestationType::None.as_str()); - // - // let res_body = res.text().await.unwrap(); - // - // // The response body shows us what was in the request header (as the test http server - // // handler puts them there) - // let measurements = - // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // } - // - // // Server has no attestation, client has mock DCAP but no client auth - // #[tokio::test] - // async fn http_proxy_client_attestation_no_client_auth() { - // let target_addr = example_http_service().await; - // - // let (server_cert_chain, server_private_key) = - // generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (server_config, client_config) = - // generate_tls_config(server_cert_chain.clone(), server_private_key); - // - // let proxy_server = ProxyServer::new_with_tls_config( - // server_cert_chain, - // server_config, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::with_no_attestation(), - // AttestationVerifier::mock(), - // ) - // .await - // .unwrap(); - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // // Accept one connection, then finish - // proxy_server.accept().await.unwrap(); - // }); - // - // let proxy_client = ProxyClient::new_with_tls_config( - // client_config, - // "127.0.0.1:0", - // proxy_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::expect_none(), - // None, - // ) - // .await - // .unwrap(); - // - // let proxy_client_addr = proxy_client.local_addr().unwrap(); - // - // tokio::spawn(async move { - // // Accept two connections, then finish - // proxy_client.accept().await.unwrap(); - // proxy_client.accept().await.unwrap(); - // }); - // - // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - // .await - // .unwrap(); - // - // // We expect no measurements from the server - // let headers = res.headers(); - // assert!(headers.get(MEASUREMENT_HEADER).is_none()); - // - // let attestation_type = headers - // .get(ATTESTATION_TYPE_HEADER) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(attestation_type, AttestationType::None.as_str()); - // - // let res_body = res.text().await.unwrap(); - // - // // The response body shows us what was in the request header (as the test http server - // // handler puts them there) - // let measurements = - // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // } - // - // // Server has mock DCAP, client has mock DCAP and client auth - // #[tokio::test] - // async fn http_proxy_mutual_attestation() { - // let target_addr = example_http_service().await; - // - // let (server_cert_chain, server_private_key) = - // generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (client_cert_chain, client_private_key) = - // generate_certificate_chain("127.0.0.1".parse().unwrap()); - // - // let ( - // (_client_tls_server_config, client_tls_client_config), - // (server_tls_server_config, _server_tls_client_config), - // ) = generate_tls_config_with_client_auth( - // client_cert_chain.clone(), - // client_private_key, - // server_cert_chain.clone(), - // server_private_key, - // ); - // - // let proxy_server = ProxyServer::new_with_tls_config( - // server_cert_chain, - // server_tls_server_config, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::mock(), - // ) - // .await - // .unwrap(); - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // // Accept one connection, then finish - // proxy_server.accept().await.unwrap(); - // }); - // - // let proxy_client = ProxyClient::new_with_tls_config( - // client_tls_client_config, - // "127.0.0.1:0", - // proxy_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::mock(), - // Some(client_cert_chain), - // ) - // .await - // .unwrap(); - // - // let proxy_client_addr = proxy_client.local_addr().unwrap(); - // - // tokio::spawn(async move { - // // Accept two connections, then finish - // proxy_client.accept().await.unwrap(); - // proxy_client.accept().await.unwrap(); - // }); - // - // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - // .await - // .unwrap(); - // - // let headers = res.headers(); - // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - // let measurements = - // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - // .unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // - // let attestation_type = headers - // .get(ATTESTATION_TYPE_HEADER) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - // - // let res_body = res.text().await.unwrap(); - // - // // The response body shows us what was in the request header (as the test http server - // // handler puts them there) - // let measurements = - // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // - // // Now do another request - to check that the connection has stayed open - // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - // .await - // .unwrap(); - // - // let headers = res.headers(); - // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - // let measurements = - // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - // .unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // - // let attestation_type = headers - // .get(ATTESTATION_TYPE_HEADER) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - // - // let res_body = res.text().await.unwrap(); - // - // // The response body shows us what was in the request header (as the test http server - // // handler puts them there) - // let measurements = - // MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // } - // - // // Server has mock DCAP, client no attestation - just get the server certificate - // #[tokio::test] - // async fn test_get_tls_cert() { - // let target_addr = example_http_service().await; - // - // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - // - // let proxy_server = ProxyServer::new_with_tls_config( - // cert_chain.clone(), - // server_config, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::expect_none(), - // ) - // .await - // .unwrap(); - // - // let proxy_server_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_server.accept().await.unwrap(); - // }); - // - // let (retrieved_chain, _measurements) = get_tls_cert_with_config( - // &proxy_server_addr.to_string(), - // AttestationVerifier::mock(), - // client_config, - // ) - // .await - // .unwrap(); - // - // assert_eq!(retrieved_chain, cert_chain); - // } - // - // // Negative test - server does not provide attestation but client requires it - // // Server has no attestaion, client has no attestation and no client auth - // #[tokio::test] - // async fn fails_on_no_attestation_when_expected() { - // let target_addr = example_http_service().await; - // - // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - // - // let proxy_server = ProxyServer::new_with_tls_config( - // cert_chain, - // server_config, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::with_no_attestation(), - // AttestationVerifier::expect_none(), - // ) - // .await - // .unwrap(); - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_server.accept().await.unwrap(); - // }); - // - // let proxy_client_result = ProxyClient::new_with_tls_config( - // client_config, - // "127.0.0.1:0".to_string(), - // proxy_addr.to_string(), - // AttestationGenerator::with_no_attestation(), - // AttestationVerifier::mock(), - // None, - // ) - // .await; - // - // assert!(matches!( - // proxy_client_result.unwrap_err(), - // ProxyError::AttestedTls(AttestedTlsError::Attestation( - // AttestationError::AttestationTypeNotAccepted - // )) - // )); - // } - // - // // Negative test - server does not provide attestation but client requires it - // // Server has no attestaion, client has no attestation and no client auth - // #[tokio::test] - // async fn fails_on_bad_measurements() { - // let target_addr = example_http_service().await; - // - // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - // - // let proxy_server = ProxyServer::new_with_tls_config( - // cert_chain, - // server_config, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::expect_none(), - // ) - // .await - // .unwrap(); - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_server.accept().await.unwrap(); - // }); - // - // let measurement_policy = MeasurementPolicy::from_json_bytes( - // br#" - // [{ - // "measurement_id": "test", - // "attestation_type": "dcap-tdx", - // "measurements": { - // "0": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, - // "1": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, - // "2": { "expected": "010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101" }, - // "3": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, - // "4": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" } - // } - // }] - // "# - // .to_vec(), - // ) - // .unwrap(); - // - // let attestation_verifier = AttestationVerifier { - // measurement_policy, - // pccs_url: None, - // log_dcap_quote: false, - // override_azure_outdated_tcb: false, - // }; - // - // let proxy_client_result = ProxyClient::new_with_tls_config( - // client_config, - // "127.0.0.1:0".to_string(), - // proxy_addr.to_string(), - // AttestationGenerator::with_no_attestation(), - // attestation_verifier, - // None, - // ) - // .await; - // - // assert!(matches!( - // proxy_client_result.unwrap_err(), - // ProxyError::AttestedTls(AttestedTlsError::Attestation( - // AttestationError::MeasurementsNotAccepted - // )) - // )); - // } - // - // #[tokio::test] - // async fn http_proxy_client_reconnects_on_lost_connection() { - // init_tracing(); - // - // let target_addr = example_http_service().await; - // - // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - // - // let proxy_server = ProxyServer::new_with_tls_config( - // cert_chain, - // server_config, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::expect_none(), - // ) - // .await - // .unwrap(); - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // // This is used to trigger a dropped connection to the proxy server - // let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); - // - // tokio::spawn(async move { - // let connection_handle = proxy_server.accept().await.unwrap(); - // - // // Wait for a signal to simulate a dropped connection, then drop the task handling the - // // connection - // connection_breaker_rx.await.unwrap(); - // connection_handle.abort(); - // - // // Now accept another connection - // proxy_server.accept().await.unwrap(); - // }); - // - // let proxy_client = ProxyClient::new_with_tls_config( - // client_config, - // "127.0.0.1:0".to_string(), - // proxy_addr.to_string(), - // AttestationGenerator::with_no_attestation(), - // AttestationVerifier::mock(), - // None, - // ) - // .await - // .unwrap(); - // - // let proxy_client_addr = proxy_client.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_client.accept().await.unwrap(); - // proxy_client.accept().await.unwrap(); - // }); - // - // let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - // .await - // .unwrap(); - // - // // Now break the connection - // connection_breaker_tx.send(()).unwrap(); - // - // // Make another request - // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - // .await - // .unwrap(); - // - // let headers = res.headers(); - // - // let attestation_type = headers - // .get(ATTESTATION_TYPE_HEADER) - // .unwrap() - // .to_str() - // .unwrap(); - // - // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - // - // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - // let measurements = - // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - // .unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // - // let res_body = res.text().await.unwrap(); - // assert_eq!(res_body, "No measurements"); - // } - // - // // Use HTTP 1.1 - // #[tokio::test] - // async fn http_proxy_with_http1() { - // let target_addr = example_http_service().await; - // - // let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - // let (mut server_config, client_config) = - // generate_tls_config(cert_chain.clone(), private_key); - // - // server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); - // - // let attested_tls_server = AttestedTlsServer::new_with_tls_config( - // cert_chain, - // server_config, - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::expect_none(), - // ) - // .unwrap(); - // - // let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - // - // let proxy_server = ProxyServer { - // attested_tls_server, - // listener: listener.into(), - // target: target_addr.to_string(), - // }; - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_server.accept().await.unwrap(); - // }); - // - // let proxy_client = ProxyClient::new_with_tls_config( - // client_config, - // "127.0.0.1:0".to_string(), - // proxy_addr.to_string(), - // AttestationGenerator::with_no_attestation(), - // AttestationVerifier::mock(), - // None, - // ) - // .await - // .unwrap(); - // - // let proxy_client_addr = proxy_client.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_client.accept().await.unwrap(); - // }); - // - // let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - // .await - // .unwrap(); - // - // let headers = res.headers(); - // - // let attestation_type = headers - // .get(ATTESTATION_TYPE_HEADER) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - // - // let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - // let measurements = - // MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - // .unwrap(); - // assert_eq!(measurements, mock_dcap_measurements()); - // - // let res_body = res.text().await.unwrap(); - // assert_eq!(res_body, "No measurements"); - // } + // Server has no attestation, client has mock DCAP and client auth + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_client_attestation() { + let target_addr = example_http_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain_for_host("localhost"); + let (client_cert_chain, client_private_key) = + generate_certificate_chain_for_host("localhost"); + + let ( + (_client_tls_server_config, client_tls_client_config), + (server_tls_server_config, _server_tls_client_config), + ) = generate_tls_config_with_client_auth( + client_cert_chain.clone(), + client_private_key, + server_cert_chain.clone(), + server_private_key, + ); + + let proxy_server = ProxyServer::new_with_tls_config_and_client_auth( + server_cert_chain, + server_tls_server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + true, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_tls_client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + Some(client_cert_chain), + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); + } + + // Server has no attestation, client has mock DCAP but no client auth + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_client_attestation_no_client_auth() { + let target_addr = example_http_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = + generate_tls_config(server_cert_chain.clone(), server_private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + server_cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + // Accept one connection, then finish + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + // Accept two connections, then finish + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + let _res_body = res.text().await.unwrap(); + } + + // Server has mock DCAP, client has mock DCAP and client auth + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_mutual_attestation() { + let target_addr = example_http_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain_for_host("localhost"); + let (client_cert_chain, client_private_key) = + generate_certificate_chain_for_host("localhost"); + + let ( + (_client_tls_server_config, client_tls_client_config), + (server_tls_server_config, _server_tls_client_config), + ) = generate_tls_config_with_client_auth( + client_cert_chain.clone(), + client_private_key, + server_cert_chain.clone(), + server_private_key, + ); + + let proxy_server = ProxyServer::new_with_tls_config_and_client_auth( + server_cert_chain, + server_tls_server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::mock(), + true, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_tls_client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::mock(), + Some(client_cert_chain), + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "No measurements"); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "No measurements"); + } + + // Server has mock DCAP, client no attestation - just get the server certificate + #[tokio::test(flavor = "multi_thread")] + async fn test_get_tls_cert() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain.clone(), + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_server_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let retrieved_chain = get_inner_tls_cert_with_config( + format!("localhost:{}", proxy_server_addr.port()), + AttestationVerifier::mock(), + client_config, + ) + .await + .unwrap(); + + assert_eq!(retrieved_chain.len(), 1); + assert_eq!( + hostname_from_cert(&retrieved_chain[0]).unwrap(), + "localhost" + ); + assert_ne!(retrieved_chain, cert_chain); + } + + // Negative test - server does not provide attestation but client requires it + // Server has no attestaion, client has no attestation and no client auth + #[tokio::test(flavor = "multi_thread")] + async fn fails_on_no_attestation_when_expected() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client_result = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await; + + let err = proxy_client_result.unwrap_err().to_string(); + assert!(err.contains("ApplicationVerificationFailure"), "{err}"); + } + + // Negative test - server does not provide attestation but client requires it + // Server has no attestaion, client has no attestation and no client auth + #[tokio::test(flavor = "multi_thread")] + async fn fails_on_bad_measurements() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let measurement_policy = MeasurementPolicy::from_json_bytes( + br#" + [{ + "measurement_id": "test", + "attestation_type": "dcap-tdx", + "measurements": { + "0": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, + "1": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, + "2": { "expected": "010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101" }, + "3": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" }, + "4": { "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" } + } + }] + "# + .to_vec(), + ) + .unwrap(); + + let attestation_verifier = AttestationVerifier { + measurement_policy, + pccs_url: None, + log_dcap_quote: false, + override_azure_outdated_tcb: false, + }; + + let proxy_client_result = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + attestation_verifier, + None, + ) + .await; + + let err = proxy_client_result.unwrap_err().to_string(); + assert!(err.contains("ApplicationVerificationFailure"), "{err}"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_client_reconnects_on_lost_connection() { + init_tracing(); + + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + // This is used to trigger a dropped connection to the proxy server + let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + + tokio::spawn(async move { + let connection_handle = proxy_server.accept().await.unwrap(); + + // Wait for a signal to simulate a dropped connection, then drop the task handling the + // connection + connection_breaker_rx.await.unwrap(); + connection_handle.abort(); + + // Now accept another connection + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + // Now break the connection + connection_breaker_tx.send(()).unwrap(); + + // Make another request + let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); + } + + // Use HTTP 1.1 + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_with_http1() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (mut server_config, client_config) = + generate_tls_config(cert_chain.clone(), private_key); + + server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); + } } diff --git a/src/main.rs_ b/src/main.rs similarity index 89% rename from src/main.rs_ rename to src/main.rs index d929778..4704059 100644 --- a/src/main.rs_ +++ b/src/main.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, ensure}; -use attested_tls::attestation::measurements::MultiMeasurements; +use attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}; use clap::{Parser, Subcommand}; use std::{ fs::File, @@ -11,14 +11,8 @@ use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ - AttestationGenerator, ProxyClient, ProxyServer, - attested_get::attested_get, - attested_tls::{ - TlsCertAndKey, - attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}, - }, - file_server::attested_file_server, - get_tls_cert, health_check, + AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey, attested_get::attested_get, + file_server::attested_file_server, get_inner_tls_cert, health_check, normalize_pem::normalize_private_key_pem_to_pkcs8, }; @@ -280,29 +274,15 @@ async fn main() -> anyhow::Result<()> { AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap) .await?; - let client = if allow_self_signed { - let client_tls_config = - attested_tls_proxy::self_signed::client_tls_config_allow_self_signed()?; - ProxyClient::new_with_tls_config( - client_tls_config, - listen_addr, - target_addr, - client_attestation_generator, - attestation_verifier, - None, - ) - .await? - } else { - ProxyClient::new( - tls_cert_and_chain, - listen_addr, - target_addr, - client_attestation_generator, - attestation_verifier, - remote_tls_cert, - ) - .await? - }; + let client = ProxyClient::new( + tls_cert_and_chain, + listen_addr, + target_addr, + client_attestation_generator, + attestation_verifier, + remote_tls_cert, + ) + .await?; loop { if let Err(err) = client.accept().await { @@ -365,24 +345,19 @@ async fn main() -> anyhow::Result<()> { ), None => None, }; - let (cert_chain, measurements) = get_tls_cert( - server, - attestation_verifier, - remote_tls_cert, - allow_self_signed, - ) - .await?; - - // If the user chose to write measurements to a file as JSON - if let Some(path_to_write_measurements) = out_measurements { - std::fs::write( - path_to_write_measurements, - measurements - .unwrap_or(MultiMeasurements::NoAttestation) - .to_header_format()? - .as_bytes(), - )?; - } + let cert_chain = + get_inner_tls_cert(server, attestation_verifier, remote_tls_cert).await?; + + // // If the user chose to write measurements to a file as JSON + // if let Some(path_to_write_measurements) = out_measurements { + // std::fs::write( + // path_to_write_measurements, + // measurements + // .unwrap_or(MultiMeasurements::NoAttestation) + // .to_header_format()? + // .as_bytes(), + // )?; + // } println!("{}", certs_to_pem_string(&cert_chain)?); } CliCommand::AttestedFileServer { @@ -434,7 +409,6 @@ async fn main() -> anyhow::Result<()> { &url_path.unwrap_or_default(), attestation_verifier, remote_tls_cert, - allow_self_signed, ) .await?; @@ -467,9 +441,10 @@ fn load_tls_cert_and_key_server( return Err(anyhow!("Certificate chain provided but no private key")); } tracing::warn!("No TLS ceritifcate provided - generating self-signed"); - Ok(attested_tls_proxy::self_signed::generate_self_signed_cert( - ip, - )?) + todo!() + // Ok(attested_tls_proxy::self_signed::generate_self_signed_cert( + // ip, + // )?) } } diff --git a/src/self_signed.rs b/src/self_signed.rs deleted file mode 100644 index e41e826..0000000 --- a/src/self_signed.rs +++ /dev/null @@ -1,323 +0,0 @@ -use std::{net::IpAddr, sync::Arc}; -use tokio_rustls::rustls::{ - self, - crypto::CryptoProvider, - pki_types::{self, CertificateDer, PrivatePkcs8KeyDer}, -}; -use x509_parser::prelude::{FromDer, X509Certificate}; - -use crate::attested_tls::{AttestedTlsError, TlsCertAndKey}; - -/// Generate a self signed certifcate -pub fn generate_self_signed_cert(ip_address: IpAddr) -> Result { - let keypair = rcgen::KeyPair::generate()?; - let mut params = rcgen::CertificateParams::default(); - params - .subject_alt_names - .push(rcgen::SanType::IpAddress(ip_address)); - - let cert = params.self_signed(&keypair)?; - Ok(TlsCertAndKey { - cert_chain: vec![cert.der().clone()], - key: PrivatePkcs8KeyDer::from(keypair.serialize_der()).into(), - }) -} - -/// Client TLS configuration which accepts self-signed remote certificates -pub fn client_tls_config_allow_self_signed() -> Result { - Ok(rustls::ClientConfig::builder() - .dangerous() - .with_custom_certificate_verifier(SkipServerVerification::new()?) - .with_no_client_auth()) -} - -/// Used to allow verification of self-signed certificates -#[derive(Debug, Clone)] -pub struct SkipServerVerification { - supported_algs: rustls::crypto::WebPkiSupportedAlgorithms, -} - -impl SkipServerVerification { - pub fn new() -> Result, AttestedTlsError> { - Ok(Arc::new(Self { - supported_algs: Arc::new( - CryptoProvider::get_default().ok_or(AttestedTlsError::NoCryptoProvider)?, - ) - .clone() - .signature_verification_algorithms, - })) - } -} - -impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { - fn verify_server_cert( - &self, - end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &pki_types::ServerName<'_>, - _ocsp_response: &[u8], - _now: pki_types::UnixTime, - ) -> Result { - // Parse the certificate - let (_, cert) = X509Certificate::from_der(end_entity).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) - })?; - - // Verify signature - cert.verify_signature(None).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature) - })?; - - Ok(rustls::client::danger::ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls12_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls13_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_algs.supported_schemes() - } -} - -/// Used to allow verification of self-signed certificates during client authentication -#[derive(Debug)] -pub struct SkipClientVerification { - supported_algs: rustls::crypto::WebPkiSupportedAlgorithms, -} - -impl SkipClientVerification { - pub fn new() -> std::sync::Arc { - std::sync::Arc::new(Self { - supported_algs: Arc::new(CryptoProvider::get_default().unwrap()) - .clone() - .signature_verification_algorithms, - }) - } -} - -impl rustls::server::danger::ClientCertVerifier for SkipClientVerification { - fn verify_client_cert( - &self, - end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer], - _now: rustls::pki_types::UnixTime, - ) -> Result { - // Parse the certificate - let (_, cert) = X509Certificate::from_der(end_entity).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) - })?; - - // Verify signature - cert.verify_signature(None).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature) - })?; - Ok(rustls::server::danger::ClientCertVerified::assertion()) - } - - fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] { - &[] - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls12_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls13_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_algs.supported_schemes() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - AttestationGenerator, - attestation::{AttestationType, AttestationVerifier}, - attested_tls::{AttestedTlsClient, AttestedTlsServer}, - test_helpers::{generate_certificate_chain, generate_tls_config}, - }; - use tokio::net::TcpListener; - use tokio_rustls::rustls::pki_types::ServerName; - - #[tokio::test] - async fn self_signed_server_attestation() { - let cert_and_key = generate_self_signed_cert("127.0.0.1".parse().unwrap()).unwrap(); - - let server_config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone().to_vec(), - cert_and_key.key.clone_key(), - ) - .unwrap(); - - let server = AttestedTlsServer::new_with_tls_config( - cert_and_key.cert_chain, - server_config.into(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .unwrap(); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let server_addr = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (tcp_stream, _) = listener.accept().await.unwrap(); - let (_stream, _measurements, _attestation_type) = - server.handle_connection(tcp_stream).await.unwrap(); - }); - - let client_config = client_tls_config_allow_self_signed().unwrap(); - - let client = AttestedTlsClient::new_with_tls_config( - client_config.into(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .unwrap(); - - let (_stream, _measurements, _attestation_type) = - client.connect_tcp(&server_addr.to_string()).await.unwrap(); - } - - #[tokio::test] - async fn nested_tls_with_self_signed_server_attestation() { - // Outer TLS setup - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (outer_server_config, outer_client_config) = - generate_tls_config(cert_chain.clone(), private_key); - - let outer_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(outer_server_config)); - let outer_connector = tokio_rustls::TlsConnector::from(Arc::new(outer_client_config)); - - // Inner TLS setup - let cert_and_key = generate_self_signed_cert("127.0.0.1".parse().unwrap()).unwrap(); - - let server_config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone().to_vec(), - cert_and_key.key.clone_key(), - ) - .unwrap(); - - let server = AttestedTlsServer::new_with_tls_config( - cert_and_key.cert_chain, - server_config.into(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .unwrap(); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let server_addr = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (tcp_stream, _) = listener.accept().await.unwrap(); - - // Do outer TLS handshake - let tls_stream = outer_acceptor.accept(tcp_stream).await.unwrap(); - - // Do inner (attested) TLS - let (_stream, _measurements, _attestation_type) = - server.handle_connection(tls_stream).await.unwrap(); - }); - - // Inner TLS config - let client_config = client_tls_config_allow_self_signed().unwrap(); - - let client = AttestedTlsClient::new_with_tls_config( - client_config.into(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .unwrap(); - - let client_tcp_stream = tokio::net::TcpStream::connect(&server_addr).await.unwrap(); - - // Outer TLS handshake - let server_name = ServerName::try_from(server_addr.ip().to_string()).unwrap(); - let tls_stream = outer_connector - .connect(server_name, client_tcp_stream) - .await - .unwrap(); - - // Inner (attested) TLS - let (_stream, _measurements, _attestation_type) = client - .connect(&server_addr.to_string(), tls_stream) - .await - .unwrap(); - } -} diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 66ddc60..939deb4 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -160,17 +160,6 @@ async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse { "No measurements".to_string() } -/// All-zero measurment values used in some tests -pub fn mock_dcap_measurements() -> MultiMeasurements { - MultiMeasurements::Dcap(HashMap::from([ - (DcapMeasurementRegister::MRTD, [0u8; 48]), - (DcapMeasurementRegister::RTMR0, [0u8; 48]), - (DcapMeasurementRegister::RTMR1, [0u8; 48]), - (DcapMeasurementRegister::RTMR2, [0u8; 48]), - (DcapMeasurementRegister::RTMR3, [0u8; 48]), - ])) -} - pub fn init_tracing() { INIT.call_once(|| { let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); From 91e35e479fc2e899af9fa3a735bf905d0d8cdb73 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 17 Mar 2026 12:13:42 +0100 Subject: [PATCH 04/20] Rm no longer needed allow self signed option --- src/main.rs | 35 ++++++----------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/src/main.rs b/src/main.rs index 4704059..1c56f20 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,7 @@ use anyhow::{anyhow, ensure}; use attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}; use clap::{Parser, Subcommand}; -use std::{ - fs::File, - net::{IpAddr, SocketAddr}, - path::PathBuf, -}; +use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio::io::AsyncWriteExt; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tracing::level_filters::LevelFilter; @@ -78,9 +74,6 @@ enum CliCommand { // Address to listen on for health checks #[arg(long)] listen_addr_healthcheck: Option, - /// Enables verification of self-signed TLS certificates - #[arg(long)] - allow_self_signed: bool, }, /// Run a proxy server Server { @@ -118,9 +111,6 @@ enum CliCommand { /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. #[arg(long)] tls_ca_certificate: Option, - /// Enables verification of self-signed TLS certificates - #[arg(long)] - allow_self_signed: bool, /// Filename to write measurements as JSON to #[arg(long)] out_measurements: Option, @@ -158,9 +148,6 @@ enum CliCommand { /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. #[arg(long)] tls_ca_certificate: Option, - /// Enables verification of self-signed TLS certificates - #[arg(long)] - allow_self_signed: bool, }, } @@ -235,7 +222,6 @@ async fn main() -> anyhow::Result<()> { tls_ca_certificate, dev_dummy_dcap, listen_addr_healthcheck, - allow_self_signed, } => { let target_addr = target_addr .strip_prefix("https://") @@ -304,11 +290,8 @@ async fn main() -> anyhow::Result<()> { health_check::server(listen_addr_healthcheck).await?; } - let tls_cert_and_chain = load_tls_cert_and_key_server( - tls_certificate_path, - tls_private_key_path, - listen_addr.ip(), - )?; + let tls_cert_and_chain = + load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; let local_attestation_generator = AttestationGenerator::new_with_detection(server_attestation_type, dev_dummy_dcap) @@ -333,7 +316,6 @@ async fn main() -> anyhow::Result<()> { CliCommand::GetTlsCert { server, tls_ca_certificate, - allow_self_signed, out_measurements, } => { let remote_tls_cert = match tls_ca_certificate { @@ -392,7 +374,6 @@ async fn main() -> anyhow::Result<()> { target_addr, url_path, tls_ca_certificate, - allow_self_signed, } => { let remote_tls_cert = match tls_ca_certificate { Some(remote_cert_filename) => Some( @@ -429,7 +410,6 @@ async fn main() -> anyhow::Result<()> { fn load_tls_cert_and_key_server( cert_chain: Option, private_key: Option, - ip: IpAddr, ) -> anyhow::Result { if let Some(private_key) = private_key { load_tls_cert_and_key( @@ -438,13 +418,10 @@ fn load_tls_cert_and_key_server( ) } else { if cert_chain.is_some() { - return Err(anyhow!("Certificate chain provided but no private key")); + Err(anyhow!("Certificate chain provided but no private key")) + } else { + Err(anyhow!("No private key provided")) } - tracing::warn!("No TLS ceritifcate provided - generating self-signed"); - todo!() - // Ok(attested_tls_proxy::self_signed::generate_self_signed_cert( - // ip, - // )?) } } From ad6ff187238a4e76375242a1fa9405f1fc05076c Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 17 Mar 2026 12:15:10 +0100 Subject: [PATCH 05/20] Bump dcap-qvl --- Cargo.lock | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e01c501..9ed81b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1676,7 +1676,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1900,7 +1900,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3045,7 +3045,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3684,7 +3684,7 @@ dependencies = [ "once_cell", "socket2 0.6.1", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -4124,7 +4124,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -4743,7 +4743,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] From 3e6140ed52becf1744a50a6307a87464786f86bf Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 17 Mar 2026 12:26:14 +0100 Subject: [PATCH 06/20] Fixes attested-tls crate branch --- Cargo.lock | 38 +++----------------------------------- attested-tls/Cargo.toml | 4 ++-- src/main.rs | 10 ++++------ src/test_helpers.rs | 26 +------------------------- 4 files changed, 10 insertions(+), 68 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ed81b1..a169090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -558,38 +558,6 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" -[[package]] -name = "attestation" -version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fadd-attestation-crate#4ebc03703510e65fd1317736b8887fc388860481" -dependencies = [ - "anyhow", - "az-tdx-vtpm", - "base64 0.22.1", - "configfs-tsm", - "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", - "hex", - "http", - "num-bigint", - "once_cell", - "openssl", - "parity-scale-codec", - "pem-rfc7468", - "rand_core 0.6.4", - "reqwest", - "rustls-webpki", - "serde", - "serde_json", - "tdx-quote", - "thiserror 2.0.17", - "time", - "tokio", - "tokio-rustls", - "tracing", - "tss-esapi", - "x509-parser 0.18.1", -] - [[package]] name = "attestation" version = "0.0.1" @@ -644,7 +612,7 @@ version = "0.0.1" dependencies = [ "alloy-rpc-client", "alloy-transport-http", - "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fadd-attestation-crate)", + "attestation", "bytes", "futures-util", "http", @@ -673,7 +641,7 @@ version = "0.0.1" source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" dependencies = [ "anyhow", - "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "attestation", "ra-tls", "rcgen 0.14.7", "rustls", @@ -691,7 +659,7 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "attestation", "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", "axum", "bytes", diff --git a/attested-tls/Cargo.toml b/attested-tls/Cargo.toml index 38a810e..d40e7b6 100644 --- a/attested-tls/Cargo.toml +++ b/attested-tls/Cargo.toml @@ -18,7 +18,7 @@ http = "1.3.1" serde_json = "1.0.145" tracing = "0.1.41" parity-scale-codec = "3.7.5" -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/add-attestation-crate" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } # Used for websocket support tokio-tungstenite = { version = "0.28.0", optional = true } @@ -40,7 +40,7 @@ rcgen = { version = "0.14.5", optional = true } [dev-dependencies] rcgen = "0.14.5" tempfile = "3.23.0" -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/add-attestation-crate", features = ["mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } [features] default = ["ws", "rpc"] diff --git a/src/main.rs b/src/main.rs index 1c56f20..c14c414 100644 --- a/src/main.rs +++ b/src/main.rs @@ -316,7 +316,7 @@ async fn main() -> anyhow::Result<()> { CliCommand::GetTlsCert { server, tls_ca_certificate, - out_measurements, + out_measurements: _, // TODO } => { let remote_tls_cert = match tls_ca_certificate { Some(remote_cert_filename) => Some( @@ -416,12 +416,10 @@ fn load_tls_cert_and_key_server( cert_chain.ok_or(anyhow!("Private key given but no certificate chain"))?, private_key, ) + } else if cert_chain.is_some() { + Err(anyhow!("Certificate chain provided but no private key")) } else { - if cert_chain.is_some() { - Err(anyhow!("Certificate chain provided but no private key")) - } else { - Err(anyhow!("No private key provided")) - } + Err(anyhow!("No private key provided")) } } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 939deb4..431c5f8 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -1,8 +1,7 @@ //! Helper functions used in tests use axum::response::IntoResponse; use std::{ - collections::HashMap, - net::{IpAddr, SocketAddr}, + net::SocketAddr, sync::{Arc, Once}, }; use tokio::net::TcpListener; @@ -15,29 +14,6 @@ use tracing_subscriber::{EnvFilter, fmt}; static INIT: Once = Once::new(); -use attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}; - -/// Helper to generate a self-signed certificate for testing -pub fn generate_certificate_chain( - ip: IpAddr, -) -> (Vec>, PrivateKeyDer<'static>) { - let mut params = rcgen::CertificateParams::new(vec![]).unwrap(); - params.subject_alt_names.push(rcgen::SanType::IpAddress(ip)); - params - .subject_alt_names - .push(rcgen::SanType::DnsName(ip.to_string().try_into().unwrap())); - params - .distinguished_name - .push(rcgen::DnType::CommonName, ip.to_string()); - - let keypair = rcgen::KeyPair::generate().unwrap(); - let cert = params.self_signed(&keypair).unwrap(); - - let certs = vec![CertificateDer::from(cert)]; - let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(keypair.serialize_der())); - (certs, key) -} - /// Helper to generate a self-signed certificate for testing with a DNS subject name pub fn generate_certificate_chain_for_host( host: &str, From 9dfb3f2b9bb6628400fb66b11a796ef5c0fa08b6 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 17 Mar 2026 12:35:17 +0100 Subject: [PATCH 07/20] Rm unwraps --- src/lib.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e711527..c8a7841 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -686,7 +686,7 @@ impl ProxyClient { ) -> Result<(HttpSender, HttpConnection), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; - let domain = server_name_from_host(target).unwrap(); + let domain = server_name_from_host(target)?; let tls_stream = nesting_tls_connector .connect(domain, outbound_stream) .await?; @@ -785,6 +785,12 @@ pub enum ProxyError { IntConversion(#[from] TryFromIntError), #[error("Bad host name: {0}")] BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), + #[error("Invalid certificate encoding")] + InvalidCertificateEncoding, + #[error("Missing common name in certificate subject")] + MissingCertificateName, + #[error("Certificate common name is not valid UTF-8")] + InvalidCertificateName, #[error("HTTP: {0}")] Hyper(#[from] hyper::Error), #[error("Attested TLS: {0}")] @@ -809,17 +815,15 @@ impl From> for ProxyError { fn hostname_from_cert(cert: &CertificateDer<'static>) -> Result { let cert = x509_parser::parse_x509_certificate(cert.as_ref()) .map(|(_, parsed)| parsed) - .unwrap(); + .map_err(|_| ProxyError::InvalidCertificateEncoding)?; Ok(cert .subject() .iter_common_name() .next() - .unwrap() - // .ok_or_else(|| Self::bad_encoding("Missing common name"))? + .ok_or(ProxyError::MissingCertificateName)? .as_str() - // .map_err(|err| Self::bad_encoding(format!("Invalid common name: {err}"))) - .unwrap() + .map_err(|_| ProxyError::InvalidCertificateName)? .to_string()) } From 4650de63620f288675bb75dafbfc7c0f2299583b Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 19 Mar 2026 10:19:03 +0100 Subject: [PATCH 08/20] Fix ALPN --- src/lib.rs | 66 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c8a7841..319d07e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,9 +97,8 @@ pub async fn get_inner_tls_cert( pub async fn get_inner_tls_cert_with_config( server_name: String, attestation_verifier: AttestationVerifier, - mut outer_client_config: ClientConfig, + outer_client_config: ClientConfig, ) -> Result>, ProxyError> { - ensure_proxy_alpn_protocols(&mut outer_client_config.alpn_protocols); let outbound_stream = tokio::net::TcpStream::connect(&server_name).await?; let domain = server_name_from_host(&server_name)?; @@ -205,19 +204,18 @@ impl ProxyServer { /// Start with preconfigured TLS and require client auth on both nested sessions pub async fn new_with_tls_config_and_client_auth( cert_chain: Vec>, - mut outer_server_config: ServerConfig, + outer_server_config: ServerConfig, local: impl ToSocketAddrs, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, ) -> Result { - ensure_proxy_alpn_protocols(&mut outer_server_config.alpn_protocols); let server_name = certificate_identity_from_chain(&cert_chain)?; let inner_cert_resolver = build_attested_cert_resolver(attestation_generator, &cert_chain).await?; - let inner_server_config = if client_auth { + let mut inner_server_config = if client_auth { let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) @@ -230,6 +228,8 @@ impl ProxyServer { .with_cert_resolver(Arc::new(inner_cert_resolver)) }; + ensure_proxy_alpn_protocols(&mut inner_server_config.alpn_protocols); + let nesting_tls_acceptor = NestingTlsAcceptor::new(Arc::new(outer_server_config), Arc::new(inner_server_config)); let listener = TcpListener::bind(local).await?; @@ -440,14 +440,13 @@ impl ProxyClient { /// Create a new proxy client with given TLS configuration pub async fn new_with_tls_config( - mut outer_client_config: ClientConfig, + outer_client_config: ClientConfig, address: impl ToSocketAddrs, target_name: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { - ensure_proxy_alpn_protocols(&mut outer_client_config.alpn_protocols); let outer_has_client_auth = outer_client_config.client_auth_cert_resolver.has_certs(); let inner_has_client_auth = cert_chain.is_some(); @@ -457,7 +456,7 @@ impl ProxyClient { let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; - let inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { + let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { let inner_cert_resolver = build_attested_cert_resolver(attestation_generator, cert_chain).await?; ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) @@ -470,6 +469,7 @@ impl ProxyClient { .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) .with_no_client_auth() }; + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); @@ -905,6 +905,56 @@ mod tests { assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]); } + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_negotiates_http2_by_default() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, outer_client_config) = + generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let attested_cert_verifier = + AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap(); + let mut inner_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + + let nesting_tls_connector = NestingTlsConnector::new( + Arc::new(outer_client_config), + Arc::new(inner_client_config), + ); + + let (sender, conn) = ProxyClient::setup_connection( + &nesting_tls_connector, + &format!("localhost:{}", proxy_addr.port()), + ) + .await + .unwrap(); + + assert!(matches!(sender, HttpSender::Http2(_))); + assert!(matches!(conn, HttpConnection::Http2 { .. })); + } + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_default_constructors_work() { let target_addr = example_http_service().await; From 4fb29e2e4fad5a864f71cbe31a40bed737cd56dd Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 19 Mar 2026 10:21:50 +0100 Subject: [PATCH 09/20] Fmt --- src/lib.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 319d07e..cf4c327 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -939,10 +939,8 @@ mod tests { .with_no_client_auth(); ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); - let nesting_tls_connector = NestingTlsConnector::new( - Arc::new(outer_client_config), - Arc::new(inner_client_config), - ); + let nesting_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); let (sender, conn) = ProxyClient::setup_connection( &nesting_tls_connector, From f6817210c96c15b5edbcfeab5f971b364825b83d Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 19 Mar 2026 11:57:37 +0100 Subject: [PATCH 10/20] Make both nested and inner only listeners --- src/attested_get.rs | 6 +- src/file_server.rs | 16 +- src/lib.rs | 593 ++++++++++++++++++++++++++++++++------------ src/main.rs | 54 ++-- 4 files changed, 474 insertions(+), 195 deletions(-) diff --git a/src/attested_get.rs b/src/attested_get.rs index 3b4575f..97b346d 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -78,12 +78,14 @@ mod tests { // Setup a proxy server targetting the static file server let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, + Some("localhost".to_string()), ) .await .unwrap(); diff --git a/src/file_server.rs b/src/file_server.rs index 424008e..8390c9f 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -7,8 +7,9 @@ use tower_http::services::ServeDir; /// Setup a static file server serving the given directory, and a proxy server targetting it pub async fn attested_file_server( path_to_serve: PathBuf, - cert_and_key: TlsCertAndKey, - listen_addr: impl ToSocketAddrs, + outer_cert_and_key: Option, + outer_listen_addr: impl ToSocketAddrs, + inner_listen_addr: impl ToSocketAddrs, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, @@ -16,8 +17,9 @@ pub async fn attested_file_server( let target_addr = static_file_server(path_to_serve).await?; let server = ProxyServer::new( - cert_and_key, - listen_addr, + outer_cert_and_key, + Some(outer_listen_addr), + inner_listen_addr, target_addr.to_string(), attestation_generator, attestation_verifier, @@ -99,12 +101,14 @@ mod tests { // Setup a proxy server targetting the static file server let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, + Some("localhost".to_string()), ) .await .unwrap(); diff --git a/src/lib.rs b/src/lib.rs index cf4c327..a36d5b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ use thiserror::Error; use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; +use tokio_rustls::TlsAcceptor; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; use tokio_rustls::rustls::{ self, ClientConfig, RootCertStore, ServerConfig, @@ -46,6 +47,7 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; +const DEFAULT_INNER_CERTIFICATE_NAME: &str = "localhost"; type RequestWithResponseSender = ( http::Request, @@ -132,111 +134,123 @@ pub async fn get_inner_tls_cert_with_config( /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { - nesting_tls_acceptor: NestingTlsAcceptor, - /// The underlying TCP listener - listener: Arc, + outer_listener: Option>, + outer_tls_acceptor: Option, + inner_listener: Arc, + inner_tls_acceptor: TlsAcceptor, /// The address/hostname of the target service we are proxying to target: String, } impl ProxyServer { - pub async fn new( - cert_and_key: TlsCertAndKey, - local: impl ToSocketAddrs, + /// Start with dual listeners. The outer nested-TLS listener is optional. + pub async fn new( + outer_cert_and_key: Option, + outer_local: Option, + inner_local: impl ToSocketAddrs, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, - ) -> Result { - let outer_server_config = if client_auth { - let root_store = - RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_client_cert_verifier(verifier) - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - } else { - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? + ) -> Result + where + O: ToSocketAddrs, + { + if outer_cert_and_key.is_some() && outer_local.is_none() { + return Err(ProxyError::OuterTlsWithoutOuterListener); + } + + let outer_certificate_name = outer_cert_and_key + .as_ref() + .map(|cert_and_key| certificate_identity_from_chain(&cert_and_key.cert_chain)) + .transpose()?; + let outer_server_config = match outer_cert_and_key { + Some(cert_and_key) => { + let config = if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(verifier) + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + }; + Some(config) + } + None => None, }; - Self::new_with_tls_config_and_client_auth( - cert_and_key.cert_chain, + Self::new_with_tls_config( outer_server_config, - local, + outer_local, + inner_local, target, attestation_generator, attestation_verifier, client_auth, + outer_certificate_name, ) .await } /// Start with preconfigured TLS - pub async fn new_with_tls_config( - cert_chain: Vec>, - outer_server_config: ServerConfig, - local: impl ToSocketAddrs, - target: String, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, - ) -> Result { - Self::new_with_tls_config_and_client_auth( - cert_chain, - outer_server_config, - local, - target, - attestation_generator, - attestation_verifier, - false, - ) - .await - } - - /// Start with preconfigured TLS and require client auth on both nested sessions - pub async fn new_with_tls_config_and_client_auth( - cert_chain: Vec>, - outer_server_config: ServerConfig, - local: impl ToSocketAddrs, + pub async fn new_with_tls_config( + outer_server_config: Option, + outer_local: Option, + inner_local: impl ToSocketAddrs, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, - ) -> Result { - let server_name = certificate_identity_from_chain(&cert_chain)?; - let inner_cert_resolver = - build_attested_cert_resolver(attestation_generator, &cert_chain).await?; - - let mut inner_server_config = if client_auth { - let attested_cert_verifier = - AttestedCertificateVerifier::new(None, attestation_verifier)?; - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_client_cert_verifier(Arc::new(attested_cert_verifier)) - .with_cert_resolver(Arc::new(inner_cert_resolver)) - } else { - let _ = server_name; - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_cert_resolver(Arc::new(inner_cert_resolver)) - }; - - ensure_proxy_alpn_protocols(&mut inner_server_config.alpn_protocols); + certificate_name: Option, + ) -> Result + where + O: ToSocketAddrs, + { + if outer_server_config.is_some() && outer_local.is_none() { + return Err(ProxyError::OuterTlsWithoutOuterListener); + } - let nesting_tls_acceptor = - NestingTlsAcceptor::new(Arc::new(outer_server_config), Arc::new(inner_server_config)); - let listener = TcpListener::bind(local).await?; + let inner_server_config = Arc::new( + build_inner_server_config( + attestation_generator, + attestation_verifier, + client_auth, + certificate_name, + ) + .await?, + ); + let inner_listener = Arc::new(TcpListener::bind(inner_local).await?); + let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone()); + + let (outer_listener, outer_tls_acceptor) = match (outer_server_config, outer_local) { + (Some(outer_server_config), Some(outer_local)) => { + let outer_listener = Arc::new(TcpListener::bind(outer_local).await?); + let acceptor = NestingTlsAcceptor::new( + Arc::new(outer_server_config), + inner_server_config.clone(), + ); + (Some(outer_listener), Some(acceptor)) + } + (Some(_), None) => return Err(ProxyError::OuterTlsWithoutOuterListener), + (None, _) => (None, None), + }; Ok(Self { - nesting_tls_acceptor, - listener: listener.into(), + outer_listener, + outer_tls_acceptor, + inner_listener, + inner_tls_acceptor, target, }) } @@ -246,33 +260,92 @@ impl ProxyServer { /// Returns the handle for the task handling the connection pub async fn accept(&self) -> Result, ProxyError> { let target = self.target.clone(); - let (inbound, client_addr) = self.listener.accept().await?; - let nesting_tls_acceptor = self.nesting_tls_acceptor.clone(); + let outer_listener = self.outer_listener.clone(); + let outer_tls_acceptor = self.outer_tls_acceptor.clone(); + let inner_listener = self.inner_listener.clone(); + let inner_tls_acceptor = self.inner_tls_acceptor.clone(); + + let join_handle = match (outer_listener, outer_tls_acceptor) { + (Some(outer_listener), Some(outer_tls_acceptor)) => { + let ((inbound, client_addr), use_outer) = tokio::select! { + accepted = outer_listener.accept() => (accepted?, true), + accepted = inner_listener.accept() => (accepted?, false), + }; - let join_handle = tokio::spawn(async move { - match nesting_tls_acceptor.accept(inbound).await { - Ok(tls_stream) => { - if let Err(err) = Self::handle_connection(tls_stream, target, client_addr).await - { - warn!("Failed to handle connection: {err}"); + tokio::spawn(async move { + if use_outer { + match outer_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_outer_connection(tls_stream, target, client_addr) + .await + { + warn!("Failed to handle outer connection: {err}"); + } + } + Err(err) => { + warn!("Outer attestation exchange failed: {err}"); + } + } + } else { + match inner_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_inner_connection(tls_stream, target, client_addr) + .await + { + warn!("Failed to handle inner connection: {err}"); + } + } + Err(err) => { + warn!("Inner attestation exchange failed: {err}"); + } + } } - } - Err(err) => { - warn!("Attestation exchange failed: {err}"); - } + }) } - }); + _ => { + let (inbound, client_addr) = inner_listener.accept().await?; + tokio::spawn(async move { + match inner_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_inner_connection(tls_stream, target, client_addr).await + { + warn!("Failed to handle inner connection: {err}"); + } + } + Err(err) => { + warn!("Inner attestation exchange failed: {err}"); + } + } + }) + } + }; Ok(join_handle) } /// Helper to get the socket address of the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.listener.local_addr() + match &self.outer_listener { + Some(listener) => listener.local_addr(), + None => self.inner_listener.local_addr(), + } } - /// Handle an incoming connection from a proxy-client - async fn handle_connection( + pub fn outer_local_addr(&self) -> std::io::Result> { + self.outer_listener + .as_ref() + .map(|listener| listener.local_addr()) + .transpose() + } + + pub fn inner_local_addr(&self) -> std::io::Result { + self.inner_listener.local_addr() + } + + async fn handle_outer_connection( tls_stream: NestingTlsStream, target: String, client_addr: SocketAddr, @@ -280,7 +353,29 @@ impl ProxyServer { debug!("[proxy-server] accepted connection"); let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); + Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + } + async fn handle_inner_connection( + tls_stream: tokio_rustls::server::TlsStream, + target: String, + client_addr: SocketAddr, + ) -> Result<(), ProxyError> { + debug!("[proxy-server] accepted inner-only connection"); + + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); + Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + } + + async fn serve_tls_stream( + tls_stream: IO, + http_version: HttpVersion, + target: String, + client_addr: SocketAddr, + ) -> Result<(), ProxyError> + where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + { // Setup a request handler let service = service_fn(move |mut req| { debug!("[proxy-server] Handling request {req:?}"); @@ -457,8 +552,11 @@ impl ProxyClient { let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { - let inner_cert_resolver = - build_attested_cert_resolver(attestation_generator, cert_chain).await?; + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + Some(certificate_identity_from_chain(cert_chain)?), + ) + .await?; ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .dangerous() .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) @@ -803,6 +901,8 @@ pub enum ProxyError { MpscSend, #[error("Client auth must be configured on both the inner and outer TLS sessions")] ClientAuthMisconfigured, + #[error("Outer TLS configuration requires an outer listener address")] + OuterTlsWithoutOuterListener, } impl From> for ProxyError { @@ -835,13 +935,40 @@ fn certificate_identity_from_chain( async fn build_attested_cert_resolver( attestation_generator: AttestationGenerator, - cert_chain: &[CertificateDer<'static>], + certificate_name: Option, ) -> Result { - let certificate_name = certificate_identity_from_chain(cert_chain)?; - Ok( - AttestedCertificateResolver::new(attestation_generator, None, certificate_name, vec![]) - .await?, + Ok(AttestedCertificateResolver::new( + attestation_generator, + None, + certificate_name.unwrap_or_else(|| DEFAULT_INNER_CERTIFICATE_NAME.to_string()), + vec![], ) + .await?) +} + +async fn build_inner_server_config( + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + client_auth: bool, + certificate_name: Option, +) -> Result { + let inner_cert_resolver = + build_attested_cert_resolver(attestation_generator, certificate_name).await?; + + let mut inner_server_config = if client_auth { + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(Arc::new(attested_cert_verifier)) + .with_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_cert_resolver(Arc::new(inner_cert_resolver)) + }; + + ensure_proxy_alpn_protocols(&mut inner_server_config.alpn_protocols); + + Ok(inner_server_config) } /// If no port was provided, default to 443 @@ -882,6 +1009,7 @@ where #[cfg(test)] mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; + use tokio_rustls::TlsConnector; use super::*; use test_helpers::{ @@ -906,25 +1034,94 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] - async fn http_proxy_negotiates_http2_by_default() { + async fn dual_listener_server_reports_expected_addresses() { let target_addr = example_http_service().await; let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); - let (server_config, outer_client_config) = - generate_tls_config(cert_chain.clone(), private_key); + let tls_cert_and_key = TlsCertAndKey { + cert_chain, + key: private_key, + }; - let proxy_server = ProxyServer::new_with_tls_config( + let dual_listener_server = ProxyServer::new( + Some(tls_cert_and_key), + Some("127.0.0.1:0"), + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let outer_addr = dual_listener_server.outer_local_addr().unwrap().unwrap(); + let inner_addr = dual_listener_server.inner_local_addr().unwrap(); + assert_eq!(dual_listener_server.local_addr().unwrap(), outer_addr); + assert_ne!(outer_addr, inner_addr); + + let inner_only_server = ProxyServer::new( + None, + None::<&str>, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let inner_only_addr = inner_only_server.inner_local_addr().unwrap(); + assert!(inner_only_server.outer_local_addr().unwrap().is_none()); + assert_eq!(inner_only_server.local_addr().unwrap(), inner_only_addr); + } + + #[tokio::test(flavor = "multi_thread")] + async fn outer_tls_requires_outer_listener_address() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let tls_cert_and_key = TlsCertAndKey { cert_chain, - server_config, + key: private_key, + }; + + let result = ProxyServer::new( + Some(tls_cert_and_key), + None::<&str>, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::OuterTlsWithoutOuterListener) + )); + } + + #[tokio::test(flavor = "multi_thread")] + async fn inner_only_listener_negotiates_http2_by_default() { + let _ = rustls::crypto::ring::default_provider().install_default(); + let target_addr = example_http_service().await; + + let proxy_server = ProxyServer::new( + None, + None::<&str>, "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); - let proxy_addr = proxy_server.local_addr().unwrap(); + let inner_addr = proxy_server.inner_local_addr().unwrap(); tokio::spawn(async move { proxy_server.accept().await.unwrap(); @@ -932,44 +1129,46 @@ mod tests { let attested_cert_verifier = AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap(); - let mut inner_client_config = + let mut client_config = ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .dangerous() .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) .with_no_client_auth(); - ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); - let nesting_tls_connector = - NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); + let tls_connector = TlsConnector::from(Arc::new(client_config)); + let outbound_stream = TcpStream::connect(inner_addr).await.unwrap(); + let domain = ServerName::try_from("localhost".to_string()).unwrap(); + let mut tls_stream = tls_connector + .connect(domain, outbound_stream) + .await + .unwrap(); - let (sender, conn) = ProxyClient::setup_connection( - &nesting_tls_connector, - &format!("localhost:{}", proxy_addr.port()), - ) - .await - .unwrap(); + assert!(matches!( + HttpVersion::from_negotiated_protocol_client(&tls_stream), + HttpVersion::Http2 + )); - assert!(matches!(sender, HttpSender::Http2(_))); - assert!(matches!(conn, HttpConnection::Http2 { .. })); + tls_stream.shutdown().await.unwrap(); } #[tokio::test(flavor = "multi_thread")] - async fn http_proxy_default_constructors_work() { + async fn http_proxy_negotiates_http2_by_default() { let target_addr = example_http_service().await; let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); - let server_cert = cert_chain[0].clone(); + let (server_config, outer_client_config) = + generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new( - TlsCertAndKey { - cert_chain, - key: private_key, - }, + let proxy_server = ProxyServer::new_with_tls_config( + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, + Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -980,31 +1179,81 @@ mod tests { proxy_server.accept().await.unwrap(); }); - let proxy_client = ProxyClient::new( - None, - "127.0.0.1:0".to_string(), - format!("localhost:{}", proxy_addr.port()), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - Some(server_cert), + let attested_cert_verifier = + AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap(); + let mut inner_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + + let nesting_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); + + let (sender, conn) = ProxyClient::setup_connection( + &nesting_tls_connector, + &format!("localhost:{}", proxy_addr.port()), ) .await .unwrap(); - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_client.accept().await.unwrap(); - }); - - let res = reqwest::get(format!("http://{}", proxy_client_addr)) - .await - .unwrap(); - - let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); + assert!(matches!(sender, HttpSender::Http2(_))); + assert!(matches!(conn, HttpConnection::Http2 { .. })); } + // #[tokio::test(flavor = "multi_thread")] + // async fn http_proxy_default_constructors_work() { + // let target_addr = example_http_service().await; + // + // let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + // let server_cert = cert_chain[0].clone(); + // + // let proxy_server = ProxyServer::new( + // TlsCertAndKey { + // cert_chain, + // key: private_key, + // }, + // "127.0.0.1:0", + // target_addr.to_string(), + // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + // AttestationVerifier::expect_none(), + // false, + // ) + // .await + // .unwrap(); + // + // let proxy_addr = proxy_server.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_server.accept().await.unwrap(); + // }); + // + // let proxy_client = ProxyClient::new( + // None, + // "127.0.0.1:0".to_string(), + // format!("localhost:{}", proxy_addr.port()), + // AttestationGenerator::with_no_attestation(), + // AttestationVerifier::mock(), + // Some(server_cert), + // ) + // .await + // .unwrap(); + // + // let proxy_client_addr = proxy_client.local_addr().unwrap(); + // + // tokio::spawn(async move { + // proxy_client.accept().await.unwrap(); + // }); + // + // let res = reqwest::get(format!("http://{}", proxy_client_addr)) + // .await + // .unwrap(); + // + // let res_body = res.text().await.unwrap(); + // assert_eq!(res_body, "No measurements"); + // } + // Server has mock DCAP, client has no attestation and no client auth #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_server_attestation() { @@ -1015,12 +1264,14 @@ mod tests { let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, + Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1076,14 +1327,15 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config_and_client_auth( - server_cert_chain, - server_tls_server_config, + let proxy_server = ProxyServer::new_with_tls_config( + Some(server_tls_server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), true, + Some(certificate_identity_from_chain(&server_cert_chain).unwrap()), ) .await .unwrap(); @@ -1130,12 +1382,14 @@ mod tests { generate_tls_config(server_cert_chain.clone(), server_private_key); let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), + false, + Some(certificate_identity_from_chain(&server_cert_chain).unwrap()), ) .await .unwrap(); @@ -1193,14 +1447,15 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config_and_client_auth( - server_cert_chain, - server_tls_server_config, + let proxy_server = ProxyServer::new_with_tls_config( + Some(server_tls_server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), true, + Some(certificate_identity_from_chain(&server_cert_chain).unwrap()), ) .await .unwrap(); @@ -1249,12 +1504,14 @@ mod tests { let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new_with_tls_config( - cert_chain.clone(), - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, + Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1291,12 +1548,14 @@ mod tests { let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), + false, + Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1331,12 +1590,14 @@ mod tests { let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, + Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1396,12 +1657,14 @@ mod tests { let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, + Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1469,12 +1732,14 @@ mod tests { server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, + Some(server_config), + Some("127.0.0.1:0"), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, + Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); diff --git a/src/main.rs b/src/main.rs index c14c414..e198b3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -77,9 +77,12 @@ enum CliCommand { }, /// Run a proxy server Server { - /// Socket address to listen on - #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] - listen_addr: SocketAddr, + /// Socket address to listen on for the outer nested-TLS listener + #[arg(long, default_value = "0.0.0.0:443")] + outer_listen_addr: SocketAddr, + /// Socket address to listen on for the inner-only attested TLS listener + #[arg(long, default_value = "0.0.0.0:4433")] + inner_listen_addr: SocketAddr, /// The hostname:port or ip:port of the target service to forward traffic to target_addr: String, /// Type of attestation to present (dafaults to 'auto' for automatic detection) @@ -119,19 +122,22 @@ enum CliCommand { AttestedFileServer { /// Filesystem path to statically serve path_to_serve: PathBuf, - /// Socket address to listen on - #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] - listen_addr: SocketAddr, + /// Socket address to listen on for the outer nested-TLS listener + #[arg(long, default_value = "0.0.0.0:443")] + outer_listen_addr: SocketAddr, + /// Socket address to listen on for the inner-only attested TLS listener + #[arg(long, default_value = "0.0.0.0:4433")] + inner_listen_addr: SocketAddr, /// Type of attestation to present (dafaults to none) /// If other than None, a TLS key and certicate must also be given #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, /// The path to a PEM encoded private key #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] - tls_private_key_path: PathBuf, + tls_private_key_path: Option, /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. #[arg(long, env = "TLS_CERTIFICATE_PATH")] - tls_certificate_path: PathBuf, + tls_certificate_path: Option, /// URL of the remote dummy attestation service. Only use with --server-attestation-type /// dummy #[arg(long)] @@ -277,7 +283,8 @@ async fn main() -> anyhow::Result<()> { } } CliCommand::Server { - listen_addr, + outer_listen_addr, + inner_listen_addr, target_addr, tls_private_key_path, tls_certificate_path, @@ -299,7 +306,8 @@ async fn main() -> anyhow::Result<()> { let server = ProxyServer::new( tls_cert_and_chain, - listen_addr, + Some(outer_listen_addr), + inner_listen_addr, target_addr, local_attestation_generator, attestation_verifier, @@ -344,14 +352,15 @@ async fn main() -> anyhow::Result<()> { } CliCommand::AttestedFileServer { path_to_serve, - listen_addr, + outer_listen_addr, + inner_listen_addr, server_attestation_type, tls_private_key_path, tls_certificate_path, dev_dummy_dcap, } => { let tls_cert_and_chain = - load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?; + load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; let server_attestation_type: AttestationType = serde_json::from_value( serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())), @@ -363,7 +372,8 @@ async fn main() -> anyhow::Result<()> { attested_file_server( path_to_serve, tls_cert_and_chain, - listen_addr, + outer_listen_addr, + inner_listen_addr, attestation_generator, attestation_verifier, false, @@ -410,16 +420,14 @@ async fn main() -> anyhow::Result<()> { fn load_tls_cert_and_key_server( cert_chain: Option, private_key: Option, -) -> anyhow::Result { - if let Some(private_key) = private_key { - load_tls_cert_and_key( - cert_chain.ok_or(anyhow!("Private key given but no certificate chain"))?, - private_key, - ) - } else if cert_chain.is_some() { - Err(anyhow!("Certificate chain provided but no private key")) - } else { - Err(anyhow!("No private key provided")) +) -> anyhow::Result> { + match (cert_chain, private_key) { + (Some(cert_chain), Some(private_key)) => { + Ok(Some(load_tls_cert_and_key(cert_chain, private_key)?)) + } + (Some(_), None) => Err(anyhow!("Certificate chain provided but no private key")), + (None, Some(_)) => Err(anyhow!("Private key given but no certificate chain")), + (None, None) => Ok(None), } } From b52645ac6e1afcd44e04ed7aefa8f586f68c453f Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 19 Mar 2026 12:14:12 +0100 Subject: [PATCH 11/20] Clippy --- src/file_server.rs | 18 +++++++----------- src/lib.rs | 12 +++++++----- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/file_server.rs b/src/file_server.rs index 8390c9f..1e07e0d 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -142,27 +142,23 @@ mod tests { let client = reqwest::Client::new(); // This makes the request - let (body, content_type) = get_body_and_content_type( - format!("http://{}/foo.txt", proxy_client_addr.to_string()), - &client, - ) - .await; + let (body, content_type) = + get_body_and_content_type(format!("http://{}/foo.txt", proxy_client_addr), &client) + .await; assert_eq!(content_type, "text/plain"); assert_eq!(body, b"bar"); let (body, content_type) = get_body_and_content_type( - format!("http://{}/index.html", proxy_client_addr.to_string()), + format!("http://{}/index.html", proxy_client_addr), &client, ) .await; assert_eq!(content_type, "text/html"); assert_eq!(body, b"foo"); - let (body, content_type) = get_body_and_content_type( - format!("http://{}/data.bin", proxy_client_addr.to_string()), - &client, - ) - .await; + let (body, content_type) = + get_body_and_content_type(format!("http://{}/data.bin", proxy_client_addr), &client) + .await; assert_eq!(content_type, "application/octet-stream"); assert_eq!(body, [0u8; 32]); } diff --git a/src/lib.rs b/src/lib.rs index a36d5b3..e6098bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -144,6 +144,7 @@ pub struct ProxyServer { impl ProxyServer { /// Start with dual listeners. The outer nested-TLS listener is optional. + #[allow(clippy::too_many_arguments)] pub async fn new( outer_cert_and_key: Option, outer_local: Option, @@ -204,6 +205,7 @@ impl ProxyServer { } /// Start with preconfigured TLS + #[allow(clippy::too_many_arguments)] pub async fn new_with_tls_config( outer_server_config: Option, outer_local: Option, @@ -1299,7 +1301,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1420,7 +1422,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1704,7 +1706,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1712,7 +1714,7 @@ mod tests { connection_breaker_tx.send(()).unwrap(); // Make another request - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1767,7 +1769,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); From edf728ee65cc617ebeae6b640feb55c4cdade4cd Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 19 Mar 2026 12:26:57 +0100 Subject: [PATCH 12/20] Improve constructors --- src/attested_get.rs | 14 ++- src/file_server.rs | 33 +++--- src/lib.rs | 263 ++++++++++++++++++++++++++------------------ src/main.rs | 12 +- 4 files changed, 194 insertions(+), 128 deletions(-) diff --git a/src/attested_get.rs b/src/attested_get.rs index 97b346d..16fd82e 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -55,7 +55,7 @@ async fn attested_get_with_client( mod tests { use super::*; use crate::{ - ProxyServer, + OuterTlsConfig, OuterTlsMode, ProxyServer, attestation::AttestationType, file_server::static_file_server, test_helpers::{generate_certificate_chain_for_host, generate_tls_config}, @@ -77,15 +77,19 @@ mod tests { let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); // Setup a proxy server targetting the static file server - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some("localhost".to_string()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, - Some("localhost".to_string()), ) .await .unwrap(); diff --git a/src/file_server.rs b/src/file_server.rs index 1e07e0d..d2cfd90 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -1,5 +1,8 @@ //! Static HTTP file server provided by an attested TLS proxy server -use crate::{AttestationGenerator, AttestationVerifier, ProxyError, ProxyServer, TlsCertAndKey}; +use crate::{ + AttestationGenerator, AttestationVerifier, OuterTlsConfig, OuterTlsMode, ProxyError, + ProxyServer, TlsCertAndKey, +}; use std::{net::SocketAddr, path::PathBuf}; use tokio::net::ToSocketAddrs; use tower_http::services::ServeDir; @@ -17,8 +20,10 @@ pub async fn attested_file_server( let target_addr = static_file_server(path_to_serve).await?; let server = ProxyServer::new( - outer_cert_and_key, - Some(outer_listen_addr), + outer_cert_and_key.map(|cert_and_key| OuterTlsConfig { + listen_addr: outer_listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), inner_listen_addr, target_addr.to_string(), attestation_generator, @@ -54,7 +59,7 @@ pub(crate) async fn static_file_server(path: PathBuf) -> Resultfoo"); diff --git a/src/lib.rs b/src/lib.rs index e6098bb..64bb2de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,6 +62,19 @@ pub struct TlsCertAndKey { pub key: PrivateKeyDer<'static>, } +pub struct OuterTlsConfig { + pub listen_addr: A, + pub tls: OuterTlsMode, +} + +pub enum OuterTlsMode { + CertAndKey(TlsCertAndKey), + Preconfigured { + server_config: ServerConfig, + certificate_name: Option, + }, +} + /// Adds HTTP 1 and 2 to the list of allowed protocols fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { for protocol in [ALPN_H2, ALPN_HTTP11] { @@ -144,10 +157,8 @@ pub struct ProxyServer { impl ProxyServer { /// Start with dual listeners. The outer nested-TLS listener is optional. - #[allow(clippy::too_many_arguments)] pub async fn new( - outer_cert_and_key: Option, - outer_local: Option, + outer_session: Option>, inner_local: impl ToSocketAddrs, target: String, attestation_generator: AttestationGenerator, @@ -157,17 +168,14 @@ impl ProxyServer { where O: ToSocketAddrs, { - if outer_cert_and_key.is_some() && outer_local.is_none() { - return Err(ProxyError::OuterTlsWithoutOuterListener); - } - - let outer_certificate_name = outer_cert_and_key - .as_ref() - .map(|cert_and_key| certificate_identity_from_chain(&cert_and_key.cert_chain)) - .transpose()?; - let outer_server_config = match outer_cert_and_key { - Some(cert_and_key) => { - let config = if client_auth { + let outer_session = match outer_session { + Some(OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }) => { + let certificate_name = + Some(certificate_identity_from_chain(&cert_and_key.cert_chain)?); + let server_config = if client_auth { let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; @@ -186,43 +194,69 @@ impl ProxyServer { cert_and_key.key.clone_key(), )? }; - Some(config) + + Some(OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name, + }, + }) } + Some(OuterTlsConfig { + listen_addr, + tls: + OuterTlsMode::Preconfigured { + server_config, + certificate_name, + }, + }) => Some(OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name, + }, + }), None => None, }; - Self::new_with_tls_config( - outer_server_config, - outer_local, + Self::new_inner( + outer_session, inner_local, target, attestation_generator, attestation_verifier, client_auth, - outer_certificate_name, ) .await } - /// Start with preconfigured TLS - #[allow(clippy::too_many_arguments)] - pub async fn new_with_tls_config( - outer_server_config: Option, - outer_local: Option, + async fn new_inner( + outer_session: Option>, inner_local: impl ToSocketAddrs, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, - certificate_name: Option, ) -> Result where O: ToSocketAddrs, { - if outer_server_config.is_some() && outer_local.is_none() { - return Err(ProxyError::OuterTlsWithoutOuterListener); - } - + let (outer_server_config, certificate_name, outer_local) = match outer_session { + Some(OuterTlsConfig { + listen_addr, + tls: + OuterTlsMode::Preconfigured { + server_config, + certificate_name, + }, + }) => (Some(server_config), certificate_name, Some(listen_addr)), + Some(OuterTlsConfig { + listen_addr: _, + tls: OuterTlsMode::CertAndKey(_), + }) => unreachable!("cert/key outer session should be normalized via ProxyServer::new"), + None => (None, None, None), + }; let inner_server_config = Arc::new( build_inner_server_config( attestation_generator, @@ -244,7 +278,9 @@ impl ProxyServer { ); (Some(outer_listener), Some(acceptor)) } - (Some(_), None) => return Err(ProxyError::OuterTlsWithoutOuterListener), + (Some(_), None) => { + unreachable!("outer config without outer listener is unrepresentable") + } (None, _) => (None, None), }; @@ -903,8 +939,6 @@ pub enum ProxyError { MpscSend, #[error("Client auth must be configured on both the inner and outer TLS sessions")] ClientAuthMisconfigured, - #[error("Outer TLS configuration requires an outer listener address")] - OuterTlsWithoutOuterListener, } impl From> for ProxyError { @@ -1046,8 +1080,10 @@ mod tests { }; let dual_listener_server = ProxyServer::new( - Some(tls_cert_and_key), - Some("127.0.0.1:0"), + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::CertAndKey(tls_cert_and_key), + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), @@ -1063,8 +1099,7 @@ mod tests { assert_ne!(outer_addr, inner_addr); let inner_only_server = ProxyServer::new( - None, - None::<&str>, + None::>, "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), @@ -1079,41 +1114,13 @@ mod tests { assert_eq!(inner_only_server.local_addr().unwrap(), inner_only_addr); } - #[tokio::test(flavor = "multi_thread")] - async fn outer_tls_requires_outer_listener_address() { - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); - let tls_cert_and_key = TlsCertAndKey { - cert_chain, - key: private_key, - }; - - let result = ProxyServer::new( - Some(tls_cert_and_key), - None::<&str>, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::expect_none(), - false, - ) - .await; - - assert!(matches!( - result, - Err(ProxyError::OuterTlsWithoutOuterListener) - )); - } - #[tokio::test(flavor = "multi_thread")] async fn inner_only_listener_negotiates_http2_by_default() { let _ = rustls::crypto::ring::default_provider().install_default(); let target_addr = example_http_service().await; let proxy_server = ProxyServer::new( - None, - None::<&str>, + None::>, "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), @@ -1162,15 +1169,19 @@ mod tests { let (server_config, outer_client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, - Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1265,15 +1276,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, - Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1329,15 +1344,21 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_tls_server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config: server_tls_server_config, + certificate_name: Some( + certificate_identity_from_chain(&server_cert_chain).unwrap(), + ), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), true, - Some(certificate_identity_from_chain(&server_cert_chain).unwrap()), ) .await .unwrap(); @@ -1383,15 +1404,21 @@ mod tests { let (server_config, client_config) = generate_tls_config(server_cert_chain.clone(), server_private_key); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some( + certificate_identity_from_chain(&server_cert_chain).unwrap(), + ), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), false, - Some(certificate_identity_from_chain(&server_cert_chain).unwrap()), ) .await .unwrap(); @@ -1449,15 +1476,21 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_tls_server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config: server_tls_server_config, + certificate_name: Some( + certificate_identity_from_chain(&server_cert_chain).unwrap(), + ), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), true, - Some(certificate_identity_from_chain(&server_cert_chain).unwrap()), ) .await .unwrap(); @@ -1505,15 +1538,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, - Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1549,15 +1586,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), false, - Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1591,15 +1632,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, - Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1658,15 +1703,19 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, - Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); @@ -1733,15 +1782,19 @@ mod tests { server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); - let proxy_server = ProxyServer::new_with_tls_config( - Some(server_config), - Some("127.0.0.1:0"), + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + }, + }), "127.0.0.1:0", target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), false, - Some(certificate_identity_from_chain(&cert_chain).unwrap()), ) .await .unwrap(); diff --git a/src/main.rs b/src/main.rs index e198b3a..ad8ad33 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,9 +7,9 @@ use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ - AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey, attested_get::attested_get, - file_server::attested_file_server, get_inner_tls_cert, health_check, - normalize_pem::normalize_private_key_pem_to_pkcs8, + AttestationGenerator, OuterTlsConfig, OuterTlsMode, ProxyClient, ProxyServer, TlsCertAndKey, + attested_get::attested_get, file_server::attested_file_server, get_inner_tls_cert, + health_check, normalize_pem::normalize_private_key_pem_to_pkcs8, }; const GIT_REV: &str = match option_env!("GIT_REV") { @@ -305,8 +305,10 @@ async fn main() -> anyhow::Result<()> { .await?; let server = ProxyServer::new( - tls_cert_and_chain, - Some(outer_listen_addr), + tls_cert_and_chain.map(|cert_and_key| OuterTlsConfig { + listen_addr: outer_listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), inner_listen_addr, target_addr, local_attestation_generator, From 22991befcf33c086230330fa8b017c1b37972fba Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 19 Mar 2026 12:37:24 +0100 Subject: [PATCH 13/20] Improve constructors --- src/lib.rs | 177 +++++++++++++++++++++++------------------------------ 1 file changed, 76 insertions(+), 101 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 64bb2de..9b9f4fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,19 +62,83 @@ pub struct TlsCertAndKey { pub key: PrivateKeyDer<'static>, } +/// Configuration for the optional outer nested-TLS listener. pub struct OuterTlsConfig { + /// The socket address to bind for the outer listener. pub listen_addr: A, + /// How the outer TLS server configuration should be constructed. pub tls: OuterTlsMode, } +/// TLS configuration sources for the outer nested-TLS listener. pub enum OuterTlsMode { + /// Build the outer TLS server config from certificate and key material. CertAndKey(TlsCertAndKey), + /// Use an already-constructed outer TLS server config. Preconfigured { + /// The outer TLS server configuration to expose on the listener. server_config: ServerConfig, + /// The server identity to embed into the inner attested certificate. certificate_name: Option, }, } +impl OuterTlsConfig +where + A: ToSocketAddrs, +{ + fn certificate_name(&self) -> Result, ProxyError> { + match &self.tls { + OuterTlsMode::CertAndKey(cert_and_key) => { + Ok(Some(certificate_identity_from_chain(&cert_and_key.cert_chain)?)) + } + OuterTlsMode::Preconfigured { + certificate_name, .. + } => Ok(certificate_name.clone()), + } + } + + async fn into_listener_and_acceptor( + self, + inner_server_config: Arc, + client_auth: bool, + ) -> Result<(Arc, NestingTlsAcceptor), ProxyError> { + let listen_addr = self.listen_addr; + let outer_server_config = match self.tls { + OuterTlsMode::CertAndKey(cert_and_key) => { + if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(verifier) + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } + } + OuterTlsMode::Preconfigured { server_config, .. } => server_config, + }; + + let outer_listener = Arc::new(TcpListener::bind(listen_addr).await?); + let outer_tls_acceptor = NestingTlsAcceptor::new( + Arc::new(outer_server_config), + inner_server_config, + ); + + Ok((outer_listener, outer_tls_acceptor)) + } +} + /// Adds HTTP 1 and 2 to the list of allowed protocols fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { for protocol in [ALPN_H2, ALPN_HTTP11] { @@ -168,95 +232,11 @@ impl ProxyServer { where O: ToSocketAddrs, { - let outer_session = match outer_session { - Some(OuterTlsConfig { - listen_addr, - tls: OuterTlsMode::CertAndKey(cert_and_key), - }) => { - let certificate_name = - Some(certificate_identity_from_chain(&cert_and_key.cert_chain)?); - let server_config = if client_auth { - let root_store = - RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_client_cert_verifier(verifier) - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - } else { - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - }; - - Some(OuterTlsConfig { - listen_addr, - tls: OuterTlsMode::Preconfigured { - server_config, - certificate_name, - }, - }) - } - Some(OuterTlsConfig { - listen_addr, - tls: - OuterTlsMode::Preconfigured { - server_config, - certificate_name, - }, - }) => Some(OuterTlsConfig { - listen_addr, - tls: OuterTlsMode::Preconfigured { - server_config, - certificate_name, - }, - }), - None => None, - }; - - Self::new_inner( - outer_session, - inner_local, - target, - attestation_generator, - attestation_verifier, - client_auth, - ) - .await - } - - async fn new_inner( - outer_session: Option>, - inner_local: impl ToSocketAddrs, - target: String, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, - client_auth: bool, - ) -> Result - where - O: ToSocketAddrs, - { - let (outer_server_config, certificate_name, outer_local) = match outer_session { - Some(OuterTlsConfig { - listen_addr, - tls: - OuterTlsMode::Preconfigured { - server_config, - certificate_name, - }, - }) => (Some(server_config), certificate_name, Some(listen_addr)), - Some(OuterTlsConfig { - listen_addr: _, - tls: OuterTlsMode::CertAndKey(_), - }) => unreachable!("cert/key outer session should be normalized via ProxyServer::new"), - None => (None, None, None), - }; + let certificate_name = outer_session + .as_ref() + .map(OuterTlsConfig::certificate_name) + .transpose()? + .flatten(); let inner_server_config = Arc::new( build_inner_server_config( attestation_generator, @@ -269,19 +249,14 @@ impl ProxyServer { let inner_listener = Arc::new(TcpListener::bind(inner_local).await?); let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone()); - let (outer_listener, outer_tls_acceptor) = match (outer_server_config, outer_local) { - (Some(outer_server_config), Some(outer_local)) => { - let outer_listener = Arc::new(TcpListener::bind(outer_local).await?); - let acceptor = NestingTlsAcceptor::new( - Arc::new(outer_server_config), - inner_server_config.clone(), - ); - (Some(outer_listener), Some(acceptor)) - } - (Some(_), None) => { - unreachable!("outer config without outer listener is unrepresentable") + let (outer_listener, outer_tls_acceptor) = match outer_session { + Some(outer_session) => { + let (outer_listener, outer_tls_acceptor) = outer_session + .into_listener_and_acceptor(inner_server_config.clone(), client_auth) + .await?; + (Some(outer_listener), Some(outer_tls_acceptor)) } - (None, _) => (None, None), + None => (None, None), }; Ok(Self { From f82d5632bebbb6f0c63ac7f48d008f6229710a5c Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 19 Mar 2026 13:37:49 +0100 Subject: [PATCH 14/20] Fmt --- src/lib.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9b9f4fd..9c636e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,9 +89,9 @@ where { fn certificate_name(&self) -> Result, ProxyError> { match &self.tls { - OuterTlsMode::CertAndKey(cert_and_key) => { - Ok(Some(certificate_identity_from_chain(&cert_and_key.cert_chain)?)) - } + OuterTlsMode::CertAndKey(cert_and_key) => Ok(Some(certificate_identity_from_chain( + &cert_and_key.cert_chain, + )?)), OuterTlsMode::Preconfigured { certificate_name, .. } => Ok(certificate_name.clone()), @@ -130,10 +130,8 @@ where }; let outer_listener = Arc::new(TcpListener::bind(listen_addr).await?); - let outer_tls_acceptor = NestingTlsAcceptor::new( - Arc::new(outer_server_config), - inner_server_config, - ); + let outer_tls_acceptor = + NestingTlsAcceptor::new(Arc::new(outer_server_config), inner_server_config); Ok((outer_listener, outer_tls_acceptor)) } From cfcb77fba3e0898f44ee2a4080509afae50d8ac4 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 20 Mar 2026 08:32:03 +0100 Subject: [PATCH 15/20] Preconfigured server name should not be optional --- src/attested_get.rs | 2 +- src/file_server.rs | 2 +- src/lib.rs | 61 +++++++++++++++++++-------------------------- 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/src/attested_get.rs b/src/attested_get.rs index 16fd82e..9bd7118 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -82,7 +82,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some("localhost".to_string()), + certificate_name: "localhost".to_string(), }, }), "127.0.0.1:0", diff --git a/src/file_server.rs b/src/file_server.rs index d2cfd90..2867972 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -110,7 +110,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some("localhost".to_string()), + certificate_name: "localhost".to_string(), }, }), "127.0.0.1:0", diff --git a/src/lib.rs b/src/lib.rs index 9c636e7..084f70c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,8 +47,6 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; -const DEFAULT_INNER_CERTIFICATE_NAME: &str = "localhost"; - type RequestWithResponseSender = ( http::Request, oneshot::Sender>, hyper::Error>>, @@ -79,7 +77,7 @@ pub enum OuterTlsMode { /// The outer TLS server configuration to expose on the listener. server_config: ServerConfig, /// The server identity to embed into the inner attested certificate. - certificate_name: Option, + certificate_name: String, }, } @@ -87,11 +85,11 @@ impl OuterTlsConfig where A: ToSocketAddrs, { - fn certificate_name(&self) -> Result, ProxyError> { + fn certificate_name(&self) -> Result { match &self.tls { - OuterTlsMode::CertAndKey(cert_and_key) => Ok(Some(certificate_identity_from_chain( - &cert_and_key.cert_chain, - )?)), + OuterTlsMode::CertAndKey(cert_and_key) => { + Ok(certificate_identity_from_chain(&cert_and_key.cert_chain)?) + } OuterTlsMode::Preconfigured { certificate_name, .. } => Ok(certificate_name.clone()), @@ -233,8 +231,7 @@ impl ProxyServer { let certificate_name = outer_session .as_ref() .map(OuterTlsConfig::certificate_name) - .transpose()? - .flatten(); + .transpose()?; let inner_server_config = Arc::new( build_inner_server_config( attestation_generator, @@ -565,7 +562,7 @@ impl ProxyClient { let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { let inner_cert_resolver = build_attested_cert_resolver( attestation_generator, - Some(certificate_identity_from_chain(cert_chain)?), + certificate_identity_from_chain(cert_chain)?, ) .await?; ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) @@ -944,15 +941,12 @@ fn certificate_identity_from_chain( async fn build_attested_cert_resolver( attestation_generator: AttestationGenerator, - certificate_name: Option, + certificate_name: String, ) -> Result { - Ok(AttestedCertificateResolver::new( - attestation_generator, - None, - certificate_name.unwrap_or_else(|| DEFAULT_INNER_CERTIFICATE_NAME.to_string()), - vec![], + Ok( + AttestedCertificateResolver::new(attestation_generator, None, certificate_name, vec![]) + .await?, ) - .await?) } async fn build_inner_server_config( @@ -961,8 +955,11 @@ async fn build_inner_server_config( client_auth: bool, certificate_name: Option, ) -> Result { - let inner_cert_resolver = - build_attested_cert_resolver(attestation_generator, certificate_name).await?; + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_name.unwrap_or_else(|| "localhost".to_string()), + ) + .await?; let mut inner_server_config = if client_auth { let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; @@ -1147,7 +1144,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1254,7 +1251,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1322,9 +1319,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config: server_tls_server_config, - certificate_name: Some( - certificate_identity_from_chain(&server_cert_chain).unwrap(), - ), + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1382,9 +1377,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some( - certificate_identity_from_chain(&server_cert_chain).unwrap(), - ), + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1454,9 +1447,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config: server_tls_server_config, - certificate_name: Some( - certificate_identity_from_chain(&server_cert_chain).unwrap(), - ), + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1516,7 +1507,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1564,7 +1555,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1610,7 +1601,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1681,7 +1672,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), "127.0.0.1:0", @@ -1760,7 +1751,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::Preconfigured { server_config, - certificate_name: Some(certificate_identity_from_chain(&cert_chain).unwrap()), + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), "127.0.0.1:0", From b3ebc9cd6be9378e1f35459b5b13419d71b84190 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 20 Mar 2026 08:38:45 +0100 Subject: [PATCH 16/20] Update CLI documentation --- src/main.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/main.rs b/src/main.rs index ad8ad33..0d7eb61 100644 --- a/src/main.rs +++ b/src/main.rs @@ -77,7 +77,7 @@ enum CliCommand { }, /// Run a proxy server Server { - /// Socket address to listen on for the outer nested-TLS listener + /// Socket address to listen on for the outer nested-TLS listener, if enabled #[arg(long, default_value = "0.0.0.0:443")] outer_listen_addr: SocketAddr, /// Socket address to listen on for the inner-only attested TLS listener @@ -86,13 +86,13 @@ enum CliCommand { /// The hostname:port or ip:port of the target service to forward traffic to target_addr: String, /// Type of attestation to present (dafaults to 'auto' for automatic detection) - /// If other than None, a TLS key and certicate must also be given + /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, - /// The path to a PEM encoded private key + /// The path to a PEM encoded private key for the optional outer nested-TLS listener #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: Option, - /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. + /// PEM certificate chain for the optional outer nested-TLS listener #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: Option, /// Whether to use client authentication. If the client is running in a CVM this must be @@ -122,20 +122,20 @@ enum CliCommand { AttestedFileServer { /// Filesystem path to statically serve path_to_serve: PathBuf, - /// Socket address to listen on for the outer nested-TLS listener + /// Socket address to listen on for the outer nested-TLS listener, if enabled #[arg(long, default_value = "0.0.0.0:443")] outer_listen_addr: SocketAddr, /// Socket address to listen on for the inner-only attested TLS listener #[arg(long, default_value = "0.0.0.0:4433")] inner_listen_addr: SocketAddr, /// Type of attestation to present (dafaults to none) - /// If other than None, a TLS key and certicate must also be given + /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, - /// The path to a PEM encoded private key + /// The path to a PEM encoded private key for the optional outer nested-TLS listener #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: Option, - /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. + /// PEM certificate chain for the optional outer nested-TLS listener #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: Option, /// URL of the remote dummy attestation service. Only use with --server-attestation-type From a42c81d99c9f09db71e9a85423ea4b3703d167e4 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 20 Mar 2026 09:15:27 +0100 Subject: [PATCH 17/20] Make inner and outer session optional and dont use default ports --- src/attested_get.rs | 2 +- src/file_server.rs | 16 ++--- src/lib.rs | 146 ++++++++++++++++++++++++++++++-------------- src/main.rs | 62 +++++++++++++++---- 4 files changed, 160 insertions(+), 66 deletions(-) diff --git a/src/attested_get.rs b/src/attested_get.rs index 9bd7118..7fc40b9 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -85,7 +85,7 @@ mod tests { certificate_name: "localhost".to_string(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), diff --git a/src/file_server.rs b/src/file_server.rs index 2867972..a8dfbd0 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -11,8 +11,8 @@ use tower_http::services::ServeDir; pub async fn attested_file_server( path_to_serve: PathBuf, outer_cert_and_key: Option, - outer_listen_addr: impl ToSocketAddrs, - inner_listen_addr: impl ToSocketAddrs, + outer_listen_addr: Option, + inner_listen_addr: Option, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, @@ -20,10 +20,12 @@ pub async fn attested_file_server( let target_addr = static_file_server(path_to_serve).await?; let server = ProxyServer::new( - outer_cert_and_key.map(|cert_and_key| OuterTlsConfig { - listen_addr: outer_listen_addr, - tls: OuterTlsMode::CertAndKey(cert_and_key), - }), + outer_cert_and_key + .zip(outer_listen_addr) + .map(|(cert_and_key, listen_addr)| OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), inner_listen_addr, target_addr.to_string(), attestation_generator, @@ -113,7 +115,7 @@ mod tests { certificate_name: "localhost".to_string(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), diff --git a/src/lib.rs b/src/lib.rs index 084f70c..213ed3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,9 @@ type RequestWithResponseSender = ( oneshot::Sender>, hyper::Error>>, ); +type OuterProxySession = (Arc, NestingTlsAcceptor); +type InnerProxySession = (Arc, TlsAcceptor); + /// TLS Credentials pub struct TlsCertAndKey { /// Der-encoded TLS certificate chain @@ -207,19 +210,17 @@ pub async fn get_inner_tls_cert_with_config( /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { - outer_listener: Option>, - outer_tls_acceptor: Option, - inner_listener: Arc, - inner_tls_acceptor: TlsAcceptor, + outer: Option, + inner: Option, /// The address/hostname of the target service we are proxying to target: String, } impl ProxyServer { /// Start with dual listeners. The outer nested-TLS listener is optional. - pub async fn new( + pub async fn new( outer_session: Option>, - inner_local: impl ToSocketAddrs, + inner_local: Option, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, @@ -227,7 +228,12 @@ impl ProxyServer { ) -> Result where O: ToSocketAddrs, + I: ToSocketAddrs, { + if outer_session.is_none() && inner_local.is_none() { + return Err(ProxyError::NoListenersConfigured); + } + let certificate_name = outer_session .as_ref() .map(OuterTlsConfig::certificate_name) @@ -241,24 +247,28 @@ impl ProxyServer { ) .await?, ); - let inner_listener = Arc::new(TcpListener::bind(inner_local).await?); - let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone()); + let inner = match inner_local { + Some(inner_local) => { + let inner_listener = Arc::new(TcpListener::bind(inner_local).await?); + let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone()); + Some((inner_listener, inner_tls_acceptor)) + } + None => None, + }; - let (outer_listener, outer_tls_acceptor) = match outer_session { + let outer = match outer_session { Some(outer_session) => { let (outer_listener, outer_tls_acceptor) = outer_session .into_listener_and_acceptor(inner_server_config.clone(), client_auth) .await?; - (Some(outer_listener), Some(outer_tls_acceptor)) + Some((outer_listener, outer_tls_acceptor)) } - None => (None, None), + None => None, }; Ok(Self { - outer_listener, - outer_tls_acceptor, - inner_listener, - inner_tls_acceptor, + outer, + inner, target, }) } @@ -268,13 +278,14 @@ impl ProxyServer { /// Returns the handle for the task handling the connection pub async fn accept(&self) -> Result, ProxyError> { let target = self.target.clone(); - let outer_listener = self.outer_listener.clone(); - let outer_tls_acceptor = self.outer_tls_acceptor.clone(); - let inner_listener = self.inner_listener.clone(); - let inner_tls_acceptor = self.inner_tls_acceptor.clone(); - - let join_handle = match (outer_listener, outer_tls_acceptor) { - (Some(outer_listener), Some(outer_tls_acceptor)) => { + let outer = self.outer.clone(); + let inner = self.inner.clone(); + + let join_handle = match (outer, inner) { + ( + Some((outer_listener, outer_tls_acceptor)), + Some((inner_listener, inner_tls_acceptor)), + ) => { let ((inbound, client_addr), use_outer) = tokio::select! { accepted = outer_listener.accept() => (accepted?, true), accepted = inner_listener.accept() => (accepted?, false), @@ -312,7 +323,7 @@ impl ProxyServer { } }) } - _ => { + (None, Some((inner_listener, inner_tls_acceptor))) => { let (inbound, client_addr) = inner_listener.accept().await?; tokio::spawn(async move { match inner_tls_acceptor.accept(inbound).await { @@ -329,6 +340,24 @@ impl ProxyServer { } }) } + (Some((outer_listener, outer_tls_acceptor)), None) => { + let (inbound, client_addr) = outer_listener.accept().await?; + tokio::spawn(async move { + match outer_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_outer_connection(tls_stream, target, client_addr).await + { + warn!("Failed to handle outer connection: {err}"); + } + } + Err(err) => { + warn!("Outer attestation exchange failed: {err}"); + } + } + }) + } + _ => return Err(ProxyError::NoListenersConfigured), }; Ok(join_handle) @@ -336,21 +365,29 @@ impl ProxyServer { /// Helper to get the socket address of the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - match &self.outer_listener { - Some(listener) => listener.local_addr(), - None => self.inner_listener.local_addr(), + match &self.outer { + Some((listener, _)) => listener.local_addr(), + None => self + .inner + .as_ref() + .map(|(listener, _)| listener) + .ok_or_else(|| std::io::Error::other("no listeners configured"))? + .local_addr(), } } pub fn outer_local_addr(&self) -> std::io::Result> { - self.outer_listener + self.outer .as_ref() - .map(|listener| listener.local_addr()) + .map(|(listener, _)| listener.local_addr()) .transpose() } - pub fn inner_local_addr(&self) -> std::io::Result { - self.inner_listener.local_addr() + pub fn inner_local_addr(&self) -> std::io::Result> { + self.inner + .as_ref() + .map(|(listener, _)| listener.local_addr()) + .transpose() } async fn handle_outer_connection( @@ -909,6 +946,8 @@ pub enum ProxyError { MpscSend, #[error("Client auth must be configured on both the inner and outer TLS sessions")] ClientAuthMisconfigured, + #[error("At least one server listener must be configured")] + NoListenersConfigured, } impl From> for ProxyError { @@ -1039,6 +1078,21 @@ mod tests { assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]); } + #[tokio::test(flavor = "multi_thread")] + async fn proxy_server_requires_at_least_one_listener() { + let result = ProxyServer::new( + None::>, + None::<&str>, + "127.0.0.1:1".to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::NoListenersConfigured))); + } + #[tokio::test(flavor = "multi_thread")] async fn dual_listener_server_reports_expected_addresses() { let target_addr = example_http_service().await; @@ -1054,7 +1108,7 @@ mod tests { listen_addr: "127.0.0.1:0", tls: OuterTlsMode::CertAndKey(tls_cert_and_key), }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -1064,13 +1118,13 @@ mod tests { .unwrap(); let outer_addr = dual_listener_server.outer_local_addr().unwrap().unwrap(); - let inner_addr = dual_listener_server.inner_local_addr().unwrap(); + let inner_addr = dual_listener_server.inner_local_addr().unwrap().unwrap(); assert_eq!(dual_listener_server.local_addr().unwrap(), outer_addr); assert_ne!(outer_addr, inner_addr); let inner_only_server = ProxyServer::new( None::>, - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -1079,7 +1133,7 @@ mod tests { .await .unwrap(); - let inner_only_addr = inner_only_server.inner_local_addr().unwrap(); + let inner_only_addr = inner_only_server.inner_local_addr().unwrap().unwrap(); assert!(inner_only_server.outer_local_addr().unwrap().is_none()); assert_eq!(inner_only_server.local_addr().unwrap(), inner_only_addr); } @@ -1091,7 +1145,7 @@ mod tests { let proxy_server = ProxyServer::new( None::>, - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1100,7 +1154,7 @@ mod tests { .await .unwrap(); - let inner_addr = proxy_server.inner_local_addr().unwrap(); + let inner_addr = proxy_server.inner_local_addr().unwrap().unwrap(); tokio::spawn(async move { proxy_server.accept().await.unwrap(); @@ -1147,7 +1201,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1254,7 +1308,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1322,7 +1376,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), @@ -1380,7 +1434,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), @@ -1450,7 +1504,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), @@ -1510,7 +1564,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1558,7 +1612,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -1604,7 +1658,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1675,7 +1729,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1754,7 +1808,7 @@ mod tests { certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), }, }), - "127.0.0.1:0", + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), diff --git a/src/main.rs b/src/main.rs index 0d7eb61..a80a54b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,11 +78,11 @@ enum CliCommand { /// Run a proxy server Server { /// Socket address to listen on for the outer nested-TLS listener, if enabled - #[arg(long, default_value = "0.0.0.0:443")] - outer_listen_addr: SocketAddr, + #[arg(long)] + outer_listen_addr: Option, /// Socket address to listen on for the inner-only attested TLS listener - #[arg(long, default_value = "0.0.0.0:4433")] - inner_listen_addr: SocketAddr, + #[arg(long)] + inner_listen_addr: Option, /// The hostname:port or ip:port of the target service to forward traffic to target_addr: String, /// Type of attestation to present (dafaults to 'auto' for automatic detection) @@ -123,11 +123,11 @@ enum CliCommand { /// Filesystem path to statically serve path_to_serve: PathBuf, /// Socket address to listen on for the outer nested-TLS listener, if enabled - #[arg(long, default_value = "0.0.0.0:443")] - outer_listen_addr: SocketAddr, + #[arg(long)] + outer_listen_addr: Option, /// Socket address to listen on for the inner-only attested TLS listener - #[arg(long, default_value = "0.0.0.0:4433")] - inner_listen_addr: SocketAddr, + #[arg(long)] + inner_listen_addr: Option, /// Type of attestation to present (dafaults to none) /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] @@ -299,16 +299,23 @@ async fn main() -> anyhow::Result<()> { let tls_cert_and_chain = load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; + validate_listener_args( + inner_listen_addr, + outer_listen_addr, + tls_cert_and_chain.is_some(), + )?; let local_attestation_generator = AttestationGenerator::new_with_detection(server_attestation_type, dev_dummy_dcap) .await?; let server = ProxyServer::new( - tls_cert_and_chain.map(|cert_and_key| OuterTlsConfig { - listen_addr: outer_listen_addr, - tls: OuterTlsMode::CertAndKey(cert_and_key), - }), + tls_cert_and_chain + .zip(outer_listen_addr) + .map(|(cert_and_key, listen_addr)| OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), inner_listen_addr, target_addr, local_attestation_generator, @@ -363,6 +370,11 @@ async fn main() -> anyhow::Result<()> { } => { let tls_cert_and_chain = load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; + validate_listener_args( + inner_listen_addr, + outer_listen_addr, + tls_cert_and_chain.is_some(), + )?; let server_attestation_type: AttestationType = serde_json::from_value( serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())), @@ -433,6 +445,32 @@ fn load_tls_cert_and_key_server( } } +fn validate_listener_args( + inner_listen_addr: Option, + outer_listen_addr: Option, + has_outer_tls: bool, +) -> anyhow::Result<()> { + if inner_listen_addr.is_none() && outer_listen_addr.is_none() { + return Err(anyhow!( + "At least one of --inner-listen-addr or --outer-listen-addr must be provided" + )); + } + + if has_outer_tls && outer_listen_addr.is_none() { + return Err(anyhow!( + "--outer-listen-addr is required when TLS certificate and key are provided" + )); + } + + if !has_outer_tls && outer_listen_addr.is_some() { + return Err(anyhow!( + "--outer-listen-addr requires TLS certificate and key" + )); + } + + Ok(()) +} + /// Load TLS details from storage fn load_tls_cert_and_key( cert_chain: PathBuf, From d2d9b7c5561d8a684dc273762479d20e4a009102 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 20 Mar 2026 09:18:05 +0100 Subject: [PATCH 18/20] Small fix for attested file server --- src/file_server.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/file_server.rs b/src/file_server.rs index a8dfbd0..4c5c9bb 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -18,14 +18,19 @@ pub async fn attested_file_server( client_auth: bool, ) -> Result<(), ProxyError> { let target_addr = static_file_server(path_to_serve).await?; + let outer_session = match (outer_cert_and_key, outer_listen_addr) { + (Some(cert_and_key), Some(listen_addr)) => Some(OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), + (Some(_), None) | (None, Some(_)) => { + return Err(ProxyError::NoListenersConfigured); + } + (None, None) => None, + }; let server = ProxyServer::new( - outer_cert_and_key - .zip(outer_listen_addr) - .map(|(cert_and_key, listen_addr)| OuterTlsConfig { - listen_addr, - tls: OuterTlsMode::CertAndKey(cert_and_key), - }), + outer_session, inner_listen_addr, target_addr.to_string(), attestation_generator, From 634fb4fc94cb73387b72325665b76ee228302207 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 20 Mar 2026 09:28:10 +0100 Subject: [PATCH 19/20] Tidy, rm unneeded test --- src/lib.rs | 58 ++++-------------------------------------------------- 1 file changed, 4 insertions(+), 54 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 213ed3f..b3b69c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -363,7 +363,7 @@ impl ProxyServer { Ok(join_handle) } - /// Helper to get the socket address of the underlying TCP listener + /// Helper to get the socket address of either underlying TCP listener pub fn local_addr(&self) -> std::io::Result { match &self.outer { Some((listener, _)) => listener.local_addr(), @@ -376,6 +376,7 @@ impl ProxyServer { } } + /// Helper to get the socket address of the underlying outer TCP listener if present pub fn outer_local_addr(&self) -> std::io::Result> { self.outer .as_ref() @@ -383,6 +384,7 @@ impl ProxyServer { .transpose() } + /// Helper to get the socket address of the underlying inner TCP listener if present pub fn inner_local_addr(&self) -> std::io::Result> { self.inner .as_ref() @@ -1239,58 +1241,6 @@ mod tests { assert!(matches!(conn, HttpConnection::Http2 { .. })); } - // #[tokio::test(flavor = "multi_thread")] - // async fn http_proxy_default_constructors_work() { - // let target_addr = example_http_service().await; - // - // let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); - // let server_cert = cert_chain[0].clone(); - // - // let proxy_server = ProxyServer::new( - // TlsCertAndKey { - // cert_chain, - // key: private_key, - // }, - // "127.0.0.1:0", - // target_addr.to_string(), - // AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - // AttestationVerifier::expect_none(), - // false, - // ) - // .await - // .unwrap(); - // - // let proxy_addr = proxy_server.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_server.accept().await.unwrap(); - // }); - // - // let proxy_client = ProxyClient::new( - // None, - // "127.0.0.1:0".to_string(), - // format!("localhost:{}", proxy_addr.port()), - // AttestationGenerator::with_no_attestation(), - // AttestationVerifier::mock(), - // Some(server_cert), - // ) - // .await - // .unwrap(); - // - // let proxy_client_addr = proxy_client.local_addr().unwrap(); - // - // tokio::spawn(async move { - // proxy_client.accept().await.unwrap(); - // }); - // - // let res = reqwest::get(format!("http://{}", proxy_client_addr)) - // .await - // .unwrap(); - // - // let res_body = res.text().await.unwrap(); - // assert_eq!(res_body, "No measurements"); - // } - // Server has mock DCAP, client has no attestation and no client auth #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_server_attestation() { @@ -1334,7 +1284,7 @@ mod tests { .await .unwrap(); - let proxy_client_addr = proxy_client.local_addr().unwrap(); + let proy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { proxy_client.accept().await.unwrap(); From a9bb332e9f77750b3e04b9d46f1751b67d57d3a6 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 20 Mar 2026 09:35:42 +0100 Subject: [PATCH 20/20] Typo --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index b3b69c7..88aa200 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1284,7 +1284,7 @@ mod tests { .await .unwrap(); - let proy_client_addr = proxy_client.local_addr().unwrap(); + let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { proxy_client.accept().await.unwrap();