diff --git a/Cargo.lock b/Cargo.lock index 0459228..a169090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + [[package]] name = "aho-corasick" version = "1.1.4" @@ -11,6 +17,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "alloy-json-rpc" version = "1.6.3" @@ -242,9 +263,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.100" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "ark-ff" @@ -435,19 +456,41 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive 0.5.1", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "asn1-rs" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60" dependencies = [ - "asn1-rs-derive", + "asn1-rs-derive 0.6.0", "asn1-rs-impl", "displaydoc", "nom", @@ -457,6 +500,18 @@ dependencies = [ "time", ] +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", + "synstructure", +] + [[package]] name = "asn1-rs-derive" version = "0.6.0" @@ -506,13 +561,13 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "attestation" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fadd-attestation-crate#4ebc03703510e65fd1317736b8887fc388860481" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" dependencies = [ "anyhow", "az-tdx-vtpm", "base64 0.22.1", "configfs-tsm", - "dcap-qvl", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", "hex", "http", "num-bigint", @@ -532,7 +587,7 @@ dependencies = [ "tokio-rustls", "tracing", "tss-esapi", - "x509-parser", + "x509-parser 0.18.1", ] [[package]] @@ -565,7 +620,7 @@ dependencies = [ "hyper", "hyper-util", "parity-scale-codec", - "rcgen", + "rcgen 0.14.7", "serde_json", "sha2", "tempfile", @@ -577,7 +632,26 @@ dependencies = [ "tracing", "url", "webpki-roots", - "x509-parser", + "x509-parser 0.18.1", +] + +[[package]] +name = "attested-tls" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +dependencies = [ + "anyhow", + "attestation", + "ra-tls", + "rcgen 0.14.7", + "rustls", + "serde_json", + "sha2", + "thiserror 2.0.17", + "tokio", + "tracing", + "x509-parser 0.18.1", + "yasna 0.5.2", ] [[package]] @@ -585,7 +659,8 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attested-tls", + "attestation", + "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", "axum", "bytes", "clap", @@ -594,12 +669,13 @@ dependencies = [ "hyper", "hyper-util", "jsonrpsee", + "nested-tls", "p256", "pem-rfc7468", "pin-project-lite", "pkcs1", "pkcs8", - "rcgen", + "rcgen 0.14.7", "reqwest", "rsa", "rustls-pemfile", @@ -614,7 +690,7 @@ dependencies = [ "tracing", "tracing-subscriber", "webpki-roots", - "x509-parser", + "x509-parser 0.18.1", ] [[package]] @@ -855,6 +931,29 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest 0.10.7", +] + +[[package]] +name = "blake3" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -864,6 +963,31 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bon" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" +dependencies = [ + "darling", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.108", +] + [[package]] name = "borsh" version = "1.6.0" @@ -887,6 +1011,27 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "brotli" +version = "8.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -924,6 +1069,23 @@ dependencies = [ "shlex", ] +[[package]] +name = "cc-eventlog" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "digest 0.10.7", + "ez-hash", + "fs-err", + "hex", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -1004,6 +1166,18 @@ version = "0.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187437900921c8172f33316ad51a3267df588e99a2aebfa5ca1a2ed44df9e703" +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "const-hex" version = "1.17.0" @@ -1042,6 +1216,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + [[package]] name = "convert_case" version = "0.10.0" @@ -1086,6 +1266,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -1170,6 +1359,40 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.108", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.108", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1179,7 +1402,44 @@ checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" [[package]] name = "dcap-qvl" version = "0.3.12" -source = "git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override#b61d8f3ffb59f225d7b98220e2185a66f1c7f8c7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e7842b81018f3b991dc65ec0a95ff347332de58478c4ac43459095af00cc89" +dependencies = [ + "anyhow", + "asn1_der", + "base64 0.22.1", + "borsh", + "byteorder", + "chrono", + "const-oid", + "dcap-qvl-webpki", + "der", + "derive_more 2.1.1", + "futures", + "hex", + "log", + "p256", + "parity-scale-codec", + "pem", + "reqwest", + "ring", + "rustls-pki-types", + "scale-info", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "signature", + "tracing", + "urlencoding", + "wasm-bindgen-futures", + "x509-cert", +] + +[[package]] +name = "dcap-qvl" +version = "0.3.12" +source = "git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override#e38818f0b7b600ceadad1ec3efd9e681bcbdc1e5" dependencies = [ "anyhow", "asn1_der", @@ -1243,13 +1503,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs 0.6.2", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "der-parser" version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", "displaydoc", "nom", "num-bigint", @@ -1384,6 +1658,44 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "dstack-attest" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "cc-eventlog", + "dcap-qvl 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", + "dstack-types", + "errify", + "ez-hash", + "fs-err", + "hex", + "hex_fmt", + "insta", + "or-panic", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "sha3", + "tdx-attest", + "tracing", +] + +[[package]] +name = "dstack-types" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "parity-scale-codec", + "serde", + "serde-human-bytes", + "sha3", + "size-parser", +] + [[package]] name = "dunce" version = "1.0.5" @@ -1463,6 +1775,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "enum-as-inner" version = "0.6.1" @@ -1521,6 +1839,28 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errify" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb818c3c01af9cdeb367f7e92e290b9a080935cdc5fb6cc0c1193ae17032849" +dependencies = [ + "anyhow", + "errify-macros", +] + +[[package]] +name = "errify-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e87afa19e6030c2cf5514b00d5a242a3ea9492a2aa618635076914f5d15e7af" +dependencies = [ + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.108", +] + [[package]] name = "errno" version = "0.3.14" @@ -1531,6 +1871,21 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "ez-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42b3b3adc5fbbc9e21416d5b721b1bccb501a87d7b32ac89f2c7cea229d40772" +dependencies = [ + "blake2", + "blake3", + "digest 0.10.7", + "md-5", + "sha1", + "sha2", + "sha3", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -1599,6 +1954,16 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7ac824320a75a52197e8f2d787f6a38b6718bb6897a35142d749af3c0e8f4fe" +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1635,6 +2000,15 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-err" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" +dependencies = [ + "autocfg", +] + [[package]] name = "funty" version = "2.0.0" @@ -1826,6 +2200,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hex_fmt" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b07f60793ff0a4d9cef0f18e63b5357e06209987153a64648c972c1e5aff336f" + [[package]] name = "hickory-proto" version = "0.25.2" @@ -1872,6 +2252,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -2102,6 +2491,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.1.0" @@ -2155,6 +2550,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "insta" +version = "1.46.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" +dependencies = [ + "console", + "once_cell", + "similar", + "tempfile", +] + [[package]] name = "iocuddle" version = "0.1.1" @@ -2374,9 +2781,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" [[package]] name = "libm" @@ -2463,6 +2870,16 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest 0.10.7", +] + [[package]] name = "memchr" version = "2.7.6" @@ -2500,6 +2917,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + [[package]] name = "mio" version = "1.1.0" @@ -2546,6 +2973,30 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nested-tls" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +dependencies = [ + "rustls", + "tokio", + "tokio-rustls", + "tracing", +] + +[[package]] +name = "nix" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "cfg_aliases", + "libc", + "memoffset", +] + [[package]] name = "nom" version = "7.1.3" @@ -2647,13 +3098,22 @@ dependencies = [ "serde", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs 0.6.2", +] + [[package]] name = "oid-registry" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", ] [[package]] @@ -2773,6 +3233,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "or-panic" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "596a79faf55e869e7bc0c2162cf2f18a54d4d1112876bceae587ad954fcbd574" + [[package]] name = "p256" version = "0.13.2" @@ -3026,6 +3492,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.108", +] + [[package]] name = "primeorder" version = "0.13.6" @@ -3086,6 +3562,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", + "version_check", + "yansi", +] + [[package]] name = "proptest" version = "1.10.0" @@ -3181,6 +3670,43 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "ra-tls" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "bon", + "cc-eventlog", + "dcap-qvl 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", + "dstack-attest", + "dstack-types", + "elliptic-curve", + "errify", + "ez-hash", + "flate2", + "fs-err", + "hex", + "hex_fmt", + "hkdf", + "or-panic", + "p256", + "parity-scale-codec", + "rand 0.8.5", + "rcgen 0.13.2", + "ring", + "rmp-serde", + "rustls-pki-types", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "sha3", + "tracing", + "x509-parser 0.16.0", + "yasna 0.5.2", +] + [[package]] name = "radium" version = "0.7.0" @@ -3268,14 +3794,29 @@ dependencies = [ [[package]] name = "rcgen" -version = "0.14.5" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "x509-parser 0.16.0", + "yasna 0.5.2", +] + +[[package]] +name = "rcgen" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fae430c6b28f1ad601274e78b7dffa0546de0b73b4cd32f46723c0c2a16f7a5" +checksum = "10b99e0098aa4082912d4c649628623db6aba77335e4f4569ff5083a6448b32e" dependencies = [ "pem", "ring", "rustls-pki-types", "time", + "x509-parser 0.18.1", "yasna 0.5.2", ] @@ -3422,6 +3963,25 @@ dependencies = [ "rustc-hex", ] +[[package]] +name = "rmp" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + [[package]] name = "route-recognizer" version = "0.3.1" @@ -3537,10 +4097,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.34" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a9586e9ee2b4f8fab52a0048ca7334d7024eef48e2cb9407e3497bb7cab7fa7" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "brotli", + "brotli-decompressor", "once_cell", "ring", "rustls-pki-types", @@ -3560,9 +4122,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -3725,10 +4287,11 @@ dependencies = [ [[package]] name = "serde-human-bytes" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ef65cb41f3f9cef63c431193229067e8b98b53c4d4c4ed38a8ca87c4d07676" +checksum = "6a091af6294712930d01e375cce513e4ac416f823e033e8991ec4e5d6e6ef4c0" dependencies = [ + "base64 0.13.1", "hex", "serde", ] @@ -3902,6 +4465,28 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + +[[package]] +name = "size-parser" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "slab" version = "0.4.11" @@ -4083,6 +4668,25 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tdx-attest" +version = "0.5.8" +source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +dependencies = [ + "anyhow", + "cc-eventlog", + "fs-err", + "hex", + "libc", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "thiserror 2.0.17", + "vsock", +] + [[package]] name = "tdx-quote" version = "0.0.5" @@ -4217,9 +4821,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.48.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ "bytes", "libc", @@ -4389,9 +4993,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -4401,9 +5005,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -4412,9 +5016,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", @@ -4689,6 +5293,16 @@ version = "0.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" +[[package]] +name = "vsock" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b82aeb12ad864eb8cd26a6c21175d0bdc66d398584ee6c93c76964c3bcfc78ff" +dependencies = [ + "libc", + "nix", +] + [[package]] name = "wait-timeout" version = "0.2.1" @@ -4889,6 +5503,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" @@ -5158,16 +5781,34 @@ dependencies = [ [[package]] name = "x509-parser" -version = "0.18.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3e137310115a65136898d2079f003ce33331a6c4b0d51f1531d1be082b6425" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" dependencies = [ - "asn1-rs", + "asn1-rs 0.6.2", "data-encoding", - "der-parser", + "der-parser 9.0.0", "lazy_static", "nom", - "oid-registry", + "oid-registry 0.7.1", + "ring", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "x509-parser" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202" +dependencies = [ + "asn1-rs 0.7.1", + "data-encoding", + "der-parser 10.0.0", + "lazy_static", + "nom", + "oid-registry 0.8.1", "ring", "rusticata-macros", "thiserror 2.0.17", @@ -5195,6 +5836,12 @@ dependencies = [ "x509-ocsp", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 40a3ab3..c285277 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,10 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] -attested-tls = { path = "attested-tls", default-features = false } -tokio = { version = "1.48.0", features = ["full"] } +attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +tokio = { version = "1.50.0", features = ["full"] } tokio-rustls = { version = "0.26.4", default-features = false } x509-parser = { version = "0.18.0", features = ["verify"] } thiserror = "2.0.17" @@ -45,14 +47,15 @@ pin-project-lite = "0.2.16" [dev-dependencies] tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } -attested-tls = { path = "attested-tls", features = ["test-helpers", "mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } +tokio = { version = "1.48.0", features = ["full"] } jsonrpsee = { version = "0.26.0", features = ["server"] } [features] default = [] # Adds support for Microsoft Azure attestation generation and verification -azure = ["attested-tls/azure"] +azure = ["attestation/azure"] [package.metadata.deb] maintainer = "Flashbots Team " diff --git a/attested-tls/Cargo.toml b/attested-tls/Cargo.toml index 38a810e..d40e7b6 100644 --- a/attested-tls/Cargo.toml +++ b/attested-tls/Cargo.toml @@ -18,7 +18,7 @@ http = "1.3.1" serde_json = "1.0.145" tracing = "0.1.41" parity-scale-codec = "3.7.5" -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/add-attestation-crate" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } # Used for websocket support tokio-tungstenite = { version = "0.28.0", optional = true } @@ -40,7 +40,7 @@ rcgen = { version = "0.14.5", optional = true } [dev-dependencies] rcgen = "0.14.5" tempfile = "3.23.0" -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/add-attestation-crate", features = ["mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } [features] default = ["ws", "rpc"] diff --git a/src/attested_get.rs b/src/attested_get.rs index 27e7ad3..7fc40b9 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -9,30 +9,16 @@ pub async fn attested_get( url_path: &str, attestation_verifier: AttestationVerifier, remote_certificate: Option>, - allow_self_signed: bool, ) -> Result { - let proxy_client = if allow_self_signed { - let client_config = crate::self_signed::client_tls_config_allow_self_signed()?; - ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - target_addr, - AttestationGenerator::with_no_attestation(), - attestation_verifier, - None, - ) - .await? - } else { - ProxyClient::new( - None, - "127.0.0.1:0".to_string(), - target_addr, - AttestationGenerator::with_no_attestation(), - attestation_verifier, - remote_certificate, - ) - .await? - }; + let proxy_client = ProxyClient::new( + None, + "127.0.0.1:0".to_string(), + target_addr, + AttestationGenerator::with_no_attestation(), + attestation_verifier, + remote_certificate, + ) + .await?; attested_get_with_client(proxy_client, url_path).await } @@ -69,14 +55,14 @@ async fn attested_get_with_client( mod tests { use super::*; use crate::{ - ProxyServer, + OuterTlsConfig, OuterTlsMode, ProxyServer, attestation::AttestationType, file_server::static_file_server, - test_helpers::{generate_certificate_chain, generate_tls_config}, + test_helpers::{generate_certificate_chain_for_host, generate_tls_config}, }; use tempfile::tempdir; - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_attested_get() { // Create a temporary directory with a file to serve let dir = tempdir().unwrap(); @@ -87,17 +73,23 @@ mod tests { let target_addr = static_file_server(dir.path().to_path_buf()).await.unwrap(); // Create TLS configuration - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); // Setup a proxy server targetting the static file server - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: "localhost".to_string(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -113,7 +105,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, diff --git a/src/file_server.rs b/src/file_server.rs index bce4804..4c5c9bb 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -1,5 +1,8 @@ //! Static HTTP file server provided by an attested TLS proxy server -use crate::{AttestationGenerator, AttestationVerifier, ProxyError, ProxyServer, TlsCertAndKey}; +use crate::{ + AttestationGenerator, AttestationVerifier, OuterTlsConfig, OuterTlsMode, ProxyError, + ProxyServer, TlsCertAndKey, +}; use std::{net::SocketAddr, path::PathBuf}; use tokio::net::ToSocketAddrs; use tower_http::services::ServeDir; @@ -7,17 +10,28 @@ use tower_http::services::ServeDir; /// Setup a static file server serving the given directory, and a proxy server targetting it pub async fn attested_file_server( path_to_serve: PathBuf, - cert_and_key: TlsCertAndKey, - listen_addr: impl ToSocketAddrs, + outer_cert_and_key: Option, + outer_listen_addr: Option, + inner_listen_addr: Option, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, ) -> Result<(), ProxyError> { let target_addr = static_file_server(path_to_serve).await?; + let outer_session = match (outer_cert_and_key, outer_listen_addr) { + (Some(cert_and_key), Some(listen_addr)) => Some(OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), + (Some(_), None) | (None, Some(_)) => { + return Err(ProxyError::NoListenersConfigured); + } + (None, None) => None, + }; let server = ProxyServer::new( - cert_and_key, - listen_addr, + outer_session, + inner_listen_addr, target_addr.to_string(), attestation_generator, attestation_verifier, @@ -52,10 +66,10 @@ pub(crate) async fn static_file_server(path: PathBuf) -> Resultfoo"); - let (body, content_type) = get_body_and_content_type( - format!("http://{}/data.bin", proxy_client_addr.to_string()), - &client, - ) - .await; + let (body, content_type) = + get_body_and_content_type(format!("http://{}/data.bin", proxy_client_addr), &client) + .await; assert_eq!(content_type, "application/octet-stream"); assert_eq!(body, [0u8; 32]); } diff --git a/src/http_version.rs b/src/http_version.rs index bef817c..901df66 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -7,6 +7,9 @@ use std::task::{Context, Poll}; pub const ALPN_H2: &[u8] = b"h2"; pub const ALPN_HTTP11: &[u8] = b"http/1.1"; +type ProxyClientTlsStream = + tokio_rustls::client::TlsStream>; + /// Supported HTTP versions #[derive(Debug)] pub enum HttpVersion { @@ -55,13 +58,11 @@ impl HttpVersion { type Http1Sender = hyper::client::conn::http1::SendRequest; type Http2Sender = hyper::client::conn::http2::SendRequest; -type Http1Connection = hyper::client::conn::http1::Connection< - TokioIo>, - hyper::body::Incoming, ->; +type Http1Connection = + hyper::client::conn::http1::Connection, hyper::body::Incoming>; type Http2Connection = hyper::client::conn::http2::Connection< - TokioIo>, + TokioIo, hyper::body::Incoming, crate::TokioExecutor, >; diff --git a/src/lib.rs b/src/lib.rs index 8e0ce80..88aa200 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,46 +3,38 @@ pub mod attested_get; pub mod file_server; pub mod health_check; pub mod normalize_pem; -pub mod self_signed; -pub use attested_tls; -pub use attested_tls::attestation; -pub use attested_tls::attestation::AttestationGenerator; +pub use attestation; +pub use attestation::AttestationGenerator; mod http_version; #[cfg(test)] mod test_helpers; +use attestation::{AttestationError, AttestationVerifier}; +use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier, AttestedTlsError}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; use http_body_util::{BodyExt, combinators::BoxBody}; use hyper::{Response, service::service_fn}; use hyper_util::rt::TokioIo; +use nested_tls::server::NestingTlsStream; +use nested_tls::{client::NestingTlsConnector, server::NestingTlsAcceptor}; use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration}; use thiserror::Error; -use tokio::io; +use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; +use tokio_rustls::TlsAcceptor; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; use tokio_rustls::rustls::{ - self, ClientConfig, RootCertStore, ServerConfig, pki_types::CertificateDer, + self, ClientConfig, RootCertStore, ServerConfig, + pki_types::{CertificateDer, PrivateKeyDer, ServerName}, }; use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; -use attested_tls::{ - AttestedTlsClient, AttestedTlsError, AttestedTlsServer, TlsCertAndKey, - attestation::{ - AttestationError, AttestationType, AttestationVerifier, measurements::MultiMeasurements, - }, -}; - -/// The header name for giving attestation type -const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; - -/// The header name for giving measurements -const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; /// The header name for giving the forwarded for IP static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); @@ -55,12 +47,97 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; - type RequestWithResponseSender = ( http::Request, oneshot::Sender>, hyper::Error>>, ); +type OuterProxySession = (Arc, NestingTlsAcceptor); +type InnerProxySession = (Arc, TlsAcceptor); + +/// TLS Credentials +pub struct TlsCertAndKey { + /// Der-encoded TLS certificate chain + pub cert_chain: Vec>, + /// Der-encoded TLS private key + pub key: PrivateKeyDer<'static>, +} + +/// Configuration for the optional outer nested-TLS listener. +pub struct OuterTlsConfig { + /// The socket address to bind for the outer listener. + pub listen_addr: A, + /// How the outer TLS server configuration should be constructed. + pub tls: OuterTlsMode, +} + +/// TLS configuration sources for the outer nested-TLS listener. +pub enum OuterTlsMode { + /// Build the outer TLS server config from certificate and key material. + CertAndKey(TlsCertAndKey), + /// Use an already-constructed outer TLS server config. + Preconfigured { + /// The outer TLS server configuration to expose on the listener. + server_config: ServerConfig, + /// The server identity to embed into the inner attested certificate. + certificate_name: String, + }, +} + +impl OuterTlsConfig +where + A: ToSocketAddrs, +{ + fn certificate_name(&self) -> Result { + match &self.tls { + OuterTlsMode::CertAndKey(cert_and_key) => { + Ok(certificate_identity_from_chain(&cert_and_key.cert_chain)?) + } + OuterTlsMode::Preconfigured { + certificate_name, .. + } => Ok(certificate_name.clone()), + } + } + + async fn into_listener_and_acceptor( + self, + inner_server_config: Arc, + client_auth: bool, + ) -> Result<(Arc, NestingTlsAcceptor), ProxyError> { + let listen_addr = self.listen_addr; + let outer_server_config = match self.tls { + OuterTlsMode::CertAndKey(cert_and_key) => { + if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(verifier) + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } + } + OuterTlsMode::Preconfigured { server_config, .. } => server_config, + }; + + let outer_listener = Arc::new(TcpListener::bind(listen_addr).await?); + let outer_tls_acceptor = + NestingTlsAcceptor::new(Arc::new(outer_server_config), inner_server_config); + + Ok((outer_listener, outer_tls_acceptor)) + } +} + /// Adds HTTP 1 and 2 to the list of allowed protocols fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { for protocol in [ALPN_H2, ALPN_HTTP11] { @@ -72,108 +149,126 @@ fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec>) { } } -/// Retrieve the attested remote TLS certificate. -pub async fn get_tls_cert( +/// Retrieve the inner attested remote TLS certificate. +pub async fn get_inner_tls_cert( server_name: String, attestation_verifier: AttestationVerifier, - remote_certificate: Option>, - allow_self_signed: bool, -) -> Result<(Vec>, Option), AttestedTlsError> { - let (cert, measurements) = if allow_self_signed { - let client_tls_config = self_signed::client_tls_config_allow_self_signed()?; - attested_tls::get_tls_cert_with_config( - &server_name, - attestation_verifier, - client_tls_config, - ) - .await? - } else { - attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await? + remote_outer_certificate: Option>, +) -> Result>, ProxyError> { + let root_store = match remote_outer_certificate.as_ref() { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate.clone())?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), }; - debug!("[get-tls-cert] Connected to proxy server with measurements: {measurements:?}"); - Ok((cert, measurements)) + let outer_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(root_store) + .with_no_client_auth(); + + get_inner_tls_cert_with_config(server_name, attestation_verifier, outer_client_config).await +} + +pub async fn get_inner_tls_cert_with_config( + server_name: String, + attestation_verifier: AttestationVerifier, + outer_client_config: ClientConfig, +) -> Result>, ProxyError> { + let outbound_stream = tokio::net::TcpStream::connect(&server_name).await?; + + let domain = server_name_from_host(&server_name)?; + + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; + let inner_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + + let nested_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); + + let mut tls_stream = nested_tls_connector + .connect(domain, outbound_stream) + .await?; + debug!("[get-tls-cert] Connected to proxy server"); + + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)? + .to_owned(); + + tls_stream.shutdown().await?; + + Ok(remote_cert_chain) } /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { - /// The underlying attested TLS server - attested_tls_server: AttestedTlsServer, - /// The underlying TCP listener - listener: Arc, + outer: Option, + inner: Option, /// The address/hostname of the target service we are proxying to target: String, } impl ProxyServer { - pub async fn new( - cert_and_key: TlsCertAndKey, - local: impl ToSocketAddrs, + /// Start with dual listeners. The outer nested-TLS listener is optional. + pub async fn new( + outer_session: Option>, + inner_local: Option, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, - ) -> Result { - let mut server_config = if client_auth { - let root_store = - RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_client_cert_verifier(verifier) - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - } else { - ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - }; - ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols); - - let attested_tls_server = AttestedTlsServer::new_with_tls_config( - cert_and_key.cert_chain, - server_config, - attestation_generator, - attestation_verifier, - )?; - - let listener = TcpListener::bind(local).await?; - - Ok(Self { - attested_tls_server, - listener: listener.into(), - target, - }) - } - - /// Start with preconfigured TLS - pub async fn new_with_tls_config( - cert_chain: Vec>, - mut server_config: ServerConfig, - local: impl ToSocketAddrs, - target: String, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, - ) -> Result { - ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols); + ) -> Result + where + O: ToSocketAddrs, + I: ToSocketAddrs, + { + if outer_session.is_none() && inner_local.is_none() { + return Err(ProxyError::NoListenersConfigured); + } - let attested_tls_server = AttestedTlsServer::new_with_tls_config( - cert_chain, - server_config, - attestation_generator, - attestation_verifier, - )?; + let certificate_name = outer_session + .as_ref() + .map(OuterTlsConfig::certificate_name) + .transpose()?; + let inner_server_config = Arc::new( + build_inner_server_config( + attestation_generator, + attestation_verifier, + client_auth, + certificate_name, + ) + .await?, + ); + let inner = match inner_local { + Some(inner_local) => { + let inner_listener = Arc::new(TcpListener::bind(inner_local).await?); + let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone()); + Some((inner_listener, inner_tls_acceptor)) + } + None => None, + }; - let listener = TcpListener::bind(local).await?; + let outer = match outer_session { + Some(outer_session) => { + let (outer_listener, outer_tls_acceptor) = outer_session + .into_listener_and_acceptor(inner_server_config.clone(), client_auth) + .await?; + Some((outer_listener, outer_tls_acceptor)) + } + None => None, + }; Ok(Self { - attested_tls_server, - listener: listener.into(), + outer, + inner, target, }) } @@ -183,50 +278,151 @@ impl ProxyServer { /// Returns the handle for the task handling the connection pub async fn accept(&self) -> Result, ProxyError> { let target = self.target.clone(); - let (inbound, client_addr) = self.listener.accept().await?; - let attested_tls_server = self.attested_tls_server.clone(); - - let join_handle = tokio::spawn(async move { - match attested_tls_server.handle_connection(inbound).await { - Ok((tls_stream, measurements, attestation_type)) => { - if let Err(err) = Self::handle_connection( - tls_stream, - measurements, - attestation_type, - target, - client_addr, - ) - .await - { - warn!("Failed to handle connection: {err}"); + let outer = self.outer.clone(); + let inner = self.inner.clone(); + + let join_handle = match (outer, inner) { + ( + Some((outer_listener, outer_tls_acceptor)), + Some((inner_listener, inner_tls_acceptor)), + ) => { + let ((inbound, client_addr), use_outer) = tokio::select! { + accepted = outer_listener.accept() => (accepted?, true), + accepted = inner_listener.accept() => (accepted?, false), + }; + + tokio::spawn(async move { + if use_outer { + match outer_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_outer_connection(tls_stream, target, client_addr) + .await + { + warn!("Failed to handle outer connection: {err}"); + } + } + Err(err) => { + warn!("Outer attestation exchange failed: {err}"); + } + } + } else { + match inner_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_inner_connection(tls_stream, target, client_addr) + .await + { + warn!("Failed to handle inner connection: {err}"); + } + } + Err(err) => { + warn!("Inner attestation exchange failed: {err}"); + } + } } - } - Err(err) => { - warn!("Attestation exchange failed: {err}"); - } + }) } - }); + (None, Some((inner_listener, inner_tls_acceptor))) => { + let (inbound, client_addr) = inner_listener.accept().await?; + tokio::spawn(async move { + match inner_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_inner_connection(tls_stream, target, client_addr).await + { + warn!("Failed to handle inner connection: {err}"); + } + } + Err(err) => { + warn!("Inner attestation exchange failed: {err}"); + } + } + }) + } + (Some((outer_listener, outer_tls_acceptor)), None) => { + let (inbound, client_addr) = outer_listener.accept().await?; + tokio::spawn(async move { + match outer_tls_acceptor.accept(inbound).await { + Ok(tls_stream) => { + if let Err(err) = + Self::handle_outer_connection(tls_stream, target, client_addr).await + { + warn!("Failed to handle outer connection: {err}"); + } + } + Err(err) => { + warn!("Outer attestation exchange failed: {err}"); + } + } + }) + } + _ => return Err(ProxyError::NoListenersConfigured), + }; Ok(join_handle) } - /// Helper to get the socket address of the underlying TCP listener + /// Helper to get the socket address of either underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.listener.local_addr() + match &self.outer { + Some((listener, _)) => listener.local_addr(), + None => self + .inner + .as_ref() + .map(|(listener, _)| listener) + .ok_or_else(|| std::io::Error::other("no listeners configured"))? + .local_addr(), + } } - /// Handle an incoming connection from a proxy-client - async fn handle_connection( + /// Helper to get the socket address of the underlying outer TCP listener if present + pub fn outer_local_addr(&self) -> std::io::Result> { + self.outer + .as_ref() + .map(|(listener, _)| listener.local_addr()) + .transpose() + } + + /// Helper to get the socket address of the underlying inner TCP listener if present + pub fn inner_local_addr(&self) -> std::io::Result> { + self.inner + .as_ref() + .map(|(listener, _)| listener.local_addr()) + .transpose() + } + + async fn handle_outer_connection( + tls_stream: NestingTlsStream, + target: String, + client_addr: SocketAddr, + ) -> Result<(), ProxyError> { + debug!("[proxy-server] accepted connection"); + + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); + Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + } + + async fn handle_inner_connection( tls_stream: tokio_rustls::server::TlsStream, - measurements: Option, - remote_attestation_type: AttestationType, target: String, client_addr: SocketAddr, ) -> Result<(), ProxyError> { - debug!("[proxy-server] accepted connection with measurements: {measurements:?}"); + debug!("[proxy-server] accepted inner-only connection"); let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); + Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + } + async fn serve_tls_stream( + tls_stream: IO, + http_version: HttpVersion, + target: String, + client_addr: SocketAddr, + ) -> Result<(), ProxyError> + where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + { // Setup a request handler let service = service_fn(move |mut req| { debug!("[proxy-server] Handling request {req:?}"); @@ -251,27 +447,6 @@ impl ProxyServer { update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); - // If we have measurements, from the remote peer, add them to the request header - let measurements = measurements.clone(); - if let Some(measurements) = measurements { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); - } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - error!("Failed to encode measurement values: {e}"); - } - } - } - - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); - let target = target.clone(); async move { match Self::handle_http_request(req, target).await { @@ -372,16 +547,16 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, remote_certificate: Option>, ) -> Result { - let root_store = match remote_certificate { + let root_store = match remote_certificate.as_ref() { Some(remote_certificate) => { let mut root_store = RootCertStore::empty(); - root_store.add(remote_certificate)?; + root_store.add(remote_certificate.clone())?; root_store } None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), }; - let mut client_config = if let Some(ref cert_and_key) = cert_and_key { + let outer_client_config = if let Some(ref cert_and_key) = cert_and_key { ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .with_root_certificates(root_store) .with_client_auth_cert( @@ -393,43 +568,64 @@ impl ProxyClient { .with_root_certificates(root_store) .with_no_client_auth() }; - ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); - let attested_tls_client = AttestedTlsClient::new_with_tls_config( - client_config, + Self::new_with_tls_config( + outer_client_config, + address, + server_name, attestation_generator, attestation_verifier, - cert_and_key.map(|c| c.cert_chain), - )?; - - Self::new_with_inner(address, attested_tls_client, &server_name).await + cert_and_key.map(|cert_and_key| cert_and_key.cert_chain), + ) + .await } /// Create a new proxy client with given TLS configuration pub async fn new_with_tls_config( - mut client_config: ClientConfig, + outer_client_config: ClientConfig, address: impl ToSocketAddrs, target_name: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { - ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); + let outer_has_client_auth = outer_client_config.client_auth_cert_resolver.has_certs(); + let inner_has_client_auth = cert_chain.is_some(); - let attested_tls_client = AttestedTlsClient::new_with_tls_config( - client_config, - attestation_generator, - attestation_verifier, - cert_chain, - )?; + if outer_has_client_auth != inner_has_client_auth { + return Err(ProxyError::ClientAuthMisconfigured); + } + + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; + + let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_identity_from_chain(cert_chain)?, + ) + .await?; + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_client_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth() + }; + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + + let nesting_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - Self::new_with_inner(address, attested_tls_client, &target_name).await + Self::new_with_inner(address, nesting_tls_connector, &target_name).await } /// Create a new proxy client with given [AttestedTlsClient] pub async fn new_with_inner( address: impl ToSocketAddrs, - attested_tls_client: AttestedTlsClient, + nesting_tls_connector: NestingTlsConnector, target_name: &str, ) -> Result { let listener = TcpListener::bind(address).await?; @@ -452,9 +648,9 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn, measurements, remote_attestation_type) = + let (mut sender, conn) = // Connect to the proxy server and provide / verify attestation - match Self::setup_connection_with_backoff(&target, &attested_tls_client, first) + match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await { Ok(output) => { @@ -494,29 +690,8 @@ impl ProxyClient { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server let (response, should_reconnect) = match sender.send_request(req).await { - Ok(mut resp) => { + Ok(resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - // If we have measurements from the proxy-server, inject them into the - // response header - let headers = resp.headers_mut(); - if let Some(measurements) = measurements.clone() { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); - } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - error!("Failed to encode measurement values: {e}"); - } - } - } - - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); (Ok(resp.map(|b| b.boxed())), false) } Err(e) => { @@ -622,27 +797,19 @@ impl ProxyClient { // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( target: &str, - attested_tls_client: &AttestedTlsClient, + nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result< - ( - HttpSender, - HttpConnection, - Option, - AttestationType, - ), - ProxyError, - > { + ) -> Result<(HttpSender, HttpConnection), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); loop { - match Self::setup_connection(attested_tls_client, target).await { + match Self::setup_connection(nesting_tls_connector, target).await { Ok(output) => { return Ok(output); } Err(e) => { - if matches!(e, ProxyError::Io(_)) || !should_bail { + if should_retry_setup_error(&e, should_bail) { warn!("Reconnect failed: {e}. Retrying in {:#?}...", delay); tokio::time::sleep(delay).await; @@ -659,19 +826,16 @@ impl ProxyClient { /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( - inner: &AttestedTlsClient, + nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result< - ( - HttpSender, - HttpConnection, - Option, - AttestationType, - ), - ProxyError, - > { - let (tls_stream, measurements, remote_attestation_type) = inner.connect_tcp(target).await?; - debug!("[proxy-client] Connected to proxy server with measurements: {measurements:?}"); + ) -> Result<(HttpSender, HttpConnection), ProxyError> { + let outbound_stream = tokio::net::TcpStream::connect(target).await?; + + let domain = server_name_from_host(target)?; + let tls_stream = nesting_tls_connector + .connect(domain, outbound_stream) + .await?; + debug!("[proxy-client] Connected to proxy server"); // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -697,7 +861,7 @@ impl ProxyClient { }; // Return the HTTP client, as well as remote measurements - Ok((sender, conn, measurements, remote_attestation_type)) + Ok((sender, conn)) } // Handle a request from the source client to the proxy server @@ -711,6 +875,25 @@ impl ProxyClient { } } +fn should_retry_setup_error(error: &ProxyError, should_bail: bool) -> bool { + if !should_bail { + return true; + } + + match error { + ProxyError::Io(io_error) => matches!( + io_error.kind(), + std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::TimedOut + | std::io::ErrorKind::UnexpectedEof + ), + _ => false, + } +} + /// Update a request/response header if we are able to encode the header value /// /// This avoids bailing on bad header values - the headers are simply not updated @@ -747,16 +930,26 @@ pub enum ProxyError { IntConversion(#[from] TryFromIntError), #[error("Bad host name: {0}")] BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), + #[error("Invalid certificate encoding")] + InvalidCertificateEncoding, + #[error("Missing common name in certificate subject")] + MissingCertificateName, + #[error("Certificate common name is not valid UTF-8")] + InvalidCertificateName, #[error("HTTP: {0}")] Hyper(#[from] hyper::Error), + #[error("Attested TLS: {0}")] + AttestedTls(#[from] AttestedTlsError), #[error("JSON: {0}")] Json(#[from] serde_json::Error), #[error("Could not forward response - sender was dropped")] OneShotRecv(#[from] oneshot::error::RecvError), #[error("Failed to send request, connection to proxy-server dropped")] MpscSend, - #[error("Attested TLS: {0}")] - AttestedTls(#[from] AttestedTlsError), + #[error("Client auth must be configured on both the inner and outer TLS sessions")] + ClientAuthMisconfigured, + #[error("At least one server listener must be configured")] + NoListenersConfigured, } impl From> for ProxyError { @@ -765,6 +958,66 @@ impl From> for ProxyError { } } +/// Given a certifcate, get the hostname +fn hostname_from_cert(cert: &CertificateDer<'static>) -> Result { + let cert = x509_parser::parse_x509_certificate(cert.as_ref()) + .map(|(_, parsed)| parsed) + .map_err(|_| ProxyError::InvalidCertificateEncoding)?; + + Ok(cert + .subject() + .iter_common_name() + .next() + .ok_or(ProxyError::MissingCertificateName)? + .as_str() + .map_err(|_| ProxyError::InvalidCertificateName)? + .to_string()) +} + +fn certificate_identity_from_chain( + cert_chain: &[CertificateDer<'static>], +) -> Result { + hostname_from_cert(cert_chain.first().ok_or(ProxyError::NoCertificate)?) +} + +async fn build_attested_cert_resolver( + attestation_generator: AttestationGenerator, + certificate_name: String, +) -> Result { + Ok( + AttestedCertificateResolver::new(attestation_generator, None, certificate_name, vec![]) + .await?, + ) +} + +async fn build_inner_server_config( + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + client_auth: bool, + certificate_name: Option, +) -> Result { + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_name.unwrap_or_else(|| "localhost".to_string()), + ) + .await?; + + let mut inner_server_config = if client_auth { + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_client_cert_verifier(Arc::new(attested_cert_verifier)) + .with_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_cert_resolver(Arc::new(inner_cert_resolver)) + }; + + ensure_proxy_alpn_protocols(&mut inner_server_config.alpn_protocols); + + Ok(inner_server_config) +} + /// If no port was provided, default to 443 pub(crate) fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { @@ -774,6 +1027,16 @@ pub(crate) fn host_to_host_with_port(host: &str) -> String { } } +/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part +fn server_name_from_host( + host: &str, +) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> { + let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host); + let host_part = host_part.trim_matches(|c| c == '[' || c == ']'); + + ServerName::try_from(host_part.to_string()) +} + /// An Executor for hyper that uses the tokio runtime #[derive(Clone)] pub(crate) struct TokioExecutor; @@ -792,14 +1055,13 @@ where #[cfg(test)] mod tests { - use crate::{ - attestation::measurements::MeasurementPolicy, attested_tls::get_tls_cert_with_config, - }; + use attestation::{AttestationType, measurements::MeasurementPolicy}; + use tokio_rustls::TlsConnector; use super::*; use test_helpers::{ - example_http_service, generate_certificate_chain, generate_tls_config, - generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements, + example_http_service, generate_certificate_chain_for_host, generate_tls_config, + generate_tls_config_with_client_auth, init_tracing, }; #[test] @@ -818,19 +1080,74 @@ mod tests { assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]); } - #[tokio::test] - async fn http_proxy_default_constructors_work() { + #[tokio::test(flavor = "multi_thread")] + async fn proxy_server_requires_at_least_one_listener() { + let result = ProxyServer::new( + None::>, + None::<&str>, + "127.0.0.1:1".to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::NoListenersConfigured))); + } + + #[tokio::test(flavor = "multi_thread")] + async fn dual_listener_server_reports_expected_addresses() { let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let server_cert = cert_chain[0].clone(); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let tls_cert_and_key = TlsCertAndKey { + cert_chain, + key: private_key, + }; + + let dual_listener_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::CertAndKey(tls_cert_and_key), + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let outer_addr = dual_listener_server.outer_local_addr().unwrap().unwrap(); + let inner_addr = dual_listener_server.inner_local_addr().unwrap().unwrap(); + assert_eq!(dual_listener_server.local_addr().unwrap(), outer_addr); + assert_ne!(outer_addr, inner_addr); + + let inner_only_server = ProxyServer::new( + None::>, + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let inner_only_addr = inner_only_server.inner_local_addr().unwrap().unwrap(); + assert!(inner_only_server.outer_local_addr().unwrap().is_none()); + assert_eq!(inner_only_server.local_addr().unwrap(), inner_only_addr); + } + + #[tokio::test(flavor = "multi_thread")] + async fn inner_only_listener_negotiates_http2_by_default() { + let _ = rustls::crypto::ring::default_provider().install_default(); + let target_addr = example_http_service().await; let proxy_server = ProxyServer::new( - TlsCertAndKey { - cert_chain, - key: private_key, - }, - "127.0.0.1:0", + None::>, + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -839,68 +1156,113 @@ mod tests { .await .unwrap(); - let proxy_addr = proxy_server.local_addr().unwrap(); + let inner_addr = proxy_server.inner_local_addr().unwrap().unwrap(); tokio::spawn(async move { proxy_server.accept().await.unwrap(); }); - let proxy_client = ProxyClient::new( - None, - "127.0.0.1:0".to_string(), - proxy_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - Some(server_cert), + let attested_cert_verifier = + AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap(); + let mut client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols); + + let tls_connector = TlsConnector::from(Arc::new(client_config)); + let outbound_stream = TcpStream::connect(inner_addr).await.unwrap(); + let domain = ServerName::try_from("localhost".to_string()).unwrap(); + let mut tls_stream = tls_connector + .connect(domain, outbound_stream) + .await + .unwrap(); + + assert!(matches!( + HttpVersion::from_negotiated_protocol_client(&tls_stream), + HttpVersion::Http2 + )); + + tls_stream.shutdown().await.unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_negotiates_http2_by_default() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, outer_client_config) = + generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, ) .await .unwrap(); - let proxy_client_addr = proxy_client.local_addr().unwrap(); + let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { - proxy_client.accept().await.unwrap(); + proxy_server.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr)) - .await - .unwrap(); - - let headers = res.headers(); + let attested_cert_verifier = + AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap(); + let mut inner_client_config = + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth(); + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + let nesting_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); + let (sender, conn) = ProxyClient::setup_connection( + &nesting_tls_connector, + &format!("localhost:{}", proxy_addr.port()), + ) + .await + .unwrap(); - let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); + assert!(matches!(sender, HttpSender::Http2(_))); + assert!(matches!(conn, HttpConnection::Http2 { .. })); } // Server has mock DCAP, client has no attestation and no client auth - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_server_attestation() { let _ = tracing_subscriber::fmt::try_init(); let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -914,7 +1276,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, @@ -928,38 +1290,23 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - let headers = res.headers(); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } // Server has no attestation, client has mock DCAP and client auth - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_client_attestation() { let target_addr = example_http_service().await; let (server_cert_chain, server_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); + generate_certificate_chain_for_host("localhost"); let (client_cert_chain, client_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); + generate_certificate_chain_for_host("localhost"); let ( (_client_tls_server_config, client_tls_client_config), @@ -971,13 +1318,19 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_tls_server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config: server_tls_server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), + true, ) .await .unwrap(); @@ -985,14 +1338,13 @@ mod tests { let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { - // Accept one connection, then finish proxy_server.accept().await.unwrap(); }); let proxy_client = ProxyClient::new_with_tls_config( client_tls_client_config, "127.0.0.1:0", - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), Some(client_cert_chain), @@ -1003,52 +1355,40 @@ mod tests { let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { - // Accept two connections, then finish - proxy_client.accept().await.unwrap(); proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - // We expect no measurements from the server - let headers = res.headers(); - assert!(headers.get(MEASUREMENT_HEADER).is_none()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::None.as_str()); - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); + assert_eq!(res_body, "No measurements"); } // Server has no attestation, client has mock DCAP but no client auth - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_client_attestation_no_client_auth() { let target_addr = example_http_service().await; let (server_cert_chain, server_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); + generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(server_cert_chain.clone(), server_private_key); - let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), + false, ) .await .unwrap(); @@ -1063,7 +1403,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0", - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), None, @@ -1079,39 +1419,22 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - // We expect no measurements from the server - let headers = res.headers(); - assert!(headers.get(MEASUREMENT_HEADER).is_none()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::None.as_str()); - - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); + let _res_body = res.text().await.unwrap(); } // Server has mock DCAP, client has mock DCAP and client auth - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_mutual_attestation() { let target_addr = example_http_service().await; let (server_cert_chain, server_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); + generate_certificate_chain_for_host("localhost"); let (client_cert_chain, client_private_key) = - generate_certificate_chain("127.0.0.1".parse().unwrap()); + generate_certificate_chain_for_host("localhost"); let ( (_client_tls_server_config, client_tls_client_config), @@ -1123,13 +1446,19 @@ mod tests { server_private_key, ); - let proxy_server = ProxyServer::new_with_tls_config( - server_cert_chain, - server_tls_server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config: server_tls_server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), + true, ) .await .unwrap(); @@ -1137,14 +1466,13 @@ mod tests { let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { - // Accept one connection, then finish proxy_server.accept().await.unwrap(); }); let proxy_client = ProxyClient::new_with_tls_config( client_tls_client_config, "127.0.0.1:0", - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), Some(client_cert_chain), @@ -1155,80 +1483,42 @@ mod tests { let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { - // Accept two connections, then finish proxy_client.accept().await.unwrap(); proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); + assert_eq!(res.text().await.unwrap(), "No measurements"); - let headers = res.headers(); - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - // Now do another request - to check that the connection has stayed open - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - - let headers = res.headers(); - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let res_body = res.text().await.unwrap(); - - // The response body shows us what was in the request header (as the test http server - // handler puts them there) - let measurements = - MultiMeasurements::from_header_format(&res_body, AttestationType::DcapTdx).unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); + assert_eq!(res.text().await.unwrap(), "No measurements"); } // Server has mock DCAP, client no attestation - just get the server certificate - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_get_tls_cert() { let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain.clone(), - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1239,33 +1529,44 @@ mod tests { proxy_server.accept().await.unwrap(); }); - let (retrieved_chain, _measurements) = get_tls_cert_with_config( - &proxy_server_addr.to_string(), + let retrieved_chain = get_inner_tls_cert_with_config( + format!("localhost:{}", proxy_server_addr.port()), AttestationVerifier::mock(), client_config, ) .await .unwrap(); - assert_eq!(retrieved_chain, cert_chain); + assert_eq!(retrieved_chain.len(), 1); + assert_eq!( + hostname_from_cert(&retrieved_chain[0]).unwrap(), + "localhost" + ); + assert_ne!(retrieved_chain, cert_chain); } // Negative test - server does not provide attestation but client requires it // Server has no attestaion, client has no attestation and no client auth - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn fails_on_no_attestation_when_expected() { let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1279,37 +1580,39 @@ mod tests { let proxy_client_result = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, ) .await; - assert!(matches!( - proxy_client_result.unwrap_err(), - ProxyError::AttestedTls(AttestedTlsError::Attestation( - AttestationError::AttestationTypeNotAccepted - )) - )); + let err = proxy_client_result.unwrap_err().to_string(); + assert!(err.contains("ApplicationVerificationFailure"), "{err}"); } // Negative test - server does not provide attestation but client requires it // Server has no attestaion, client has no attestation and no client auth - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn fails_on_bad_measurements() { let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1348,37 +1651,39 @@ mod tests { let proxy_client_result = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), attestation_verifier, None, ) .await; - assert!(matches!( - proxy_client_result.unwrap_err(), - ProxyError::AttestedTls(AttestedTlsError::Attestation( - AttestationError::MeasurementsNotAccepted - )) - )); + let err = proxy_client_result.unwrap_err().to_string(); + assert!(err.contains("ApplicationVerificationFailure"), "{err}"); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_client_reconnects_on_lost_connection() { init_tracing(); let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) .await .unwrap(); @@ -1403,7 +1708,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, @@ -1418,7 +1723,7 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); @@ -1426,56 +1731,42 @@ mod tests { connection_breaker_tx.send(()).unwrap(); // Make another request - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - let headers = res.headers(); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } // Use HTTP 1.1 - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_http1() { let target_addr = example_http_service().await; - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); let (mut server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); server_config.alpn_protocols.push(ALPN_HTTP11.to_vec()); - let attested_tls_server = AttestedTlsServer::new_with_tls_config( - cert_chain, - server_config, + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), + false, ) + .await .unwrap(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - - let proxy_server = ProxyServer { - attested_tls_server, - listener: listener.into(), - target: target_addr.to_string(), - }; - let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { @@ -1485,7 +1776,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0".to_string(), - proxy_addr.to_string(), + format!("localhost:{}", proxy_addr.port()), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, @@ -1499,25 +1790,10 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - let headers = res.headers(); - - let attestation_type = headers - .get(ATTESTATION_TYPE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); - - let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); - let measurements = - MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) - .unwrap(); - assert_eq!(measurements, mock_dcap_measurements()); - let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } diff --git a/src/main.rs b/src/main.rs index d929778..a80a54b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,15 @@ use anyhow::{anyhow, ensure}; -use attested_tls::attestation::measurements::MultiMeasurements; +use attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}; use clap::{Parser, Subcommand}; -use std::{ - fs::File, - net::{IpAddr, SocketAddr}, - path::PathBuf, -}; +use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio::io::AsyncWriteExt; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ - AttestationGenerator, ProxyClient, ProxyServer, - attested_get::attested_get, - attested_tls::{ - TlsCertAndKey, - attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}, - }, - file_server::attested_file_server, - get_tls_cert, health_check, - normalize_pem::normalize_private_key_pem_to_pkcs8, + AttestationGenerator, OuterTlsConfig, OuterTlsMode, ProxyClient, ProxyServer, TlsCertAndKey, + attested_get::attested_get, file_server::attested_file_server, get_inner_tls_cert, + health_check, normalize_pem::normalize_private_key_pem_to_pkcs8, }; const GIT_REV: &str = match option_env!("GIT_REV") { @@ -84,25 +74,25 @@ enum CliCommand { // Address to listen on for health checks #[arg(long)] listen_addr_healthcheck: Option, - /// Enables verification of self-signed TLS certificates - #[arg(long)] - allow_self_signed: bool, }, /// Run a proxy server Server { - /// Socket address to listen on - #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] - listen_addr: SocketAddr, + /// Socket address to listen on for the outer nested-TLS listener, if enabled + #[arg(long)] + outer_listen_addr: Option, + /// Socket address to listen on for the inner-only attested TLS listener + #[arg(long)] + inner_listen_addr: Option, /// The hostname:port or ip:port of the target service to forward traffic to target_addr: String, /// Type of attestation to present (dafaults to 'auto' for automatic detection) - /// If other than None, a TLS key and certicate must also be given + /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, - /// The path to a PEM encoded private key + /// The path to a PEM encoded private key for the optional outer nested-TLS listener #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: Option, - /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. + /// PEM certificate chain for the optional outer nested-TLS listener #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: Option, /// Whether to use client authentication. If the client is running in a CVM this must be @@ -124,9 +114,6 @@ enum CliCommand { /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. #[arg(long)] tls_ca_certificate: Option, - /// Enables verification of self-signed TLS certificates - #[arg(long)] - allow_self_signed: bool, /// Filename to write measurements as JSON to #[arg(long)] out_measurements: Option, @@ -135,19 +122,22 @@ enum CliCommand { AttestedFileServer { /// Filesystem path to statically serve path_to_serve: PathBuf, - /// Socket address to listen on - #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] - listen_addr: SocketAddr, + /// Socket address to listen on for the outer nested-TLS listener, if enabled + #[arg(long)] + outer_listen_addr: Option, + /// Socket address to listen on for the inner-only attested TLS listener + #[arg(long)] + inner_listen_addr: Option, /// Type of attestation to present (dafaults to none) - /// If other than None, a TLS key and certicate must also be given + /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, - /// The path to a PEM encoded private key + /// The path to a PEM encoded private key for the optional outer nested-TLS listener #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] - tls_private_key_path: PathBuf, - /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. + tls_private_key_path: Option, + /// PEM certificate chain for the optional outer nested-TLS listener #[arg(long, env = "TLS_CERTIFICATE_PATH")] - tls_certificate_path: PathBuf, + tls_certificate_path: Option, /// URL of the remote dummy attestation service. Only use with --server-attestation-type /// dummy #[arg(long)] @@ -164,9 +154,6 @@ enum CliCommand { /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. #[arg(long)] tls_ca_certificate: Option, - /// Enables verification of self-signed TLS certificates - #[arg(long)] - allow_self_signed: bool, }, } @@ -241,7 +228,6 @@ async fn main() -> anyhow::Result<()> { tls_ca_certificate, dev_dummy_dcap, listen_addr_healthcheck, - allow_self_signed, } => { let target_addr = target_addr .strip_prefix("https://") @@ -280,29 +266,15 @@ async fn main() -> anyhow::Result<()> { AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap) .await?; - let client = if allow_self_signed { - let client_tls_config = - attested_tls_proxy::self_signed::client_tls_config_allow_self_signed()?; - ProxyClient::new_with_tls_config( - client_tls_config, - listen_addr, - target_addr, - client_attestation_generator, - attestation_verifier, - None, - ) - .await? - } else { - ProxyClient::new( - tls_cert_and_chain, - listen_addr, - target_addr, - client_attestation_generator, - attestation_verifier, - remote_tls_cert, - ) - .await? - }; + let client = ProxyClient::new( + tls_cert_and_chain, + listen_addr, + target_addr, + client_attestation_generator, + attestation_verifier, + remote_tls_cert, + ) + .await?; loop { if let Err(err) = client.accept().await { @@ -311,7 +283,8 @@ async fn main() -> anyhow::Result<()> { } } CliCommand::Server { - listen_addr, + outer_listen_addr, + inner_listen_addr, target_addr, tls_private_key_path, tls_certificate_path, @@ -324,10 +297,12 @@ async fn main() -> anyhow::Result<()> { health_check::server(listen_addr_healthcheck).await?; } - let tls_cert_and_chain = load_tls_cert_and_key_server( - tls_certificate_path, - tls_private_key_path, - listen_addr.ip(), + let tls_cert_and_chain = + load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; + validate_listener_args( + inner_listen_addr, + outer_listen_addr, + tls_cert_and_chain.is_some(), )?; let local_attestation_generator = @@ -335,8 +310,13 @@ async fn main() -> anyhow::Result<()> { .await?; let server = ProxyServer::new( - tls_cert_and_chain, - listen_addr, + tls_cert_and_chain + .zip(outer_listen_addr) + .map(|(cert_and_key, listen_addr)| OuterTlsConfig { + listen_addr, + tls: OuterTlsMode::CertAndKey(cert_and_key), + }), + inner_listen_addr, target_addr, local_attestation_generator, attestation_verifier, @@ -353,8 +333,7 @@ async fn main() -> anyhow::Result<()> { CliCommand::GetTlsCert { server, tls_ca_certificate, - allow_self_signed, - out_measurements, + out_measurements: _, // TODO } => { let remote_tls_cert = match tls_ca_certificate { Some(remote_cert_filename) => Some( @@ -365,36 +344,37 @@ async fn main() -> anyhow::Result<()> { ), None => None, }; - let (cert_chain, measurements) = get_tls_cert( - server, - attestation_verifier, - remote_tls_cert, - allow_self_signed, - ) - .await?; - - // If the user chose to write measurements to a file as JSON - if let Some(path_to_write_measurements) = out_measurements { - std::fs::write( - path_to_write_measurements, - measurements - .unwrap_or(MultiMeasurements::NoAttestation) - .to_header_format()? - .as_bytes(), - )?; - } + let cert_chain = + get_inner_tls_cert(server, attestation_verifier, remote_tls_cert).await?; + + // // If the user chose to write measurements to a file as JSON + // if let Some(path_to_write_measurements) = out_measurements { + // std::fs::write( + // path_to_write_measurements, + // measurements + // .unwrap_or(MultiMeasurements::NoAttestation) + // .to_header_format()? + // .as_bytes(), + // )?; + // } println!("{}", certs_to_pem_string(&cert_chain)?); } CliCommand::AttestedFileServer { path_to_serve, - listen_addr, + outer_listen_addr, + inner_listen_addr, server_attestation_type, tls_private_key_path, tls_certificate_path, dev_dummy_dcap, } => { let tls_cert_and_chain = - load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?; + load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; + validate_listener_args( + inner_listen_addr, + outer_listen_addr, + tls_cert_and_chain.is_some(), + )?; let server_attestation_type: AttestationType = serde_json::from_value( serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())), @@ -406,7 +386,8 @@ async fn main() -> anyhow::Result<()> { attested_file_server( path_to_serve, tls_cert_and_chain, - listen_addr, + outer_listen_addr, + inner_listen_addr, attestation_generator, attestation_verifier, false, @@ -417,7 +398,6 @@ async fn main() -> anyhow::Result<()> { target_addr, url_path, tls_ca_certificate, - allow_self_signed, } => { let remote_tls_cert = match tls_ca_certificate { Some(remote_cert_filename) => Some( @@ -434,7 +414,6 @@ async fn main() -> anyhow::Result<()> { &url_path.unwrap_or_default(), attestation_verifier, remote_tls_cert, - allow_self_signed, ) .await?; @@ -455,22 +434,41 @@ async fn main() -> anyhow::Result<()> { fn load_tls_cert_and_key_server( cert_chain: Option, private_key: Option, - ip: IpAddr, -) -> anyhow::Result { - if let Some(private_key) = private_key { - load_tls_cert_and_key( - cert_chain.ok_or(anyhow!("Private key given but no certificate chain"))?, - private_key, - ) - } else { - if cert_chain.is_some() { - return Err(anyhow!("Certificate chain provided but no private key")); +) -> anyhow::Result> { + match (cert_chain, private_key) { + (Some(cert_chain), Some(private_key)) => { + Ok(Some(load_tls_cert_and_key(cert_chain, private_key)?)) } - tracing::warn!("No TLS ceritifcate provided - generating self-signed"); - Ok(attested_tls_proxy::self_signed::generate_self_signed_cert( - ip, - )?) + (Some(_), None) => Err(anyhow!("Certificate chain provided but no private key")), + (None, Some(_)) => Err(anyhow!("Private key given but no certificate chain")), + (None, None) => Ok(None), + } +} + +fn validate_listener_args( + inner_listen_addr: Option, + outer_listen_addr: Option, + has_outer_tls: bool, +) -> anyhow::Result<()> { + if inner_listen_addr.is_none() && outer_listen_addr.is_none() { + return Err(anyhow!( + "At least one of --inner-listen-addr or --outer-listen-addr must be provided" + )); + } + + if has_outer_tls && outer_listen_addr.is_none() { + return Err(anyhow!( + "--outer-listen-addr is required when TLS certificate and key are provided" + )); } + + if !has_outer_tls && outer_listen_addr.is_some() { + return Err(anyhow!( + "--outer-listen-addr requires TLS certificate and key" + )); + } + + Ok(()) } /// Load TLS details from storage diff --git a/src/self_signed.rs b/src/self_signed.rs deleted file mode 100644 index e41e826..0000000 --- a/src/self_signed.rs +++ /dev/null @@ -1,323 +0,0 @@ -use std::{net::IpAddr, sync::Arc}; -use tokio_rustls::rustls::{ - self, - crypto::CryptoProvider, - pki_types::{self, CertificateDer, PrivatePkcs8KeyDer}, -}; -use x509_parser::prelude::{FromDer, X509Certificate}; - -use crate::attested_tls::{AttestedTlsError, TlsCertAndKey}; - -/// Generate a self signed certifcate -pub fn generate_self_signed_cert(ip_address: IpAddr) -> Result { - let keypair = rcgen::KeyPair::generate()?; - let mut params = rcgen::CertificateParams::default(); - params - .subject_alt_names - .push(rcgen::SanType::IpAddress(ip_address)); - - let cert = params.self_signed(&keypair)?; - Ok(TlsCertAndKey { - cert_chain: vec![cert.der().clone()], - key: PrivatePkcs8KeyDer::from(keypair.serialize_der()).into(), - }) -} - -/// Client TLS configuration which accepts self-signed remote certificates -pub fn client_tls_config_allow_self_signed() -> Result { - Ok(rustls::ClientConfig::builder() - .dangerous() - .with_custom_certificate_verifier(SkipServerVerification::new()?) - .with_no_client_auth()) -} - -/// Used to allow verification of self-signed certificates -#[derive(Debug, Clone)] -pub struct SkipServerVerification { - supported_algs: rustls::crypto::WebPkiSupportedAlgorithms, -} - -impl SkipServerVerification { - pub fn new() -> Result, AttestedTlsError> { - Ok(Arc::new(Self { - supported_algs: Arc::new( - CryptoProvider::get_default().ok_or(AttestedTlsError::NoCryptoProvider)?, - ) - .clone() - .signature_verification_algorithms, - })) - } -} - -impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { - fn verify_server_cert( - &self, - end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &pki_types::ServerName<'_>, - _ocsp_response: &[u8], - _now: pki_types::UnixTime, - ) -> Result { - // Parse the certificate - let (_, cert) = X509Certificate::from_der(end_entity).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) - })?; - - // Verify signature - cert.verify_signature(None).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature) - })?; - - Ok(rustls::client::danger::ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls12_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls13_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_algs.supported_schemes() - } -} - -/// Used to allow verification of self-signed certificates during client authentication -#[derive(Debug)] -pub struct SkipClientVerification { - supported_algs: rustls::crypto::WebPkiSupportedAlgorithms, -} - -impl SkipClientVerification { - pub fn new() -> std::sync::Arc { - std::sync::Arc::new(Self { - supported_algs: Arc::new(CryptoProvider::get_default().unwrap()) - .clone() - .signature_verification_algorithms, - }) - } -} - -impl rustls::server::danger::ClientCertVerifier for SkipClientVerification { - fn verify_client_cert( - &self, - end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer], - _now: rustls::pki_types::UnixTime, - ) -> Result { - // Parse the certificate - let (_, cert) = X509Certificate::from_der(end_entity).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) - })?; - - // Verify signature - cert.verify_signature(None).map_err(|_| { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature) - })?; - Ok(rustls::server::danger::ClientCertVerified::assertion()) - } - - fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] { - &[] - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls12_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - let provider = rustls::crypto::CryptoProvider::get_default() - .ok_or_else(|| rustls::Error::General("No crypto provider installed".into()))?; - - rustls::crypto::verify_tls13_signature( - message, - cert, - dss, - &provider.signature_verification_algorithms, - )?; - - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_algs.supported_schemes() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - AttestationGenerator, - attestation::{AttestationType, AttestationVerifier}, - attested_tls::{AttestedTlsClient, AttestedTlsServer}, - test_helpers::{generate_certificate_chain, generate_tls_config}, - }; - use tokio::net::TcpListener; - use tokio_rustls::rustls::pki_types::ServerName; - - #[tokio::test] - async fn self_signed_server_attestation() { - let cert_and_key = generate_self_signed_cert("127.0.0.1".parse().unwrap()).unwrap(); - - let server_config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone().to_vec(), - cert_and_key.key.clone_key(), - ) - .unwrap(); - - let server = AttestedTlsServer::new_with_tls_config( - cert_and_key.cert_chain, - server_config.into(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .unwrap(); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let server_addr = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (tcp_stream, _) = listener.accept().await.unwrap(); - let (_stream, _measurements, _attestation_type) = - server.handle_connection(tcp_stream).await.unwrap(); - }); - - let client_config = client_tls_config_allow_self_signed().unwrap(); - - let client = AttestedTlsClient::new_with_tls_config( - client_config.into(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .unwrap(); - - let (_stream, _measurements, _attestation_type) = - client.connect_tcp(&server_addr.to_string()).await.unwrap(); - } - - #[tokio::test] - async fn nested_tls_with_self_signed_server_attestation() { - // Outer TLS setup - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (outer_server_config, outer_client_config) = - generate_tls_config(cert_chain.clone(), private_key); - - let outer_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(outer_server_config)); - let outer_connector = tokio_rustls::TlsConnector::from(Arc::new(outer_client_config)); - - // Inner TLS setup - let cert_and_key = generate_self_signed_cert("127.0.0.1".parse().unwrap()).unwrap(); - - let server_config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_single_cert( - cert_and_key.cert_chain.clone().to_vec(), - cert_and_key.key.clone_key(), - ) - .unwrap(); - - let server = AttestedTlsServer::new_with_tls_config( - cert_and_key.cert_chain, - server_config.into(), - AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - AttestationVerifier::expect_none(), - ) - .unwrap(); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let server_addr = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (tcp_stream, _) = listener.accept().await.unwrap(); - - // Do outer TLS handshake - let tls_stream = outer_acceptor.accept(tcp_stream).await.unwrap(); - - // Do inner (attested) TLS - let (_stream, _measurements, _attestation_type) = - server.handle_connection(tls_stream).await.unwrap(); - }); - - // Inner TLS config - let client_config = client_tls_config_allow_self_signed().unwrap(); - - let client = AttestedTlsClient::new_with_tls_config( - client_config.into(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .unwrap(); - - let client_tcp_stream = tokio::net::TcpStream::connect(&server_addr).await.unwrap(); - - // Outer TLS handshake - let server_name = ServerName::try_from(server_addr.ip().to_string()).unwrap(); - let tls_stream = outer_connector - .connect(server_name, client_tcp_stream) - .await - .unwrap(); - - // Inner (attested) TLS - let (_stream, _measurements, _attestation_type) = client - .connect(&server_addr.to_string(), tls_stream) - .await - .unwrap(); - } -} diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 8990734..431c5f8 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -1,8 +1,7 @@ //! Helper functions used in tests use axum::response::IntoResponse; use std::{ - collections::HashMap, - net::{IpAddr, SocketAddr}, + net::SocketAddr, sync::{Arc, Once}, }; use tokio::net::TcpListener; @@ -15,20 +14,17 @@ use tracing_subscriber::{EnvFilter, fmt}; static INIT: Once = Once::new(); -use crate::{ - MEASUREMENT_HEADER, - attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}, -}; - -/// Helper to generate a self-signed certificate for testing -pub fn generate_certificate_chain( - ip: IpAddr, +/// Helper to generate a self-signed certificate for testing with a DNS subject name +pub fn generate_certificate_chain_for_host( + host: &str, ) -> (Vec>, PrivateKeyDer<'static>) { - let mut params = rcgen::CertificateParams::new(vec![]).unwrap(); - params.subject_alt_names.push(rcgen::SanType::IpAddress(ip)); + let mut params = rcgen::CertificateParams::new(vec![host.to_string()]).unwrap(); + params + .subject_alt_names + .push(rcgen::SanType::DnsName(host.try_into().unwrap())); params .distinguished_name - .push(rcgen::DnType::CommonName, ip.to_string()); + .push(rcgen::DnType::CommonName, host); let keypair = rcgen::KeyPair::generate().unwrap(); let cert = params.self_signed(&keypair).unwrap(); @@ -131,23 +127,13 @@ pub async fn example_http_service() -> SocketAddr { addr } -async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { - headers - .get(MEASUREMENT_HEADER) - .and_then(|v| v.to_str().ok()) - .unwrap_or("No measurements") - .to_string() -} - -/// All-zero measurment values used in some tests -pub fn mock_dcap_measurements() -> MultiMeasurements { - MultiMeasurements::Dcap(HashMap::from([ - (DcapMeasurementRegister::MRTD, [0u8; 48]), - (DcapMeasurementRegister::RTMR0, [0u8; 48]), - (DcapMeasurementRegister::RTMR1, [0u8; 48]), - (DcapMeasurementRegister::RTMR2, [0u8; 48]), - (DcapMeasurementRegister::RTMR3, [0u8; 48]), - ])) +async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse { + // headers + // .get(MEASUREMENT_HEADER) + // .and_then(|v| v.to_str().ok()) + // .unwrap_or("No measurements") + // .to_string() + "No measurements".to_string() } pub fn init_tracing() {