Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 48 additions & 76 deletions boring/src/mlkem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ impl MlKemPrivateKey {
pub fn generate(algorithm: Algorithm) -> Result<(MlKemPublicKey, MlKemPrivateKey), ErrorStack> {
match algorithm {
Algorithm::MlKem768 => {
let (pk, sk) = MlKem768PrivateKey::generate();
let (pk, sk) = MlKem768PrivateKey::generate()?;
Ok((
MlKemPublicKey(Either::MlKem768(pk)),
MlKemPrivateKey(Either::MlKem768(sk)),
))
}
Algorithm::MlKem1024 => {
let (pk, sk) = MlKem1024PrivateKey::generate();
let (pk, sk) = MlKem1024PrivateKey::generate()?;
Ok((
MlKemPublicKey(Either::MlKem1024(pk)),
MlKemPrivateKey(Either::MlKem1024(sk)),
Expand Down Expand Up @@ -131,16 +131,20 @@ impl MlKemPublicKey {
/// Encapsulates a shared secret to the given public key, returning
/// `(ciphertext, shared_secret)`.
pub fn encapsulate(&self) -> Result<(Vec<u8>, MlKemSharedSecret), ErrorStack> {
match &self.0 {
let mut ss = [0; SHARED_SECRET_BYTES];
let ct = match &self.0 {
Either::MlKem768(pk) => {
let (ct, ss) = pk.encapsulate();
Ok((ct.to_vec(), ss))
let mut ct = vec![0; MlKem768PrivateKey::CIPHERTEXT_BYTES];
pk.encapsulate_into(ct.as_mut_slice().try_into().unwrap(), &mut ss);
ct
}
Either::MlKem1024(pk) => {
let (ct, ss) = pk.encapsulate();
Ok((ct.to_vec(), ss))
let mut ct = vec![0; MlKem1024PrivateKey::CIPHERTEXT_BYTES];
pk.encapsulate_into(ct.as_mut_slice().try_into().unwrap(), &mut ss);
ct
}
}
};
Ok((ct, ss))
}

/// Query public key and ciphertext length
Expand Down Expand Up @@ -222,39 +226,27 @@ impl MlKem768PrivateKey {
pub const CIPHERTEXT_BYTES: usize = ffi::MLKEM768_CIPHERTEXT_BYTES as usize;

/// Generate a new key pair.
#[must_use]
fn generate() -> (Box<MlKem768PublicKey>, Box<MlKem768PrivateKey>) {
fn generate() -> Result<(Box<MlKem768PublicKey>, Box<MlKem768PrivateKey>), ErrorStack> {
// SAFETY: all buffers are out parameters, correctly sized
unsafe {
ffi::init();
let mut public_key_bytes: MaybeUninit<[u8; MlKem768PublicKey::PUBLIC_KEY_BYTES]> =
MaybeUninit::uninit();
let mut seed: MaybeUninit<MlKemPrivateKeySeed> = MaybeUninit::uninit();
let mut bytes = [0; MlKem768PublicKey::PUBLIC_KEY_BYTES];
let mut seed = [0; PRIVATE_KEY_SEED_BYTES];
let mut expanded: MaybeUninit<ffi::MLKEM768_private_key> = MaybeUninit::uninit();

ffi::MLKEM768_generate_key(
public_key_bytes.as_mut_ptr().cast(),
seed.as_mut_ptr().cast(),
bytes.as_mut_ptr().cast(),
seed.as_mut_ptr(),
expanded.as_mut_ptr(),
);

let bytes = public_key_bytes.assume_init();

// Parse the public key bytes to get the parsed struct
let mut cbs = cbs_init(&bytes);
let mut parsed: MaybeUninit<ffi::MLKEM768_public_key> = MaybeUninit::uninit();
ffi::MLKEM768_parse_public_key(parsed.as_mut_ptr(), &mut cbs);

(
Box::new(MlKem768PublicKey {
bytes,
parsed: parsed.assume_init(),
}),
Ok((
Box::new(MlKem768PublicKey::from_slice(&bytes)?),
Box::new(MlKem768PrivateKey {
seed: seed.assume_init(),
seed,
expanded: expanded.assume_init(),
}),
)
))
}
}

Expand Down Expand Up @@ -388,25 +380,19 @@ impl MlKem768PublicKey {
}

/// Encapsulate: returns (ciphertext, shared_secret).
fn encapsulate(
fn encapsulate_into(
&self,
) -> (
[u8; MlKem768PrivateKey::CIPHERTEXT_BYTES],
MlKemSharedSecret,
ciphertext: &mut [u8; MlKem768PrivateKey::CIPHERTEXT_BYTES],
shared_secret: &mut MlKemSharedSecret,
) {
// SAFETY: buffers correctly sized, parsed key is valid
unsafe {
ffi::init();
let mut ciphertext = [0u8; MlKem768PrivateKey::CIPHERTEXT_BYTES];
let mut shared_secret = [0u8; SHARED_SECRET_BYTES];

ffi::MLKEM768_encap(
ciphertext.as_mut_ptr(),
shared_secret.as_mut_ptr(),
&self.parsed,
);

(ciphertext, shared_secret)
}
}
}
Expand Down Expand Up @@ -439,39 +425,27 @@ impl MlKem1024PrivateKey {
pub const CIPHERTEXT_BYTES: usize = ffi::MLKEM1024_CIPHERTEXT_BYTES as usize;

/// Generate a new key pair.
#[must_use]
fn generate() -> (Box<MlKem1024PublicKey>, Box<MlKem1024PrivateKey>) {
fn generate() -> Result<(Box<MlKem1024PublicKey>, Box<MlKem1024PrivateKey>), ErrorStack> {
// SAFETY: all buffers are out parameters, correctly sized
unsafe {
ffi::init();
let mut public_key_bytes: MaybeUninit<[u8; MlKem1024PublicKey::PUBLIC_KEY_BYTES]> =
MaybeUninit::uninit();
let mut seed: MaybeUninit<MlKemPrivateKeySeed> = MaybeUninit::uninit();
let mut bytes = [0; MlKem1024PublicKey::PUBLIC_KEY_BYTES];
let mut seed = [0; PRIVATE_KEY_SEED_BYTES];
let mut expanded: MaybeUninit<ffi::MLKEM1024_private_key> = MaybeUninit::uninit();

ffi::MLKEM1024_generate_key(
public_key_bytes.as_mut_ptr().cast(),
seed.as_mut_ptr().cast(),
bytes.as_mut_ptr().cast(),
seed.as_mut_ptr(),
expanded.as_mut_ptr(),
);

let bytes = public_key_bytes.assume_init();

// Parse the public key bytes to get the parsed struct
let mut cbs = cbs_init(&bytes);
let mut parsed: MaybeUninit<ffi::MLKEM1024_public_key> = MaybeUninit::uninit();
ffi::MLKEM1024_parse_public_key(parsed.as_mut_ptr(), &mut cbs);

(
Box::new(MlKem1024PublicKey {
bytes,
parsed: parsed.assume_init(),
}),
Ok((
Box::new(MlKem1024PublicKey::from_slice(&bytes)?),
Box::new(MlKem1024PrivateKey {
seed: seed.assume_init(),
seed,
expanded: expanded.assume_init(),
}),
)
))
}
}

Expand Down Expand Up @@ -607,25 +581,19 @@ impl MlKem1024PublicKey {
}

/// Encapsulate: returns (ciphertext, shared_secret).
fn encapsulate(
fn encapsulate_into(
&self,
) -> (
[u8; MlKem1024PrivateKey::CIPHERTEXT_BYTES],
[u8; SHARED_SECRET_BYTES],
ciphertext: &mut [u8; MlKem1024PrivateKey::CIPHERTEXT_BYTES],
shared_secret: &mut [u8; SHARED_SECRET_BYTES],
) {
// SAFETY: buffers correctly sized, parsed key is valid
unsafe {
ffi::init();
let mut ciphertext = [0u8; MlKem1024PrivateKey::CIPHERTEXT_BYTES];
let mut shared_secret = [0u8; SHARED_SECRET_BYTES];

ffi::MLKEM1024_encap(
ciphertext.as_mut_ptr(),
shared_secret.as_mut_ptr(),
&self.parsed,
);

(ciphertext, shared_secret)
}
}
}
Expand All @@ -649,24 +617,28 @@ mod tests {

#[test]
fn roundtrip() {
let (pk, sk) = <$priv>::generate();
let (ct, ss1) = pk.encapsulate();
let (pk, sk) = <$priv>::generate().unwrap();
let mut ct = [0; _];
let mut ss1 = [0; _];
pk.encapsulate_into(&mut ct, &mut ss1);
let ss2 = sk.decapsulate(&ct);
assert_eq!(ss1, ss2);
}

#[test]
fn seed_roundtrip() {
let (pk, sk) = <$priv>::generate();
let (pk, sk) = <$priv>::generate().unwrap();
let sk2 = <$priv>::from_seed(&sk.seed).unwrap();
let (ct, ss1) = pk.encapsulate();
let mut ct = [0; _];
let mut ss1 = [0; _];
pk.encapsulate_into(&mut ct, &mut ss1);
let ss2 = sk2.decapsulate(&ct);
assert_eq!(ss1, ss2);
}

#[test]
fn derive_pubkey() {
let (pk, sk) = <$priv>::generate();
let (pk, sk) = <$priv>::generate().unwrap();
assert_eq!(pk.bytes, sk.public_key().unwrap().bytes);
}

Expand All @@ -678,14 +650,14 @@ mod tests {

#[test]
fn from_slice_roundtrip() {
let (pk, _) = <$priv>::generate();
let (pk, _) = <$priv>::generate().unwrap();
let pk2 = <$pub>::from_slice(&pk.bytes).unwrap();
assert_eq!(pk.bytes, pk2.bytes);
}

#[test]
fn implicit_rejection() {
let (_, sk) = <$priv>::generate();
let (_, sk) = <$priv>::generate().unwrap();
let bad_ct = [0x42u8; $ct_len];
// bad ciphertext still "works", just returns deterministic garbage
let ss1 = sk.decapsulate(&bad_ct);
Expand All @@ -695,7 +667,7 @@ mod tests {

#[test]
fn debug_redacts_seed() {
let (_, sk) = <$priv>::generate();
let (_, sk) = <$priv>::generate().unwrap();
let dbg = format!("{:?}", sk);
assert!(dbg.contains("redacted"));
}
Expand Down
Loading