From 103129bd39f6161ae722d90fa581e7490599d04e Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Wed, 24 Dec 2025 12:09:11 +0100 Subject: [PATCH] xmss: mv KeyPair as a container --- .../testing/src/consensus_testing/keys.py | 40 ++----------------- src/lean_spec/subspecs/xmss/containers.py | 30 +++++++++++++- src/lean_spec/subspecs/xmss/interface.py | 10 ++--- 3 files changed, 36 insertions(+), 44 deletions(-) diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 4d24f065..de9746c7 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -33,10 +33,9 @@ import tempfile import urllib.request from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass from functools import cache, partial from pathlib import Path -from typing import TYPE_CHECKING, Iterator, Self +from typing import TYPE_CHECKING, Iterator from lean_spec.config import LEAN_ENV from lean_spec.subspecs.containers import AttestationData @@ -46,7 +45,7 @@ AttestationSignatures, ) from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature +from lean_spec.subspecs.xmss.containers import KeyPair, PublicKey, Signature from lean_spec.subspecs.xmss.interface import ( PROD_SIGNATURE_SCHEME, TEST_SIGNATURE_SCHEME, @@ -120,39 +119,6 @@ def get_shared_key_manager(max_slot: Slot = _DEFAULT_MAX_SLOT) -> XmssKeyManager """Key lifetime in epochs (derived from DEFAULT_MAX_SLOT).""" -@dataclass(frozen=True, slots=True) -class KeyPair: - """ - Immutable XMSS key pair for a validator. - - Attributes: - public: Public key for signature verification. - secret: Secret key containing Merkle tree structures. - """ - - public: PublicKey - secret: SecretKey - - @classmethod - def from_dict(cls, data: Mapping[str, str]) -> Self: - """Deserialize from JSON-compatible dict with hex-encoded SSZ.""" - return cls( - public=PublicKey.decode_bytes(bytes.fromhex(data["public"])), - secret=SecretKey.decode_bytes(bytes.fromhex(data["secret"])), - ) - - def to_dict(self) -> dict[str, str]: - """Serialize to JSON-compatible dict with hex-encoded SSZ.""" - return { - "public": self.public.encode_bytes().hex(), - "secret": self.secret.encode_bytes().hex(), - } - - def with_secret(self, secret: SecretKey) -> KeyPair: - """Return a new KeyPair with updated secret key (for state advancement).""" - return KeyPair(public=self.public, secret=secret) - - def _get_keys_dir(scheme_name: str) -> Path: """Get the keys directory path for the given scheme.""" return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme" @@ -298,7 +264,7 @@ def sign_attestation_data( prepared = self.scheme.get_prepared_interval(sk) # Cache advanced state - self._state[validator_id] = kp.with_secret(sk) + self._state[validator_id] = kp._replace(secret=sk) # Sign hash tree root of the attestation data message = attestation_data.data_root_bytes() diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index 662494fc..4ed01d97 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Mapping, NamedTuple from ...types import Uint64 from ...types.container import Container @@ -181,3 +181,31 @@ class SecretKey(Container): Together with `left_bottom_tree`, this provides a prepared interval of exactly `2 * sqrt(LIFETIME)` consecutive epochs. """ + + +class KeyPair(NamedTuple): + """ + Immutable XMSS key pair for a validator. + + Attributes: + public: Public key for signature verification. + secret: Secret key containing Merkle tree structures. + """ + + public: PublicKey + secret: SecretKey + + @classmethod + def from_dict(cls, data: Mapping[str, str]) -> "KeyPair": + """Deserialize from JSON-compatible dict with hex-encoded SSZ.""" + return cls( + public=PublicKey.decode_bytes(bytes.fromhex(data["public"])), + secret=SecretKey.decode_bytes(bytes.fromhex(data["secret"])), + ) + + def to_dict(self) -> dict[str, str]: + """Serialize to JSON-compatible dict with hex-encoded SSZ.""" + return { + "public": self.public.encode_bytes().hex(), + "secret": self.secret.encode_bytes().hex(), + } diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 23880fa2..5644f834 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -24,7 +24,7 @@ TEST_CONFIG, XmssConfig, ) -from .containers import PublicKey, SecretKey, Signature +from .containers import KeyPair, PublicKey, SecretKey, Signature from .prf import PROD_PRF, TEST_PRF, Prf from .rand import PROD_RAND, TEST_RAND, Rand from .subtree import HashSubTree, combined_path, verify_path @@ -73,9 +73,7 @@ def _validate_strict_types(self) -> "GeneralizedXmssScheme": ) return self - def key_gen( - self, activation_epoch: Uint64, num_active_epochs: Uint64 - ) -> tuple[PublicKey, SecretKey]: + def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPair: """ Generates a new cryptographic key pair for a specified range of epochs. @@ -120,7 +118,7 @@ def key_gen( - Will be rounded up to at least `2 * sqrt(LIFETIME)`. Returns: - A tuple containing the `PublicKey` and `SecretKey`. + A `KeyPair` containing the public and secret keys. Note: The actual activation epoch and num_active_epochs in the returned SecretKey @@ -220,7 +218,7 @@ def key_gen( left_bottom_tree=left_bottom_tree, right_bottom_tree=right_bottom_tree, ) - return pk, sk + return KeyPair(public=pk, secret=sk) def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: """