Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ pub enum PkEncryptionError {
InvalidKeySize(usize),
#[error(transparent)]
Decode(#[from] vodozemac::pk_encryption::Error),
#[error(transparent)]
Mac(#[from] vodozemac::Base64DecodeError),
}

pyo3::create_exception!(module, PkInvalidKeySizeException, pyo3::exceptions::PyValueError);
Expand All @@ -148,6 +150,7 @@ impl From<PkEncryptionError> for PyErr {
PkInvalidKeySizeException::new_err(e.to_string())
}
PkEncryptionError::Decode(_) => PkDecodeException::new_err(e.to_string()),
PkEncryptionError::Mac(_) => PkDecodeException::new_err(e.to_string()),
}
}
}
60 changes: 60 additions & 0 deletions src/pk_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,77 @@ use crate::{
#[pyclass]
pub struct Message {
/// The ciphertext of the message.
#[pyo3(get)]
ciphertext: Vec<u8>,
/// The message authentication code of the message.
///
/// *Warning*: As stated in the module description, this does not
/// authenticate the message.
#[pyo3(get)]
mac: Vec<u8>,
/// The ephemeral Curve25519PublicKey of the message which was used to
/// derive the individual message key.
#[pyo3(get)]
ephemeral_key: Vec<u8>,
}

#[pymethods]
impl Message {
/// Create a new Message object from its components.
///
/// This constructor creates a Message object that represents an encrypted
/// message using the `m.megolm_backup.v1.curve25519-aes-sha2`
/// algorithm.
///
/// # Arguments
/// * `ciphertext` - The encrypted content of the message
/// * `mac` - The message authentication code
/// * `ephemeral_key` - The ephemeral public key used during encryption
#[new]
fn new(ciphertext: Vec<u8>, mac: Vec<u8>, ephemeral_key: Vec<u8>) -> Self {
Message { ciphertext, mac, ephemeral_key }
}

/// Create a new Message object from unpadded Base64-encoded components.
///
/// This function decodes the given Base64 strings and returns a `Message`
/// with the resulting byte vectors.
///
/// # Arguments
/// * `ciphertext` - Unpadded Base64-encoded ciphertext
/// * `mac` - Unpadded Base64-encoded message authentication code
/// * `ephemeral_key` - Unpadded Base64-encoded ephemeral key
#[classmethod]
fn from_base64(
_cls: &Bound<'_, PyType>,
ciphertext: &str,
mac: &str,
ephemeral_key: &str,
) -> Result<Self, PkEncryptionError> {
let decoded_ciphertext = vodozemac::base64_decode(ciphertext)?;
let decoded_mac = vodozemac::base64_decode(mac)?;
let decoded_ephemeral_key = vodozemac::base64_decode(ephemeral_key)?;

Ok(Self {
ciphertext: decoded_ciphertext,
mac: decoded_mac,
ephemeral_key: decoded_ephemeral_key,
})
}

/// Convert the message components to unpadded Base64-encoded strings.
///
/// Returns a tuple of (ciphertext, mac, ephemeral_key) as unpadded Base64
/// strings.
fn to_base64(&self) -> Result<(String, String, String), PkEncryptionError> {
let ciphertext_b64 = vodozemac::base64_encode(&self.ciphertext);
let mac_b64 = vodozemac::base64_encode(&self.mac);
let ephemeral_key_b64 = vodozemac::base64_encode(&self.ephemeral_key);

Ok((ephemeral_key_b64, mac_b64, ciphertext_b64))
}
}

/// ☣️ Compat support for libolm's PkDecryption.
///
/// This implements the `m.megolm_backup.v1.curve25519-aes-sha2` described in
Expand Down
52 changes: 49 additions & 3 deletions tests/pk_encryption_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import importlib
import pytest
import base64

from vodozemac import Curve25519SecretKey, Curve25519PublicKey, PkEncryption, PkDecryption, PkDecodeException
import pytest
from vodozemac import (
Curve25519PublicKey,
Curve25519SecretKey,
Message,
PkDecodeException,
PkDecryption,
PkEncryption,
)

CLEARTEXT = b"test"


class TestClass(object):
def test_encrypt_decrypt(self):
d = PkDecryption()
Expand All @@ -28,3 +36,41 @@ def test_encrypt_decrypt_with_serialized_keys(self):

decoded = d.decrypt(e.encrypt(CLEARTEXT))
assert decoded == CLEARTEXT

def test_encrypt_message_attr(self):
"""Test that the Message object has accessible Python attributes (mac, ciphertext, ephemeral_key)."""
decryption = PkDecryption()
encryption = PkEncryption.from_key(decryption.public_key)

message = encryption.encrypt(CLEARTEXT)

assert message.mac is not None
assert message.ciphertext is not None
assert message.ephemeral_key is not None


def test_message_from_invalid_base64(self):
"""Test that invalid base64 input raises PkDecodeException."""
# Test invalid ciphertext
with pytest.raises(PkDecodeException, match="Invalid symbol"):
Message.from_base64(
"not-valid-base64!@#", # Invalid base64 for ciphertext
base64.b64encode(b"some_mac").decode(), # Valid base64
base64.b64encode(b"some_key").decode() # Valid base64
)

# Test invalid mac
with pytest.raises(PkDecodeException, match="Invalid symbol"):
Message.from_base64(
base64.b64encode(b"some_text").decode(),
"not-valid-base64!@#", # Invalid base64 for mac
base64.b64encode(b"some_key").decode()
)

# Test invalid ephemeral key
with pytest.raises(PkDecodeException, match="Invalid symbol"):
Message.from_base64(
base64.b64encode(b"some_text").decode(),
base64.b64encode(b"some_mac").decode(),
"not-valid-base64!@#" # Invalid base64 for ephemeral key
)