diff --git a/pyproject.toml b/pyproject.toml index 8c657e6..2a6109f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ nox = [ "nox-uv>=0.6.0", ] tests = [ + "hypothesis>=6.135.21", "maturin>=1.9.0", "pytest>=8.4.1", ] diff --git a/src/lib.rs b/src/lib.rs index 0894407..0984e7c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ fn my_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tests/account_test.py b/tests/account_test.py index 4e9b177..1632b5e 100644 --- a/tests/account_test.py +++ b/tests/account_test.py @@ -1,64 +1,49 @@ -import vodozemac import pytest - -from vodozemac import Account, PickleException, SignatureException - -PICKLE_KEY = b"DEFAULT_PICKLE_KEY_1234567890___" - -class TestClass(object): - def test_account_creation(self): - account = Account() - - assert account.ed25519_key - assert account.curve25519_key - - def test_generating_onet_time_keys(self): - account = Account() - - assert len(account.one_time_keys) == 0 - - account.generate_one_time_keys(10) - assert len(account.one_time_keys) == 10 - - def test_pickling(self): - alice = Account() - pickle = alice.pickle(PICKLE_KEY) - unpickled = Account.from_pickle(pickle, PICKLE_KEY) - assert (alice.ed25519_key == unpickled.ed25519_key) - - def test_libolm_pickling(self): - pickle = ( - "3wpPcPT4xsRYCYF34NcnozxE5bN2E6qwBXQYuoovt/TX//8Dnd8gaKsxN9En/" - "7Hkh5XemuGUo3dXHVTl76G2pjf9ehfryhITMbeBrE/XuxmNvS2aB9KU4mOKXl" - "AWhCEsE7JW9fUkRhHWWkFwTvSC3eDthd6eNx3VKZlmGR270vIpIG5/Ho4YK9/" - "03lPGpil0cuEuGTTjKHXGRu9kpnQe99QGCB4KBuP5IJjFeWbtSgJ4ZrajZdlTew" - ) - - unpickled = Account.from_libolm_pickle(pickle, b"It's a secret to everybody") - - assert unpickled.ed25519_key.to_base64() == "MEQCwaTE/gcrHaxwv06WEVy5xDA30FboFzCAtYhzmoc" - - def test_invalid_pickle(self): - with pytest.raises(PickleException): - Account.from_pickle("", PICKLE_KEY) - - def test_max_one_time_keys(self): - alice = Account() - assert isinstance(alice.max_number_of_one_time_keys, int) - - def test_publish_one_time_keys(self): - alice = Account() - alice.generate_one_time_keys(10) - - assert len(alice.one_time_keys) == 10 - - alice.mark_keys_as_published() - assert not alice.one_time_keys - - def test_signing(self): - alice = Account() - signature = alice.sign(b"This is a test") - - alice.ed25519_key.verify_signature(b"This is a test", signature) - with pytest.raises(SignatureException): - alice.ed25519_key.verify_signature(b"This should fail", signature) +from hypothesis import given +from vodozemac import Account, PickleException, SignatureException, Ed25519PublicKey, Curve25519PublicKey + +@pytest.fixture(scope="module") +def account() -> Account: + return Account() + +def test_creation(account: Account): + assert isinstance(account.ed25519_key, Ed25519PublicKey) + assert isinstance(account.curve25519_key, Curve25519PublicKey) + assert isinstance(account.max_number_of_one_time_keys, int) + +def test_generate_and_publish_one_time_keys(account: Account): + assert len(account.one_time_keys) == 0 + account.generate_one_time_keys(10) + assert len(account.one_time_keys) == 10 + account.mark_keys_as_published() + assert not account.one_time_keys + +def test_pickling(account: Account, pickle_key: bytes): + pickle = account.pickle(pickle_key) + unpickled = Account.from_pickle(pickle, pickle_key) + assert account.ed25519_key == unpickled.ed25519_key + assert account.curve25519_key == unpickled.curve25519_key + assert account.one_time_keys == unpickled.one_time_keys + +def test_libolm_pickling(): + pickle = ( + "3wpPcPT4xsRYCYF34NcnozxE5bN2E6qwBXQYuoovt/TX//8Dnd8gaKsxN9En/" + "7Hkh5XemuGUo3dXHVTl76G2pjf9ehfryhITMbeBrE/XuxmNvS2aB9KU4mOKXl" + "AWhCEsE7JW9fUkRhHWWkFwTvSC3eDthd6eNx3VKZlmGR270vIpIG5/Ho4YK9/" + "03lPGpil0cuEuGTTjKHXGRu9kpnQe99QGCB4KBuP5IJjFeWbtSgJ4ZrajZdlTew" + ) + + unpickled = Account.from_libolm_pickle(pickle, b"It's a secret to everybody") + + assert unpickled.ed25519_key.to_base64() == "MEQCwaTE/gcrHaxwv06WEVy5xDA30FboFzCAtYhzmoc" + +def test_invalid_pickle(pickle_key: bytes): + with pytest.raises(PickleException): + Account.from_pickle("", pickle_key) + +@given(message=...) +def test_signing(account: Account, message: bytes): + signature = account.sign(message) + account.ed25519_key.verify_signature(message, signature) + with pytest.raises(SignatureException): + account.ed25519_key.verify_signature(b"This should fail", signature) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..382ed60 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture(scope="session") +def pickle_key(): + return b"DEFAULT_PICKLE_KEY_1234567890___" diff --git a/tests/group_session_test.py b/tests/group_session_test.py index 5fe160b..c647aa4 100644 --- a/tests/group_session_test.py +++ b/tests/group_session_test.py @@ -1,114 +1,89 @@ import pytest - +from hypothesis import given from vodozemac import ( - InboundGroupSession, GroupSession, - PickleException, - DecodeException, + InboundGroupSession, MegolmDecryptionException, + PickleException, ) -PICKLE_KEY = b"DEFAULT_PICKLE_KEY_1234567890___" +@pytest.fixture(scope="module") +def group_session() -> GroupSession: + return GroupSession() + +@pytest.fixture(scope="module") +def inbound_group_session(group_session: GroupSession) -> InboundGroupSession: + return InboundGroupSession(group_session.session_key) + + +def test_create(group_session: GroupSession, inbound_group_session: InboundGroupSession): + assert isinstance(group_session.session_id, str) + assert isinstance(group_session.message_index, int) + assert group_session.message_index == 0 + assert isinstance(inbound_group_session.first_known_index, int) + assert inbound_group_session.first_known_index == 0 -class TestClass(object): - def test_session_create(self): - GroupSession() + assert group_session.session_id == inbound_group_session.session_id - def test_session_id(self): - session = GroupSession() - assert isinstance(session.session_id, str) +def test_outbound_pickle(group_session: GroupSession, pickle_key: bytes): + pickle = group_session.pickle(pickle_key) + unpickled = GroupSession.from_pickle(pickle, pickle_key) - def test_session_index(self): - session = GroupSession() - assert isinstance(session.message_index, int) - assert session.message_index == 0 + assert group_session.session_id == unpickled.session_id - def test_outbound_pickle(self): - session = GroupSession() - pickle = session.pickle(PICKLE_KEY) - unpickled = GroupSession.from_pickle(pickle, PICKLE_KEY) +def test_outbound_pickle_fail(group_session: GroupSession, pickle_key: bytes): + wrong_pickle_key = b"It's a secret to everybody 12345" + pickle = group_session.pickle(wrong_pickle_key) - assert session.session_id == unpickled.session_id + with pytest.raises(ValueError): + GroupSession.from_pickle(pickle, pickle_key) - def test_invalid_unpickle(self): - with pytest.raises(PickleException): - GroupSession.from_pickle("", PICKLE_KEY) +@pytest.mark.parametrize("cls", (GroupSession, InboundGroupSession)) +def test_invalid_pickle(cls: type, pickle_key: bytes): + with pytest.raises(PickleException): + cls.from_pickle("", pickle_key) - with pytest.raises(PickleException): - InboundGroupSession.from_pickle("", PICKLE_KEY) - def test_inbound_create(self): - outbound = GroupSession() - InboundGroupSession(outbound.session_key) +def test_inbound_create(inbound_group_session: InboundGroupSession, pickle_key: bytes): + pickle = inbound_group_session.pickle(pickle_key) + unpickled = InboundGroupSession.from_pickle(pickle, pickle_key) + assert unpickled.session_id == inbound_group_session.session_id - def test_inbound_pickle(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - pickle = inbound.pickle(PICKLE_KEY) - InboundGroupSession.from_pickle(pickle, PICKLE_KEY) +@given(message1=..., message2=...) +def test_encrypt_twice(group_session: GroupSession, inbound_group_session: InboundGroupSession, message1: bytes, message2: bytes): + decrypted1 = inbound_group_session.decrypt(group_session.encrypt(message1)) + assert decrypted1.plaintext == message1 - def test_inbound_export(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - imported = InboundGroupSession.import_session( - inbound.export_at(inbound.first_known_index) + decrypted2 = inbound_group_session.decrypt(group_session.encrypt(message2)) + assert decrypted2.plaintext == message2 + + assert decrypted2.message_index == decrypted1.message_index + 1 + +def test_decrypt_failure(inbound_group_session: InboundGroupSession): + wrong_group_session = GroupSession() + with pytest.raises(MegolmDecryptionException): + inbound_group_session.decrypt(wrong_group_session.encrypt(b"Test")) + + +@given(message=...) +def test_inbound_export(group_session: GroupSession, inbound_group_session: InboundGroupSession, message: bytes): + imported = InboundGroupSession.import_session( + session_key=inbound_group_session.export_at( + index=inbound_group_session.first_known_index ) - message = imported.decrypt(outbound.encrypt(b"Test")) - assert message.plaintext == b"Test" - assert message.message_index == 0 - - def test_first_index(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - index = inbound.first_known_index - assert index == 0 - assert isinstance(index, int) - - def test_encrypt(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - message = inbound.decrypt(outbound.encrypt(b"Test")) - assert b"Test", 0 == inbound.decrypt(outbound.encrypt(b"Test")) - - def test_decrypt_twice(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - outbound.encrypt(b"Test 1") - message = inbound.decrypt(outbound.encrypt(b"Test 2")) - assert isinstance(message.message_index, int) - assert message.message_index == 1 - assert message.plaintext == b"Test 2" - - def test_decrypt_failure(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - eve_outbound = GroupSession() - with pytest.raises(MegolmDecryptionException): - inbound.decrypt(eve_outbound.encrypt(b"Test")) - - def test_id(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - assert outbound.session_id == inbound.session_id - - def test_inbound_fail(self): - with pytest.raises(TypeError): - InboundGroupSession() - - def test_outbound_pickle_fail(self): - outbound = GroupSession() - pickle_key = b"It's a secret to everybody 12345" - pickle = outbound.pickle(pickle_key) - - with pytest.raises(ValueError): - GroupSession.from_pickle(pickle, PICKLE_KEY) - - def test_outbound_clear(self): - session = GroupSession() - del session - - def test_inbound_clear(self): - outbound = GroupSession() - inbound = InboundGroupSession(outbound.session_key) - del inbound + ) + index = group_session.message_index + decrypted = imported.decrypt(group_session.encrypt(message)) + + assert decrypted.plaintext == message + assert decrypted.message_index == index + +def test_outbound_clear(): + session = GroupSession() + del session + +def test_inbound_clear(): + outbound = GroupSession() + inbound = InboundGroupSession(outbound.session_key) + del inbound diff --git a/tests/pk_encryption_test.py b/tests/pk_encryption_test.py index 81bf26b..40eefca 100644 --- a/tests/pk_encryption_test.py +++ b/tests/pk_encryption_test.py @@ -1,76 +1,82 @@ -import base64 - import pytest +from base64 import b64encode +from hypothesis import given from vodozemac import ( + Message, Curve25519PublicKey, Curve25519SecretKey, - Message, PkDecodeException, PkDecryption, PkEncryption, ) -CLEARTEXT = b"test" - - -class TestClass(object): - def test_encrypt_decrypt(self): - d = PkDecryption() - e = PkEncryption.from_key(d.public_key) - - decoded = d.decrypt(e.encrypt(CLEARTEXT)) - assert decoded == CLEARTEXT - - def test_encrypt_decrypt_with_wrong_key(self): - wrong_e = PkEncryption.from_key(PkDecryption().public_key) - with pytest.raises(PkDecodeException, match="MAC tag mismatch"): - PkDecryption().decrypt(wrong_e.encrypt(CLEARTEXT)) - - def test_encrypt_decrypt_with_serialized_keys(self): - secret_key = Curve25519SecretKey() - secret_key_bytes = secret_key.to_bytes() - public_key_bytes = secret_key.public_key().to_bytes() - - d = PkDecryption.from_key(Curve25519SecretKey.from_bytes(secret_key_bytes)) - e = PkEncryption.from_key(Curve25519PublicKey.from_bytes(public_key_bytes)) - - decoded = d.decrypt(e.encrypt(CLEARTEXT)) - assert decoded == CLEARTEXT - - def test_encrypt_message_attr(self): - """Test that the Message object has accessible Python attributes (mac, ciphertext, ephemeral_key).""" - decryption = PkDecryption() - encryption = PkEncryption.from_key(decryption.public_key) - - message = encryption.encrypt(CLEARTEXT) - - assert message.mac is not None - assert message.ciphertext is not None - assert message.ephemeral_key is not None - - - def test_message_from_invalid_base64(self): - """Test that invalid base64 input raises PkDecodeException.""" - # Test invalid ciphertext - with pytest.raises(PkDecodeException, match="Invalid symbol"): - Message.from_base64( - "not-valid-base64!@#", # Invalid base64 for ciphertext - base64.b64encode(b"some_mac").decode(), # Valid base64 - base64.b64encode(b"some_key").decode() # Valid base64 - ) - - # Test invalid mac - with pytest.raises(PkDecodeException, match="Invalid symbol"): - Message.from_base64( - base64.b64encode(b"some_text").decode(), - "not-valid-base64!@#", # Invalid base64 for mac - base64.b64encode(b"some_key").decode() - ) - - # Test invalid ephemeral key - with pytest.raises(PkDecodeException, match="Invalid symbol"): - Message.from_base64( - base64.b64encode(b"some_text").decode(), - base64.b64encode(b"some_mac").decode(), - "not-valid-base64!@#" # Invalid base64 for ephemeral key - ) + +@pytest.fixture(scope="module") +def pk_decryption() -> PkDecryption: + return PkDecryption() + +@pytest.fixture(scope="module") +def pk_encryption(pk_decryption: PkDecryption) -> PkEncryption: + return PkEncryption.from_key(pk_decryption.public_key) + +@pytest.fixture(scope="module") +def secret_key() -> Curve25519SecretKey: + return Curve25519SecretKey() + +@given(cleartext=...) +def test_round_trip(pk_decryption: PkDecryption, pk_encryption: PkEncryption, cleartext: bytes): + assert cleartext == pk_decryption.decrypt(pk_encryption.encrypt(cleartext)) + +@given(cleartext=...) +def test_wrong_key(pk_decryption: PkDecryption, pk_encryption: PkEncryption, cleartext: bytes): + with pytest.raises(PkDecodeException, match="MAC tag mismatch"): + PkDecryption().decrypt(pk_encryption.encrypt(cleartext)) + +@given(cleartext=...) +def test_serialized_keys(secret_key: Curve25519SecretKey, cleartext: bytes): + secret_key_bytes = secret_key.to_bytes() + public_key_bytes = secret_key.public_key().to_bytes() + + d = PkDecryption.from_key(Curve25519SecretKey.from_bytes(secret_key_bytes)) + e = PkEncryption.from_key(Curve25519PublicKey.from_bytes(public_key_bytes)) + + assert cleartext == d.decrypt(e.encrypt(cleartext)) + +@given(cleartext=...) +def test_encrypt_message_attr(cleartext: bytes): + """Test that the Message object has accessible Python attributes (mac, ciphertext, ephemeral_key).""" + decryption = PkDecryption() + encryption = PkEncryption.from_key(decryption.public_key) + + message = encryption.encrypt(cleartext) + + assert message.mac is not None + assert message.ciphertext is not None + assert message.ephemeral_key is not None + + +def test_message_from_invalid_base64(): + """Test that invalid base64 input raises PkDecodeException.""" + # Test invalid ciphertext + with pytest.raises(PkDecodeException, match="Invalid symbol"): + Message.from_base64( + "not-valid-base64!@#", # Invalid base64 for ciphertext + b64encode(b"some_mac").decode(), # Valid base64 + b64encode(b"some_key").decode() # Valid base64 + ) + + # Test invalid mac + with pytest.raises(PkDecodeException, match="Invalid symbol"): + Message.from_base64( + b64encode(b"some_text").decode(), + "not-valid-base64!@#", # Invalid base64 for mac + b64encode(b"some_key").decode() + ) + + # Test invalid ephemeral key + with pytest.raises(PkDecodeException, match="Invalid symbol"): + Message.from_base64( + b64encode(b"some_text").decode(), + b64encode(b"some_mac").decode(), + "not-valid-base64!@#" # Invalid base64 for ephemeral key + ) diff --git a/tests/sas_test.py b/tests/sas_test.py index 6ee5b77..206ee12 100644 --- a/tests/sas_test.py +++ b/tests/sas_test.py @@ -1,48 +1,44 @@ -import pytest - -from vodozemac import SasException, Sas - -MESSAGE = "Test message" -EXTRA_INFO = "extra_info" +from typing import Final +import pytest +from vodozemac import Sas, EstablishedSas +from vodozemac.vodozemac import Curve25519PublicKey -class TestClass(object): - def test_sas_creation(self): - sas = Sas() - assert sas.public_key - - def test_other_key_setting(self): - sas_alice = Sas() - sas_bob = Sas() +EXTRA_INFO: Final[str] = "extra_info" - established = sas_alice.diffie_hellman(sas_bob.public_key) +@pytest.fixture +def alice_sas() -> Sas: + return Sas() - def test_bytes_generating(self): - sas_alice = Sas() - sas_bob = Sas() +@pytest.fixture +def bob_sas() -> Sas: + return Sas() - bob_public_key = sas_bob.public_key - sas_bob = sas_bob.diffie_hellman(sas_alice.public_key) - sas_alice = sas_alice.diffie_hellman(bob_public_key) +@pytest.fixture +def alice_established_sas(alice_sas: Sas, bob_sas: Sas) -> EstablishedSas: + return alice_sas.diffie_hellman(bob_sas.public_key) - alice_bytes = sas_alice.bytes(EXTRA_INFO) - bob_bytes = sas_bob.bytes(EXTRA_INFO) +@pytest.fixture +def bob_established_sas(alice_sas: Sas, bob_sas: Sas) -> EstablishedSas: + return bob_sas.diffie_hellman(alice_sas.public_key) - assert alice_bytes.emoji_indices == bob_bytes.emoji_indices - assert alice_bytes.decimals == bob_bytes.decimals +def test_creation(alice_sas: Sas, alice_established_sas): + assert isinstance(alice_sas.public_key, Curve25519PublicKey) + assert isinstance(alice_established_sas, EstablishedSas) - def test_mac_generating(self): - sas_alice = Sas() - sas_bob = Sas() +def test_bytes_generating(alice_sas: Sas, bob_sas: Sas): + alice_bytes = alice_sas.diffie_hellman(bob_sas.public_key).bytes(info=EXTRA_INFO) + bob_bytes = bob_sas.diffie_hellman(alice_sas.public_key).bytes(info=EXTRA_INFO) - bob_public_key = sas_bob.public_key - sas_bob = sas_bob.diffie_hellman(sas_alice.public_key) - sas_alice = sas_alice.diffie_hellman(bob_public_key) + assert alice_bytes.emoji_indices == bob_bytes.emoji_indices + assert alice_bytes.decimals == bob_bytes.decimals - alice_mac = sas_alice.calculate_mac(MESSAGE, EXTRA_INFO) - bob_mac = sas_bob.calculate_mac(MESSAGE, EXTRA_INFO) +def test_mac_generating(alice_established_sas: EstablishedSas, bob_established_sas: EstablishedSas): + message = "Test message" + alice_mac = alice_established_sas.calculate_mac(message, EXTRA_INFO) + bob_mac = bob_established_sas.calculate_mac(message, EXTRA_INFO) - sas_alice.verify_mac(MESSAGE, EXTRA_INFO, bob_mac) - sas_bob.verify_mac(MESSAGE, EXTRA_INFO, alice_mac) + assert alice_established_sas.verify_mac(message, EXTRA_INFO, bob_mac) is None + assert bob_established_sas.verify_mac(message, EXTRA_INFO, alice_mac) is None - assert alice_mac == bob_mac + assert alice_mac == bob_mac diff --git a/tests/session_test.py b/tests/session_test.py index a31479c..f77faeb 100644 --- a/tests/session_test.py +++ b/tests/session_test.py @@ -1,145 +1,115 @@ -import pytest +from collections.abc import Generator +import pytest from vodozemac import ( Account, AnyOlmMessage, DecodeException, Session, PickleException, - KeyException, + PreKeyMessage ) -PICKLE_KEY = b"DEFAULT_PICKLE_KEY_1234567890___" - - -class TestClass(object): - def _create_session(self): - alice = Account() - bob = Account() - bob.generate_one_time_keys(1) - - identity_key = bob.curve25519_key - one_time_key = list(bob.one_time_keys.values())[0] - - session = alice.create_outbound_session(identity_key, one_time_key) - - return alice, bob, session - - def test_session_create(self): - _, _, session_1 = self._create_session() - _, _, session_2 = self._create_session() - assert session_1 - assert session_2 - assert session_1.session_id != session_2.session_id - assert isinstance(session_1.session_id, str) - - def test_session_clear(self): - _, _, session = self._create_session() - del session - - def test_session_pickle(self): - alice, bob, session = self._create_session() - unpickled = Session.from_pickle(session.pickle(PICKLE_KEY), PICKLE_KEY) - assert unpickled.session_id == session.session_id +@pytest.fixture(scope="module") +def alice() -> Account: + return Account() - def test_session_invalid_pickle(self): - with pytest.raises(PickleException): - Session.from_pickle("", PICKLE_KEY) +@pytest.fixture(scope="module") +def bob() -> Account: + return Account() - def test_wrong_passphrase_pickle(self): - alice, bob, session = self._create_session() - pickle_key = b"It's a secret to everybody 12345" - pickle = session.pickle(pickle_key) +type SessionGenerator = Generator[Session] - with pytest.raises(PickleException): - Session.from_pickle(pickle, PICKLE_KEY) +@pytest.fixture +def alice_session_gen(alice: Account, bob: Account) -> SessionGenerator: + def session_generator() -> SessionGenerator: + while True: + bob.generate_one_time_keys(1) + identity_key = bob.curve25519_key + one_time_key = next(iter(bob.one_time_keys.values())) + yield alice.create_outbound_session(identity_key, one_time_key) + return session_generator() - def test_encrypt(self): - plaintext = b"It's a secret to everybody" - alice, bob, session = self._create_session() - message = session.encrypt(plaintext) +@pytest.fixture +def alice_session(alice_session_gen: SessionGenerator) -> Session: + return next(alice_session_gen) - message = message.to_pre_key() - assert message != None +def test_create(alice_session_gen: SessionGenerator): + session1, session2 = next(alice_session_gen), next(alice_session_gen) + assert session1.session_id != session2.session_id + for session in (session1, session2): + assert isinstance(session, Session) + assert isinstance(session.session_id, str) - (bob_session, decrypted) = bob.create_inbound_session( - alice.curve25519_key, message - ) - assert plaintext == decrypted +def test_clear(alice_session: Session): + del alice_session - def test_empty_message(self): - with pytest.raises(DecodeException): - AnyOlmMessage.from_parts(0, b"x") +def test_pickle(alice_session: Session, pickle_key: bytes): + unpickled = Session.from_pickle(alice_session.pickle(pickle_key), pickle_key) + assert unpickled.session_id == alice_session.session_id - def test_two_messages(self): - plaintext = b"It's a secret to everybody" - alice, bob, session = self._create_session() - message = session.encrypt(plaintext) - message = message.to_pre_key() +def test_wrong_pickle_key(alice_session: Session, pickle_key: bytes): + pickle = alice_session.pickle(pickle_key) + with pytest.raises(PickleException): + Session.from_pickle(pickle, b"Definitely wrong key") - (bob_session, decrypted) = bob.create_inbound_session( - alice.curve25519_key, message - ) - assert plaintext == decrypted +def test_invalid_pickle(pickle_key: bytes): + with pytest.raises(PickleException): + Session.from_pickle("", pickle_key) - bob_plaintext = b"Grumble, Grumble" - bob_message = bob_session.encrypt(bob_plaintext) +def test_two_messages(alice: Account, bob: Account, alice_session: Session): + alice_plaintext = b"It's a secret to everybody" + alice_message = alice_session.encrypt(alice_plaintext).to_pre_key() - assert bob_plaintext == session.decrypt(bob_message) + assert isinstance(alice_message, PreKeyMessage) - def test_matches(self): - plaintext = b"It's a secret to everybody" - alice, bob, session = self._create_session() - message = session.encrypt(plaintext) - message = message.to_pre_key() + bob_session, alice_decrypted = bob.create_inbound_session(alice.curve25519_key, alice_message) + assert alice_plaintext == alice_decrypted - (bob_session, decrypted) = bob.create_inbound_session( - alice.curve25519_key, message - ) - assert plaintext == decrypted + bob_plaintext = b"Grumble, Grumble" + bob_message = bob_session.encrypt(bob_plaintext) - message2 = session.encrypt(b"Hey! Listen!") - message2 = message2.to_pre_key() + assert bob_plaintext == alice_session.decrypt(bob_message) - assert bob_session.session_matches(message2) is True +def test_empty_message(): + with pytest.raises(DecodeException): + AnyOlmMessage.from_parts(0, b"x") - def test_invalid(self): - alice, bob, session = self._create_session() - _, _, another_session = self._create_session() +def test_matches(alice: Account, bob: Account, alice_session_gen: SessionGenerator): + alice_plaintext = b"It's a secret to everybody" + alice_session1, alice_session2 = next(alice_session_gen), next(alice_session_gen) + alice_message1 = alice_session1.encrypt(alice_plaintext).to_pre_key() - message = another_session.encrypt(b"It's a secret to everybody") - message = message.to_pre_key() + bob_session, alice_decrypted = bob.create_inbound_session(alice.curve25519_key, alice_message1) + assert alice_plaintext == alice_decrypted - assert not session.session_matches(message) + alice_message2 = alice_session2.encrypt(alice_plaintext).to_pre_key() + assert bob_session.session_matches(alice_message2) is False - def test_does_not_match(self): - plaintext = b"It's a secret to everybody" - alice, bob, session = self._create_session() - message = session.encrypt(plaintext) - message = message.to_pre_key() +def test_does_not_match(alice: Account, bob: Account, alice_session: Session): + alice_plaintext = b"It's a secret to everybody" + alice_message1 = alice_session.encrypt(alice_plaintext).to_pre_key() - (bob_session, decrypted) = bob.create_inbound_session( - alice.curve25519_key, message - ) + bob_session, alice_decrypted = bob.create_inbound_session(alice.curve25519_key, alice_message1) + assert alice_plaintext == alice_decrypted - _, _, new_session = self._create_session() + alice_message2 = alice_session.encrypt(b"Hey! Listen!").to_pre_key() + assert bob_session.session_matches(alice_message2) is True - new_message = new_session.encrypt(plaintext) - new_message = new_message.to_pre_key() - assert bob_session.session_matches(new_message) is False +def test_invalid(alice_session_gen: SessionGenerator): + session1, session2 = next(alice_session_gen), next(alice_session_gen) + message = session1.encrypt(b"It's a secret to everybody").to_pre_key() - def test_message_to_parts(self): - plaintext = b"It's a secret to everybody" - alice, bob, session = self._create_session() - message = session.encrypt(plaintext) + assert not session2.session_matches(message) - (message_type, ciphertext) = message.to_parts() +def test_message_to_parts(alice: Account, bob: Account, alice_session: Session): + alice_plaintext = b"It's a secret to everybody" + encrypted = alice_session.encrypt(alice_plaintext) - message = AnyOlmMessage.from_parts(message_type, ciphertext) - message = message.to_pre_key() + alice_message = encrypted.to_pre_key() + alice_message_from_parts = AnyOlmMessage.from_parts(*encrypted.to_parts()).to_pre_key() - (bob_session, decrypted) = bob.create_inbound_session( - alice.curve25519_key, message - ) + assert alice_message.session_id() == alice_message_from_parts.session_id() - assert plaintext == decrypted + bob_session, alice_decrypted = bob.create_inbound_session(alice.curve25519_key, alice_message) + assert alice_plaintext == alice_decrypted diff --git a/uv.lock b/uv.lock index 332087f..ba215e0 100644 --- a/uv.lock +++ b/uv.lock @@ -84,6 +84,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215, upload-time = "2025-03-14T07:11:39.145Z" }, ] +[[package]] +name = "hypothesis" +version = "6.135.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/ca/987c8cd4bb248c1d03e0dbf2c67b174c32517ea3016adee1e927cc8e21b7/hypothesis-6.135.21.tar.gz", hash = "sha256:8f3d3dfc75248a6c0177310b083aaf4ab82fb46d7741b19cb39e2f46f5015fbb", size = 452871, upload-time = "2025-07-02T06:44:26.355Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/9b/092a5a1f47fc197ef3c5db0b5b6cfea51d9c5bbf4273b4e91ac21e4dc215/hypothesis-6.135.21-py3-none-any.whl", hash = "sha256:fe04ecc4c077da4f24f81ebf2bc0a07ada70a04d6466169b37bd4644044e25b8", size = 519592, upload-time = "2025-07-02T06:44:22.711Z" }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -200,6 +214,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -268,6 +291,7 @@ source = { editable = "." } [package.dev-dependencies] dev = [ + { name = "hypothesis" }, { name = "maturin" }, { name = "nox-uv" }, { name = "pytest" }, @@ -276,6 +300,7 @@ nox = [ { name = "nox-uv" }, ] tests = [ + { name = "hypothesis" }, { name = "maturin" }, { name = "pytest" }, ] @@ -284,12 +309,14 @@ tests = [ [package.metadata.requires-dev] dev = [ + { name = "hypothesis", specifier = ">=6.135.21" }, { name = "maturin", specifier = ">=1.9.0" }, { name = "nox-uv", specifier = ">=0.6.0" }, { name = "pytest", specifier = ">=8.4.1" }, ] nox = [{ name = "nox-uv", specifier = ">=0.6.0" }] tests = [ + { name = "hypothesis", specifier = ">=6.135.21" }, { name = "maturin", specifier = ">=1.9.0" }, { name = "pytest", specifier = ">=8.4.1" }, ]