Skip to content
167 changes: 146 additions & 21 deletions crates/attestation/src/azure/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use x509_parser::prelude::*;

use crate::{dcap::verify_dcap_attestation_with_given_timestamp, measurements::MultiMeasurements};
use crate::{
dcap::{
verify_dcap_attestation_with_given_timestamp,
verify_dcap_attestation_with_timestamp_sync,
},
measurements::MultiMeasurements,
};

/// The attestation evidence payload that gets sent over the channel
#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -42,6 +48,16 @@ struct TpmAttest {
instance_info: Option<Vec<u8>>,
}

/// Used during verification to support both sync and async verification
/// paths without duplicating code
struct PreparedAzureAttestation {
tdx_quote_bytes: Vec<u8>,
hcl_report: hcl::HclReport,
var_data_hash: [u8; 32],
expected_tdx_input_data: [u8; 64],
tpm_attestation: TpmAttest,
}

/// Generate a TDX attestation on Azure
pub fn create_azure_attestation(input_data: [u8; 64]) -> Result<Vec<u8>, MaaError> {
let hcl_report_bytes = vtpm::get_report_with_report_data(&input_data)?;
Expand Down Expand Up @@ -100,6 +116,32 @@ pub async fn verify_azure_attestation(
.await
}

/// Verify a TDX attestation from Azure - synchronous version
///
/// This relies on having DCAP collateral already present in the cache
///
/// If possible, prefer the async version
pub fn verify_azure_attestation_sync(
input: Vec<u8>,
expected_input_data: [u8; 64],
pccs: Pccs,
override_azure_outdated_tcb: bool,
) -> Result<super::measurements::MultiMeasurements, MaaError> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();

verify_azure_attestation_with_given_timestamp_sync(
input,
expected_input_data,
pccs,
None,
now,
override_azure_outdated_tcb,
)
}

/// Do the verification, passing in the current time
/// This allows us to test this function without time checks going out of
/// date
Expand All @@ -111,30 +153,102 @@ async fn verify_azure_attestation_with_given_timestamp(
now: u64,
override_azure_outdated_tcb: bool,
) -> Result<super::measurements::MultiMeasurements, MaaError> {
let PreparedAzureAttestation {
tdx_quote_bytes,
hcl_report,
var_data_hash,
expected_tdx_input_data,
tpm_attestation,
} = prepare_azure_attestation(input)?;

let _dcap_measurements = verify_dcap_attestation_with_given_timestamp(
tdx_quote_bytes,
expected_tdx_input_data,
pccs,
collateral,
now,
override_azure_outdated_tcb,
)
.await?;

finish_azure_attestation_verification(
hcl_report,
var_data_hash,
tpm_attestation,
expected_input_data,
now,
)
}

/// Synchronous version of the verifier
fn verify_azure_attestation_with_given_timestamp_sync(
input: Vec<u8>,
expected_input_data: [u8; 64],
pccs: Pccs,
collateral: Option<QuoteCollateralV3>,
now: u64,
override_azure_outdated_tcb: bool,
) -> Result<super::measurements::MultiMeasurements, MaaError> {
let PreparedAzureAttestation {
tdx_quote_bytes,
hcl_report,
var_data_hash,
expected_tdx_input_data,
tpm_attestation,
} = prepare_azure_attestation(input)?;

let _dcap_measurements = verify_dcap_attestation_with_timestamp_sync(
tdx_quote_bytes,
expected_tdx_input_data,
pccs,
collateral,
now,
override_azure_outdated_tcb,
)?;

finish_azure_attestation_verification(
hcl_report,
var_data_hash,
tpm_attestation,
expected_input_data,
now,
)
}

