diff --git a/boring/src/mlkem.rs b/boring/src/mlkem.rs index f5b84147e..d402ae64d 100644 --- a/boring/src/mlkem.rs +++ b/boring/src/mlkem.rs @@ -91,14 +91,14 @@ impl MlKemPrivateKey { pub fn generate(algorithm: Algorithm) -> Result<(MlKemPublicKey, MlKemPrivateKey), ErrorStack> { match algorithm { Algorithm::MlKem768 => { - let (pk, sk) = MlKem768PrivateKey::generate(); + let (pk, sk) = MlKem768PrivateKey::generate()?; Ok(( MlKemPublicKey(Either::MlKem768(pk)), MlKemPrivateKey(Either::MlKem768(sk)), )) } Algorithm::MlKem1024 => { - let (pk, sk) = MlKem1024PrivateKey::generate(); + let (pk, sk) = MlKem1024PrivateKey::generate()?; Ok(( MlKemPublicKey(Either::MlKem1024(pk)), MlKemPrivateKey(Either::MlKem1024(sk)), @@ -131,16 +131,20 @@ impl MlKemPublicKey { /// Encapsulates a shared secret to the given public key, returning /// `(ciphertext, shared_secret)`. pub fn encapsulate(&self) -> Result<(Vec, MlKemSharedSecret), ErrorStack> { - match &self.0 { + let mut ss = [0; SHARED_SECRET_BYTES]; + let ct = match &self.0 { Either::MlKem768(pk) => { - let (ct, ss) = pk.encapsulate(); - Ok((ct.to_vec(), ss)) + let mut ct = vec![0; MlKem768PrivateKey::CIPHERTEXT_BYTES]; + pk.encapsulate_into(ct.as_mut_slice().try_into().unwrap(), &mut ss); + ct } Either::MlKem1024(pk) => { - let (ct, ss) = pk.encapsulate(); - Ok((ct.to_vec(), ss)) + let mut ct = vec![0; MlKem1024PrivateKey::CIPHERTEXT_BYTES]; + pk.encapsulate_into(ct.as_mut_slice().try_into().unwrap(), &mut ss); + ct } - } + }; + Ok((ct, ss)) } /// Query public key and ciphertext length @@ -222,39 +226,27 @@ impl MlKem768PrivateKey { pub const CIPHERTEXT_BYTES: usize = ffi::MLKEM768_CIPHERTEXT_BYTES as usize; /// Generate a new key pair. - #[must_use] - fn generate() -> (Box, Box) { + fn generate() -> Result<(Box, Box), ErrorStack> { // SAFETY: all buffers are out parameters, correctly sized unsafe { ffi::init(); - let mut public_key_bytes: MaybeUninit<[u8; MlKem768PublicKey::PUBLIC_KEY_BYTES]> = - MaybeUninit::uninit(); - let mut seed: MaybeUninit = MaybeUninit::uninit(); + let mut bytes = [0; MlKem768PublicKey::PUBLIC_KEY_BYTES]; + let mut seed = [0; PRIVATE_KEY_SEED_BYTES]; let mut expanded: MaybeUninit = MaybeUninit::uninit(); ffi::MLKEM768_generate_key( - public_key_bytes.as_mut_ptr().cast(), - seed.as_mut_ptr().cast(), + bytes.as_mut_ptr().cast(), + seed.as_mut_ptr(), expanded.as_mut_ptr(), ); - let bytes = public_key_bytes.assume_init(); - - // Parse the public key bytes to get the parsed struct - let mut cbs = cbs_init(&bytes); - let mut parsed: MaybeUninit = MaybeUninit::uninit(); - ffi::MLKEM768_parse_public_key(parsed.as_mut_ptr(), &mut cbs); - - ( - Box::new(MlKem768PublicKey { - bytes, - parsed: parsed.assume_init(), - }), + Ok(( + Box::new(MlKem768PublicKey::from_slice(&bytes)?), Box::new(MlKem768PrivateKey { - seed: seed.assume_init(), + seed, expanded: expanded.assume_init(), }), - ) + )) } } @@ -388,25 +380,19 @@ impl MlKem768PublicKey { } /// Encapsulate: returns (ciphertext, shared_secret). - fn encapsulate( + fn encapsulate_into( &self, - ) -> ( - [u8; MlKem768PrivateKey::CIPHERTEXT_BYTES], - MlKemSharedSecret, + ciphertext: &mut [u8; MlKem768PrivateKey::CIPHERTEXT_BYTES], + shared_secret: &mut MlKemSharedSecret, ) { // SAFETY: buffers correctly sized, parsed key is valid unsafe { ffi::init(); - let mut ciphertext = [0u8; MlKem768PrivateKey::CIPHERTEXT_BYTES]; - let mut shared_secret = [0u8; SHARED_SECRET_BYTES]; - ffi::MLKEM768_encap( ciphertext.as_mut_ptr(), shared_secret.as_mut_ptr(), &self.parsed, ); - - (ciphertext, shared_secret) } } } @@ -439,39 +425,27 @@ impl MlKem1024PrivateKey { pub const CIPHERTEXT_BYTES: usize = ffi::MLKEM1024_CIPHERTEXT_BYTES as usize; /// Generate a new key pair. - #[must_use] - fn generate() -> (Box, Box) { + fn generate() -> Result<(Box, Box), ErrorStack> { // SAFETY: all buffers are out parameters, correctly sized unsafe { ffi::init(); - let mut public_key_bytes: MaybeUninit<[u8; MlKem1024PublicKey::PUBLIC_KEY_BYTES]> = - MaybeUninit::uninit(); - let mut seed: MaybeUninit = MaybeUninit::uninit(); + let mut bytes = [0; MlKem1024PublicKey::PUBLIC_KEY_BYTES]; + let mut seed = [0; PRIVATE_KEY_SEED_BYTES]; let mut expanded: MaybeUninit = MaybeUninit::uninit(); ffi::MLKEM1024_generate_key( - public_key_bytes.as_mut_ptr().cast(), - seed.as_mut_ptr().cast(), + bytes.as_mut_ptr().cast(), + seed.as_mut_ptr(), expanded.as_mut_ptr(), ); - let bytes = public_key_bytes.assume_init(); - - // Parse the public key bytes to get the parsed struct - let mut cbs = cbs_init(&bytes); - let mut parsed: MaybeUninit = MaybeUninit::uninit(); - ffi::MLKEM1024_parse_public_key(parsed.as_mut_ptr(), &mut cbs); - - ( - Box::new(MlKem1024PublicKey { - bytes, - parsed: parsed.assume_init(), - }), + Ok(( + Box::new(MlKem1024PublicKey::from_slice(&bytes)?), Box::new(MlKem1024PrivateKey { - seed: seed.assume_init(), + seed, expanded: expanded.assume_init(), }), - ) + )) } } @@ -607,25 +581,19 @@ impl MlKem1024PublicKey { } /// Encapsulate: returns (ciphertext, shared_secret). - fn encapsulate( + fn encapsulate_into( &self, - ) -> ( - [u8; MlKem1024PrivateKey::CIPHERTEXT_BYTES], - [u8; SHARED_SECRET_BYTES], + ciphertext: &mut [u8; MlKem1024PrivateKey::CIPHERTEXT_BYTES], + shared_secret: &mut [u8; SHARED_SECRET_BYTES], ) { // SAFETY: buffers correctly sized, parsed key is valid unsafe { ffi::init(); - let mut ciphertext = [0u8; MlKem1024PrivateKey::CIPHERTEXT_BYTES]; - let mut shared_secret = [0u8; SHARED_SECRET_BYTES]; - ffi::MLKEM1024_encap( ciphertext.as_mut_ptr(), shared_secret.as_mut_ptr(), &self.parsed, ); - - (ciphertext, shared_secret) } } } @@ -649,24 +617,28 @@ mod tests { #[test] fn roundtrip() { - let (pk, sk) = <$priv>::generate(); - let (ct, ss1) = pk.encapsulate(); + let (pk, sk) = <$priv>::generate().unwrap(); + let mut ct = [0; _]; + let mut ss1 = [0; _]; + pk.encapsulate_into(&mut ct, &mut ss1); let ss2 = sk.decapsulate(&ct); assert_eq!(ss1, ss2); } #[test] fn seed_roundtrip() { - let (pk, sk) = <$priv>::generate(); + let (pk, sk) = <$priv>::generate().unwrap(); let sk2 = <$priv>::from_seed(&sk.seed).unwrap(); - let (ct, ss1) = pk.encapsulate(); + let mut ct = [0; _]; + let mut ss1 = [0; _]; + pk.encapsulate_into(&mut ct, &mut ss1); let ss2 = sk2.decapsulate(&ct); assert_eq!(ss1, ss2); } #[test] fn derive_pubkey() { - let (pk, sk) = <$priv>::generate(); + let (pk, sk) = <$priv>::generate().unwrap(); assert_eq!(pk.bytes, sk.public_key().unwrap().bytes); } @@ -678,14 +650,14 @@ mod tests { #[test] fn from_slice_roundtrip() { - let (pk, _) = <$priv>::generate(); + let (pk, _) = <$priv>::generate().unwrap(); let pk2 = <$pub>::from_slice(&pk.bytes).unwrap(); assert_eq!(pk.bytes, pk2.bytes); } #[test] fn implicit_rejection() { - let (_, sk) = <$priv>::generate(); + let (_, sk) = <$priv>::generate().unwrap(); let bad_ct = [0x42u8; $ct_len]; // bad ciphertext still "works", just returns deterministic garbage let ss1 = sk.decapsulate(&bad_ct); @@ -695,7 +667,7 @@ mod tests { #[test] fn debug_redacts_seed() { - let (_, sk) = <$priv>::generate(); + let (_, sk) = <$priv>::generate().unwrap(); let dbg = format!("{:?}", sk); assert!(dbg.contains("redacted")); }