From a7b02248c0b4fc8cb75fdbd1eaf6b13af80ad273 Mon Sep 17 00:00:00 2001 From: John Driscoll Date: Wed, 11 Mar 2026 17:13:35 -0500 Subject: [PATCH] feat(wasm-mps): fix MPS DSG interface for 2-of-3 instead of 3-of-3 Ticket: HSM-1163 --- packages/wasm-mps/src/lib.rs | 147 +++++++--------------------------- packages/wasm-mps/test/mps.ts | 45 ++++------- 2 files changed, 45 insertions(+), 147 deletions(-) diff --git a/packages/wasm-mps/src/lib.rs b/packages/wasm-mps/src/lib.rs index 0b3b190..1ff5f44 100644 --- a/packages/wasm-mps/src/lib.rs +++ b/packages/wasm-mps/src/lib.rs @@ -253,20 +253,15 @@ mod mps { /// Process round 1 of DSG protocol. /// round1_messages: Public messages from other parties. /// state: Private state result from round 0. - pub fn dsg_round1_process( - round1_messages: &[Vec; 2], - state: &[u8], - ) -> Result { + pub fn dsg_round1_process(round1_message: &[u8], state: &[u8]) -> Result { // Parse state let state: DsgStateR1 = bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?; // Parse messages - let i0_msg1: SignMsg1 = bincode::deserialize(round1_messages[0].as_slice()) - .map_err(|_| MpsError::DeserializationError)?; - let i1_msg1: SignMsg1 = bincode::deserialize(round1_messages[1].as_slice()) - .map_err(|_| MpsError::DeserializationError)?; - let msgs = vec![i0_msg1, i1_msg1, state.msg]; + let i0_msg1: SignMsg1 = + bincode::deserialize(round1_message).map_err(|_| MpsError::DeserializationError)?; + let msgs = vec![i0_msg1, state.msg]; // Process all round1 messages together let (p2, msg2) = state @@ -289,20 +284,15 @@ mod mps { /// Process round 2 of DSG protocol. /// round2_messages: Public messages from other parties. /// state: Private state result from round 1. - pub fn dsg_round2_process( - round2_messages: &[Vec; 2], - state: &[u8], - ) -> Result { + pub fn dsg_round2_process(round2_message: &[u8], state: &[u8]) -> Result { // Parse state let state: DsgStateR2 = bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?; // Parse messages - let i0_msg2: SignMsg2 = bincode::deserialize(round2_messages[0].as_slice()) - .map_err(|_| MpsError::DeserializationError)?; - let i1_msg2: SignMsg2 = bincode::deserialize(round2_messages[1].as_slice()) - .map_err(|_| MpsError::DeserializationError)?; - let msgs = vec![i0_msg2, i1_msg2, state.msg]; + let i0_msg2: SignMsg2 = + bincode::deserialize(round2_message).map_err(|_| MpsError::DeserializationError)?; + let msgs = vec![i0_msg2, state.msg]; // Process all round2 messages together let party = state @@ -328,20 +318,15 @@ mod mps { /// Process round 3 of DSG protocol. /// round3_messages: Public messages from other parties. /// state: Private state result from round 2. - pub fn dsg_round3_process( - round3_messages: &[Vec; 2], - state: &[u8], - ) -> Result, MpsError> { + pub fn dsg_round3_process(round3_message: &[u8], state: &[u8]) -> Result, MpsError> { // Parse state let state: DsgStateR3 = bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?; // Parse messages - let i0_msg3: SignMsg3 = bincode::deserialize(round3_messages[0].as_slice()) - .map_err(|_| MpsError::DeserializationError)?; - let i1_msg3: SignMsg3 = bincode::deserialize(round3_messages[1].as_slice()) - .map_err(|_| MpsError::DeserializationError)?; - let msgs = vec![i0_msg3, i1_msg3, state.msg]; + let i0_msg3: SignMsg3 = + bincode::deserialize(round3_message).map_err(|_| MpsError::DeserializationError)?; + let msgs = vec![i0_msg3, state.msg]; // Process all round2 messages together let (signature, _) = state @@ -512,11 +497,6 @@ mod tests { dkg_p0_1.state.as_slice(), ) .unwrap(); - let dkg_p1_share = mps::dkg_round2_process( - &[dkg_p0_1.msg.clone(), dkg_p2_1.msg.clone()], - dkg_p1_1.state.as_slice(), - ) - .unwrap(); let dkg_p2_share = mps::dkg_round2_process( &[dkg_p0_1.msg.clone(), dkg_p1_1.msg.clone()], dkg_p2_1.state.as_slice(), @@ -529,70 +509,31 @@ mod tests { // Process DSG round 0 let dsg_p0_0 = mps::dsg_round0_process(dkg_p0_share.share.as_slice(), "m".to_string(), msg).unwrap(); - let dsg_p1_0 = - mps::dsg_round0_process(dkg_p1_share.share.as_slice(), "m".to_string(), msg).unwrap(); let dsg_p2_0 = mps::dsg_round0_process(dkg_p2_share.share.as_slice(), "m".to_string(), msg).unwrap(); // Process DSG round 1 - let dsg_p0_1 = mps::dsg_round1_process( - &[dsg_p1_0.msg.clone(), dsg_p2_0.msg.clone()], - dsg_p0_0.state.as_slice(), - ) - .unwrap(); - let dsg_p1_1 = mps::dsg_round1_process( - &[dsg_p0_0.msg.clone(), dsg_p2_0.msg.clone()], - dsg_p1_0.state.as_slice(), - ) - .unwrap(); - let dsg_p2_1 = mps::dsg_round1_process( - &[dsg_p0_0.msg.clone(), dsg_p1_0.msg.clone()], - dsg_p2_0.state.as_slice(), - ) - .unwrap(); + let dsg_p0_1 = + mps::dsg_round1_process(dsg_p2_0.msg.as_slice(), dsg_p0_0.state.as_slice()).unwrap(); + let dsg_p2_1 = + mps::dsg_round1_process(dsg_p0_0.msg.as_slice(), dsg_p2_0.state.as_slice()).unwrap(); // Process DSG round 2 - let dsg_p0_2 = mps::dsg_round2_process( - &[dsg_p1_1.msg.clone(), dsg_p2_1.msg.clone()], - dsg_p0_1.state.as_slice(), - ) - .unwrap(); - let dsg_p1_2 = mps::dsg_round2_process( - &[dsg_p0_1.msg.clone(), dsg_p2_1.msg.clone()], - dsg_p1_1.state.as_slice(), - ) - .unwrap(); - let dsg_p2_2 = mps::dsg_round2_process( - &[dsg_p0_1.msg.clone(), dsg_p1_1.msg.clone()], - dsg_p2_1.state.as_slice(), - ) - .unwrap(); + let dsg_p0_2 = + mps::dsg_round2_process(dsg_p2_1.msg.as_slice(), dsg_p0_1.state.as_slice()).unwrap(); + let dsg_p2_2 = + mps::dsg_round2_process(dsg_p0_1.msg.as_slice(), dsg_p2_1.state.as_slice()).unwrap(); // Process DSG round 3 - let dsg_p0_sig = mps::dsg_round3_process( - &[dsg_p1_2.msg.clone(), dsg_p2_2.msg.clone()], - dsg_p0_2.state.as_slice(), - ) - .unwrap(); - let dsg_p1_sig = mps::dsg_round3_process( - &[dsg_p0_2.msg.clone(), dsg_p2_2.msg.clone()], - dsg_p1_2.state.as_slice(), - ) - .unwrap(); - let dsg_p2_sig = mps::dsg_round3_process( - &[dsg_p0_2.msg.clone(), dsg_p1_2.msg.clone()], - dsg_p2_2.state.as_slice(), - ) - .unwrap(); + let dsg_p0_sig = + mps::dsg_round3_process(dsg_p2_2.msg.as_slice(), dsg_p0_2.state.as_slice()).unwrap(); + let dsg_p2_sig = + mps::dsg_round3_process(dsg_p0_2.msg.as_slice(), dsg_p2_2.state.as_slice()).unwrap(); assert_eq!( dsg_p2_sig, dsg_p0_sig, "Party 0 signature differs from party 2 signature" ); - assert_eq!( - dsg_p2_sig, dsg_p1_sig, - "Party 1 signature differs from party 2 signature" - ); // Verify signature VerifyingKey::from_bytes(&dkg_p0_share.pk) @@ -602,13 +543,6 @@ mod tests { &Signature::from_bytes(dsg_p0_sig.as_slice().try_into().unwrap()), ) .unwrap(); - VerifyingKey::from_bytes(&dkg_p1_share.pk) - .unwrap() - .verify( - msg, - &Signature::from_bytes(dsg_p1_sig.as_slice().try_into().unwrap()), - ) - .unwrap(); VerifyingKey::from_bytes(&dkg_p2_share.pk) .unwrap() .verify( @@ -760,15 +694,8 @@ pub fn dsg_round0_process( } #[wasm_bindgen] -pub fn dsg_round1_process(round1_messages: Array, state: &[u8]) -> Result { - let result = mps::dsg_round1_process( - &[ - js_sys::Uint8Array::from(round1_messages.get(0)).to_vec(), - js_sys::Uint8Array::from(round1_messages.get(1)).to_vec(), - ], - state, - ) - .map_err(|e| e.to_string())?; +pub fn dsg_round1_process(round1_message: &[u8], state: &[u8]) -> Result { + let result = mps::dsg_round1_process(round1_message, state).map_err(|e| e.to_string())?; Ok(MsgState { msg: result.msg, @@ -777,15 +704,8 @@ pub fn dsg_round1_process(round1_messages: Array, state: &[u8]) -> Result Result { - let result = mps::dsg_round2_process( - &[ - js_sys::Uint8Array::from(round2_messages.get(0)).to_vec(), - js_sys::Uint8Array::from(round2_messages.get(1)).to_vec(), - ], - state, - ) - .map_err(|e| e.to_string())?; +pub fn dsg_round2_process(round2_message: &[u8], state: &[u8]) -> Result { + let result = mps::dsg_round2_process(round2_message, state).map_err(|e| e.to_string())?; Ok(MsgState { msg: result.msg, @@ -794,15 +714,8 @@ pub fn dsg_round2_process(round2_messages: Array, state: &[u8]) -> Result Result, String> { - let result = mps::dsg_round3_process( - &[ - js_sys::Uint8Array::from(round2_messages.get(0)).to_vec(), - js_sys::Uint8Array::from(round2_messages.get(1)).to_vec(), - ], - state, - ) - .map_err(|e| e.to_string())?; +pub fn dsg_round3_process(round2_message: &[u8], state: &[u8]) -> Result, String> { + let result = mps::dsg_round3_process(round2_message, state).map_err(|e| e.to_string())?; Ok(result.to_vec()) } diff --git a/packages/wasm-mps/test/mps.ts b/packages/wasm-mps/test/mps.ts index 9817d95..5b9a3b2 100644 --- a/packages/wasm-mps/test/mps.ts +++ b/packages/wasm-mps/test/mps.ts @@ -193,6 +193,7 @@ describe("mps", function () { }); describe("dsg", function () { + const otherIndex = [1, 0]; let shares: Array; before("performs dkg", function () { @@ -223,7 +224,7 @@ describe("mps", function () { ); it("performs round 0", function () { - for (let i = 0; i < 3; i++) { + for (const i of [0, 2]) { mps.dsg_round0_process(shares[i].share, "m", message); } }); @@ -231,59 +232,43 @@ describe("mps", function () { let results1: Array; before("performs round 0", function () { - results1 = [0, 1, 2].map((i) => mps.dsg_round0_process(shares[i].share, "m", message)); + results1 = [0, 2].map((i) => mps.dsg_round0_process(shares[i].share, "m", message)); }); it("performs round 1", function () { - for (let i = 0; i < 3; i++) { - mps.dsg_round1_process( - otherIndices[i].map((i) => results1[i].msg), - results1[i].state, - ); + for (let i = 0; i < 2; i++) { + mps.dsg_round1_process(results1[otherIndex[i]].msg, results1[i].state); } }); let results2: Array; before("performs round 1", function () { - results2 = [0, 1, 2].map((i) => - mps.dsg_round1_process( - otherIndices[i].map((i) => results1[i].msg), - results1[i].state, - ), + results2 = [0, 1].map((i) => + mps.dsg_round1_process(results1[otherIndex[i]].msg, results1[i].state), ); }); it("performs round 2", function () { - for (let i = 0; i < 3; i++) { - mps.dsg_round2_process( - otherIndices[i].map((i) => results2[i].msg), - results2[i].state, - ); + for (let i = 0; i < 2; i++) { + mps.dsg_round2_process(results2[otherIndex[i]].msg, results2[i].state); } }); let results3: Array; before("performs round 2", function () { - results3 = [0, 1, 2].map((i) => - mps.dsg_round2_process( - otherIndices[i].map((i) => results2[i].msg), - results2[i].state, - ), + results3 = [0, 1].map((i) => + mps.dsg_round2_process(results2[otherIndex[i]].msg, results2[i].state), ); }); it("performs round 3", function () { - const signatures = [0, 1, 2].map((i) => - mps.dsg_round3_process( - otherIndices[i].map((i) => results3[i].msg), - results3[i].state, - ), + const signatures = [0, 1].map((i) => + mps.dsg_round3_process(results3[otherIndex[i]].msg, results3[i].state), ); - for (let i = 0; i < 3; i++) { - assert(sodium.crypto_sign_verify_detached(signatures[i], message, shares[i].pk)); - } + assert(sodium.crypto_sign_verify_detached(signatures[0], message, shares[0].pk)); + assert(sodium.crypto_sign_verify_detached(signatures[1], message, shares[2].pk)); }); }); });