/// Parses the attestation during verification
fn prepare_azure_attestation(input: Vec<u8>) -> Result<PreparedAzureAttestation, MaaError> {
let attestation_document: AttestationDocument = serde_json::from_slice(&input)?;
tracing::info!("Attempting to verifiy azure attestation: {attestation_document:?}");

let hcl_report_bytes = BASE64_URL_SAFE.decode(attestation_document.hcl_report_base64)?;
let AttestationDocument { tdx_quote_base64, hcl_report_base64, tpm_attestation } =
attestation_document;

let hcl_report_bytes = BASE64_URL_SAFE.decode(hcl_report_base64)?;
let hcl_report = hcl::HclReport::new(hcl_report_bytes)?;
let var_data_hash = hcl_report.var_data_sha256();

// Check that HCL var data hash matches TDX quote report data
let mut expected_tdx_input_data = [0u8; 64];
expected_tdx_input_data[..32].copy_from_slice(&var_data_hash);

// Do DCAP verification
let tdx_quote_bytes = BASE64_URL_SAFE.decode(attestation_document.tdx_quote_base64)?;
let _dcap_measurements = verify_dcap_attestation_with_given_timestamp(
let tdx_quote_bytes = BASE64_URL_SAFE.decode(tdx_quote_base64)?;

Ok(PreparedAzureAttestation {
tdx_quote_bytes,
hcl_report,
var_data_hash,
expected_tdx_input_data,
pccs,
collateral,
now,
override_azure_outdated_tcb,
)
.await?;
tpm_attestation,
})
}

/// The final part of vTPM verification, after verifying DCAP
fn finish_azure_attestation_verification(
hcl_report: hcl::HclReport,
var_data_hash: [u8; 32],
tpm_attestation: TpmAttest,
expected_input_data: [u8; 64],
now: u64,
) -> Result<super::measurements::MultiMeasurements, MaaError> {
let hcl_ak_pub = hcl_report.ak_pub()?;

// Get attestation key from runtime claims
Expand Down Expand Up @@ -166,17 +280,16 @@ async fn verify_azure_attestation_with_given_timestamp(
}

// Verify the vTPM quote
let vtpm_quote = attestation_document.tpm_attestation.quote;
let vtpm_quote = tpm_attestation.quote;
let hcl_ak_pub_der = hcl_ak_pub.key.try_to_der().map_err(|_| MaaError::JwkConversion)?;
let pub_key = PKey::public_key_from_der(&hcl_ak_pub_der)?;
vtpm_quote.verify(&pub_key, &expected_input_data[..32])?;

let pcrs = vtpm_quote.pcrs_sha256();

// Parse AK certificate
let (_type_label, ak_certificate_der) = pem_rfc7468::decode_vec(
attestation_document.tpm_attestation.ak_certificate_pem.as_bytes(),
)?;
let (_type_label, ak_certificate_der) =
pem_rfc7468::decode_vec(tpm_attestation.ak_certificate_pem.as_bytes())?;

let (remaining_bytes, ak_certificate) = X509Certificate::from_der(&ak_certificate_der)?;

Expand Down Expand Up @@ -354,7 +467,7 @@ mod tests {

let hcl_report = hcl::HclReport::new(hcl_bytes.to_vec()).unwrap();
let hcl_var_data = hcl_report.var_data();
let var_data_values: serde_json::Value = serde_json::from_slice(&hcl_var_data).unwrap();
let var_data_values: serde_json::Value = serde_json::from_slice(hcl_var_data).unwrap();

// Check that it contains 64 byte user data
assert_eq!(hex::decode(var_data_values["user-data"].as_str().unwrap()).unwrap().len(), 64);
Expand Down Expand Up @@ -394,20 +507,32 @@ mod tests {
let collateral_bytes: &'static [u8] =
include_bytes!("../../test-assets/azure-collateral02.json");

let collateral = serde_json::from_slice(collateral_bytes).unwrap();
let async_collateral = serde_json::from_slice(collateral_bytes).unwrap();
let sync_collateral = serde_json::from_slice(collateral_bytes).unwrap();

let measurements = verify_azure_attestation_with_given_timestamp(
let async_measurements = verify_azure_attestation_with_given_timestamp(
attestation_bytes.to_vec(),
[0; 64], // Input data
None,
collateral,
async_collateral,
now,
false,
)
.await
.unwrap();

measurement_policy.check_measurement(&measurements).unwrap();
let sync_measurements = verify_azure_attestation_with_given_timestamp_sync(
attestation_bytes.to_vec(),
[0; 64], // Input data
Pccs::new_without_prewarm(None),
sync_collateral,
now,
false,
)
.unwrap();

assert_eq!(async_measurements, sync_measurements);
measurement_policy.check_measurement(&async_measurements).unwrap();
}

#[tokio::test]
Expand Down
Loading
Loading