From b677b77e48372350ae7a31325c150f4f03c53be7 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 20 Feb 2026 00:08:58 +0100 Subject: [PATCH 1/5] chore: stronger typing --- .../testing/src/consensus_testing/keys.py | 26 ++-- .../test_fixtures/fork_choice.py | 5 +- .../test_fixtures/verify_signatures.py | 2 +- src/lean_spec/__main__.py | 3 +- src/lean_spec/subspecs/chain/clock.py | 9 +- .../subspecs/containers/block/block.py | 3 +- .../subspecs/containers/state/state.py | 4 +- .../subspecs/containers/validator.py | 10 +- src/lean_spec/subspecs/forkchoice/store.py | 16 +- .../networking/discovery/handshake.py | 21 +-- .../subspecs/networking/discovery/packet.py | 10 +- .../subspecs/networking/discovery/session.py | 14 +- .../networking/discovery/transport.py | 23 +-- src/lean_spec/subspecs/networking/enr/enr.py | 2 +- src/lean_spec/subspecs/networking/enr/eth2.py | 5 +- .../subspecs/networking/gossipsub/topic.py | 8 +- .../subspecs/networking/service/service.py | 5 +- src/lean_spec/subspecs/networking/types.py | 7 +- src/lean_spec/subspecs/node/node.py | 3 +- src/lean_spec/subspecs/validator/registry.py | 4 +- src/lean_spec/subspecs/validator/service.py | 49 +++--- src/lean_spec/subspecs/xmss/aggregation.py | 17 ++- src/lean_spec/subspecs/xmss/constants.py | 2 +- src/lean_spec/subspecs/xmss/containers.py | 34 +++-- src/lean_spec/subspecs/xmss/interface.py | 139 +++++++++--------- src/lean_spec/subspecs/xmss/utils.py | 28 ++-- .../devnet/ssz/test_xmss_containers.py | 5 +- tests/interop/helpers/node_runner.py | 23 +-- tests/interop/test_consensus_lifecycle.py | 8 +- tests/lean_spec/helpers/builders.py | 20 ++- .../containers/test_state_aggregation.py | 66 ++++++--- .../containers/test_state_justified_slots.py | 9 +- .../forkchoice/test_attestation_target.py | 3 +- .../forkchoice/test_store_attestations.py | 14 +- .../forkchoice/test_time_management.py | 11 +- .../subspecs/forkchoice/test_validator.py | 7 +- .../subspecs/networking/discovery/conftest.py | 6 +- .../networking/discovery/test_handshake.py | 46 +++--- .../networking/discovery/test_integration.py | 24 +-- .../networking/discovery/test_messages.py | 8 +- .../networking/discovery/test_packet.py | 18 +-- .../networking/discovery/test_routing.py | 14 +- .../networking/discovery/test_service.py | 32 ++-- .../networking/discovery/test_session.py | 15 +- .../networking/discovery/test_transport.py | 48 +++--- .../networking/discovery/test_vectors.py | 14 +- .../subspecs/networking/enr/test_enr.py | 123 ++++++++-------- .../subspecs/networking/enr/test_eth2.py | 5 +- .../networking/gossipsub/test_gossipsub.py | 5 +- .../subspecs/networking/test_peer.py | 8 +- .../subspecs/validator/test_registry.py | 4 +- .../subspecs/validator/test_service.py | 29 ++-- .../lean_spec/subspecs/xmss/test_interface.py | 97 ++++++------ .../subspecs/xmss/test_ssz_serialization.py | 49 +++--- tests/lean_spec/subspecs/xmss/test_utils.py | 20 +-- 55 files changed, 626 insertions(+), 554 deletions(-) diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 5709cebe..9e91bc19 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -314,16 +314,16 @@ def sign_attestation_data( Raises: ValueError: If slot exceeds key lifetime. """ - epoch = attestation_data.slot + slot = attestation_data.slot kp = self[validator_id] sk = kp.secret - # Advance key state until epoch is in prepared interval + # Advance key state until slot is in prepared interval prepared = self.scheme.get_prepared_interval(sk) - while int(epoch) not in prepared: + while int(slot) not in prepared: activation = self.scheme.get_activation_interval(sk) if prepared.stop >= activation.stop: - raise ValueError(f"Epoch {epoch} exceeds key lifetime {activation.stop}") + raise ValueError(f"Slot {slot} exceeds key lifetime {activation.stop}") sk = self.scheme.advance_preparation(sk) prepared = self.scheme.get_prepared_interval(sk) @@ -332,7 +332,7 @@ def sign_attestation_data( # Sign hash tree root of the attestation data message = attestation_data.data_root_bytes() - return self.scheme.sign(sk, epoch, message) + return self.scheme.sign(sk, slot, message) def build_attestation_signatures( self, @@ -351,7 +351,7 @@ def build_attestation_signatures( for agg in aggregated_attestations: validator_ids = agg.aggregation_bits.to_validator_indices() message = agg.data.data_root_bytes() - epoch = agg.data.slot + slot = agg.data.slot public_keys: list[PublicKey] = [self.get_public_key(vid) for vid in validator_ids] signatures: list[Signature] = [ @@ -370,7 +370,7 @@ def build_attestation_signatures( public_keys=public_keys, signatures=signatures, message=message, - epoch=epoch, + slot=slot, ) proofs.append(proof) @@ -378,11 +378,11 @@ def build_attestation_signatures( def _generate_single_keypair( - scheme: GeneralizedXmssScheme, num_epochs: int, index: int + scheme: GeneralizedXmssScheme, num_slots: int, index: int ) -> dict[str, str]: """Generate one key pair (module-level for pickling in ProcessPoolExecutor).""" print(f"Starting key #{index} generation...") - pk, sk = scheme.key_gen(Uint64(0), Uint64(num_epochs)) + pk, sk = scheme.key_gen(Slot(0), Uint64(num_slots)) return KeyPair(public=pk, secret=sk).to_dict() @@ -397,20 +397,20 @@ def _generate_keys(lean_env: str, count: int, max_slot: int) -> None: Args: lean_env: Name of the XMSS signature scheme to use (e.g. "test" or "prod"). count: Number of validators. - max_slot: Maximum slot (key lifetime = max_slot + 1 epochs). + max_slot: Maximum slot (key lifetime = max_slot + 1 slots). """ scheme = LEAN_ENV_TO_SCHEMES[lean_env] keys_dir = get_keys_dir(lean_env) - num_epochs = max_slot + 1 + num_slots = max_slot + 1 num_workers = os.cpu_count() or 1 print( f"Generating {count} XMSS key pairs for {lean_env} environment " - f"({num_epochs} epochs) using {num_workers} cores..." + f"({num_slots} slots) using {num_workers} cores..." ) with ProcessPoolExecutor(max_workers=num_workers) as executor: - worker_func = partial(_generate_single_keypair, scheme, num_epochs) + worker_func = partial(_generate_single_keypair, scheme, num_slots) key_pairs = list(executor.map(worker_func, range(count))) # Create keys directory (remove old one if it exists) diff --git a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py index 736adf85..4f67712b 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py +++ b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py @@ -11,6 +11,7 @@ from pydantic import model_validator +from lean_spec.subspecs.chain.clock import Interval from lean_spec.subspecs.chain.config import ( INTERVALS_PER_SLOT, MILLISECONDS_PER_INTERVAL, @@ -246,7 +247,7 @@ def make_fixture(self) -> Self: # TickStep.time is a Unix timestamp in seconds. # Convert to intervals since genesis for the store. delta_ms = (Uint64(step.time) - store.config.genesis_time) * Uint64(1000) - target_interval = delta_ms // MILLISECONDS_PER_INTERVAL + target_interval = Interval(delta_ms // MILLISECONDS_PER_INTERVAL) store, _ = store.on_tick( target_interval, has_proposal=False, is_aggregator=True ) @@ -276,7 +277,7 @@ def make_fixture(self) -> Self: # Store rejects blocks from the future. # This tick includes a block (has proposal). # Always act as aggregator to ensure gossip signatures are aggregated - target_interval = block.slot * INTERVALS_PER_SLOT + target_interval = Interval(block.slot * INTERVALS_PER_SLOT) store, _ = store.on_tick( target_interval, has_proposal=True, is_aggregator=True ) diff --git a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py index faa84c52..0bbc614b 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py +++ b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py @@ -242,7 +242,7 @@ def _build_block_from_spec( public_keys=signer_public_keys, signatures=signer_signatures, message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) # Replace participants with claimed validator_ids (mismatch!) invalid_proof = AggregatedSignatureProof( diff --git a/src/lean_spec/__main__.py b/src/lean_spec/__main__.py index 8edde6ce..5869fefa 100644 --- a/src/lean_spec/__main__.py +++ b/src/lean_spec/__main__.py @@ -35,6 +35,7 @@ from lean_spec.subspecs.containers import Block, BlockBody, Checkpoint, State from lean_spec.subspecs.containers.block.types import AggregatedAttestations from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.containers.validator import SubnetId from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.genesis import GenesisConfig from lean_spec.subspecs.networking.client import LiveNetworkEventSource @@ -489,7 +490,7 @@ async def run_node( # Subscribe to attestation subnet topics based on local validator id. validator_id = validator_registry.primary_index() if validator_registry else None if validator_id is None: - subnet_id = 0 + subnet_id = SubnetId(0) logger.info("No local validator id; subscribing to attestation subnet %d", subnet_id) else: subnet_id = validator_id.compute_subnet_id(ATTESTATION_COMMITTEE_COUNT) diff --git a/src/lean_spec/subspecs/chain/clock.py b/src/lean_spec/subspecs/chain/clock.py index aaa6410e..51a371e0 100644 --- a/src/lean_spec/subspecs/chain/clock.py +++ b/src/lean_spec/subspecs/chain/clock.py @@ -17,8 +17,9 @@ from .config import MILLISECONDS_PER_INTERVAL, MILLISECONDS_PER_SLOT, SECONDS_PER_SLOT -Interval = Uint64 -"""Interval count since genesis (matches ``Store.time``).""" + +class Interval(Uint64): + """Interval count since genesis (matches ``Store.time``).""" @dataclass(frozen=True, slots=True) @@ -57,7 +58,7 @@ def current_slot(self) -> Slot: def current_interval(self) -> Interval: """Get the current interval within the slot (0-4).""" milliseconds_into_slot = self._milliseconds_since_genesis() % MILLISECONDS_PER_SLOT - return milliseconds_into_slot // MILLISECONDS_PER_INTERVAL + return Interval(milliseconds_into_slot // MILLISECONDS_PER_INTERVAL) def total_intervals(self) -> Interval: """ @@ -65,7 +66,7 @@ def total_intervals(self) -> Interval: This is the value expected by our store time type. """ - return self._milliseconds_since_genesis() // MILLISECONDS_PER_INTERVAL + return Interval(self._milliseconds_since_genesis() // MILLISECONDS_PER_INTERVAL) def current_time(self) -> Uint64: """Get current wall-clock time as Uint64 (Unix timestamp in seconds).""" diff --git a/src/lean_spec/subspecs/containers/block/block.py b/src/lean_spec/subspecs/containers/block/block.py index 7cb3ae95..7900ceba 100644 --- a/src/lean_spec/subspecs/containers/block/block.py +++ b/src/lean_spec/subspecs/containers/block/block.py @@ -164,12 +164,11 @@ def verify_signatures( public_keys = [validators[vid].get_pubkey() for vid in validator_ids] # Verify the aggregated signature against all public keys. - # Uses slot as epoch for XMSS one-time signature indexing. try: aggregated_signature.verify( public_keys=public_keys, message=attestation_data_root, - epoch=aggregated_attestation.data.slot, + slot=aggregated_attestation.data.slot, ) except AggregationError as exc: raise AssertionError( diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index a494f0fe..0d94fd85 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -867,7 +867,7 @@ def aggregate_gossip_signatures( public_keys=gossip_keys, signatures=gossip_sigs, message=data_root, - epoch=data.slot, + slot=data.slot, ) attestation = AggregatedAttestation(aggregation_bits=participants, data=data) results.append((attestation, proof)) @@ -915,7 +915,7 @@ def select_aggregated_proofs( ) # validators contributed to this attestation # Validators that are missing in the current aggregation are put into remaining. - remaining: set[Uint64] = set(validator_ids) + remaining: set[ValidatorIndex] = set(validator_ids) # Fallback to existing proofs # diff --git a/src/lean_spec/subspecs/containers/validator.py b/src/lean_spec/subspecs/containers/validator.py index e2f2492c..440a7e53 100644 --- a/src/lean_spec/subspecs/containers/validator.py +++ b/src/lean_spec/subspecs/containers/validator.py @@ -9,6 +9,10 @@ from .slot import Slot +class SubnetId(Uint64): + """Subnet identifier (0-63) for attestation subnet partitioning.""" + + class ValidatorIndex(Uint64): """Represents a validator's unique index as a 64-bit unsigned integer.""" @@ -24,16 +28,16 @@ def is_valid(self, num_validators: int) -> bool: """Check if this index is within valid bounds for a registry of given size.""" return int(self) < num_validators - def compute_subnet_id(self, num_committees: "int | Uint64") -> int: + def compute_subnet_id(self, num_committees: int) -> SubnetId: """Compute the attestation subnet id for this validator. Args: num_committees: Positive number of committees. Returns: - An integer subnet id in 0..(num_committees-1). + A SubnetId in 0..(num_committees-1). """ - return int(self) % int(num_committees) + return SubnetId(int(self) % int(num_committees)) class ValidatorIndices(SSZList[ValidatorIndex]): diff --git a/src/lean_spec/subspecs/forkchoice/store.py b/src/lean_spec/subspecs/forkchoice/store.py index 76d31ed2..b60bea00 100644 --- a/src/lean_spec/subspecs/forkchoice/store.py +++ b/src/lean_spec/subspecs/forkchoice/store.py @@ -9,6 +9,7 @@ import copy from collections import defaultdict +from lean_spec.subspecs.chain.clock import Interval from lean_spec.subspecs.chain.config import ( ATTESTATION_COMMITTEE_COUNT, INTERVALS_PER_SLOT, @@ -62,7 +63,7 @@ class Store(Container): - or when the head is recomputed. """ - time: Uint64 + time: Interval """Current time in intervals since genesis.""" config: Config @@ -214,7 +215,7 @@ def get_forkchoice_store( # but the Store must treat the anchor block as the justified/finalized point. return cls( - time=Uint64(anchor_slot * INTERVALS_PER_SLOT), + time=Interval(anchor_slot * INTERVALS_PER_SLOT), config=anchor_state.config, head=anchor_root, safe_target=anchor_root, @@ -476,7 +477,7 @@ def on_gossip_aggregated_attestation( proof.verify( public_keys=public_keys, message=data.data_root_bytes(), - epoch=data.slot, + slot=data.slot, ) except AggregationError as exc: raise AssertionError( @@ -1117,7 +1118,7 @@ def tick_interval( Tuple of (new store with advanced time, list of new signed aggregated attestation). """ # Advance time by one interval - store = self.model_copy(update={"time": self.time + Uint64(1)}) + store = self.model_copy(update={"time": Interval(int(self.time) + 1)}) current_interval = store.time % INTERVALS_PER_SLOT new_aggregates: list[SignedAggregatedAttestation] = [] @@ -1139,7 +1140,7 @@ def tick_interval( return store, new_aggregates def on_tick( - self, target_interval: Uint64, has_proposal: bool, is_aggregator: bool = False + self, target_interval: Interval, has_proposal: bool, is_aggregator: bool = False ) -> tuple["Store", list[SignedAggregatedAttestation]]: """ Advance forkchoice store time to given interval count. @@ -1163,7 +1164,8 @@ def on_tick( # Tick forward one interval at a time while store.time < target_interval: # Check if proposal should be signaled for next interval - should_signal_proposal = has_proposal and (store.time + Uint64(1)) == target_interval + next_interval = Interval(int(store.time) + 1) + should_signal_proposal = has_proposal and next_interval == target_interval # Advance by one interval with appropriate signaling store, new_aggregates = store.tick_interval(should_signal_proposal, is_aggregator) @@ -1193,7 +1195,7 @@ def get_proposal_head(self, slot: Slot) -> tuple["Store", Bytes32]: Tuple of (new Store with updated time, head root for building). """ # Advance time to this slot's first interval - target_interval = Uint64(slot * INTERVALS_PER_SLOT) + target_interval = Interval(slot * INTERVALS_PER_SLOT) store, _ = self.on_tick(target_interval, True) # Process any pending attestations before proposal diff --git a/src/lean_spec/subspecs/networking/discovery/handshake.py b/src/lean_spec/subspecs/networking/discovery/handshake.py index 23f10c5c..2848ac9e 100644 --- a/src/lean_spec/subspecs/networking/discovery/handshake.py +++ b/src/lean_spec/subspecs/networking/discovery/handshake.py @@ -31,7 +31,7 @@ from threading import Lock from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId +from lean_spec.subspecs.networking.types import NodeId, SeqNumber from lean_spec.types import Bytes32, Bytes33, Bytes64 from .config import HANDSHAKE_TIMEOUT_SECS @@ -41,7 +41,7 @@ verify_id_nonce_signature, ) from .keys import derive_keys_from_pubkey -from .messages import PacketFlag +from .messages import PacketFlag, Port from .packet import ( HandshakeAuthdata, WhoAreYouAuthdata, @@ -52,6 +52,9 @@ ) from .session import Session, SessionCache +_DEFAULT_PORT = Port(0) +"""Default port value for optional port parameters.""" + MAX_PENDING_HANDSHAKES = 100 """Hard cap on concurrent pending handshakes to prevent resource exhaustion.""" @@ -97,7 +100,7 @@ class PendingHandshake: challenge_nonce: bytes | None = None """12-byte nonce from the packet that triggered WHOAREYOU.""" - remote_enr_seq: int = 0 + remote_enr_seq: SeqNumber = SeqNumber(0) """ENR seq we sent in WHOAREYOU. If 0, remote MUST include their ENR in HANDSHAKE.""" started_at: float = field(default_factory=time.time) @@ -144,7 +147,7 @@ def __init__( local_node_id: NodeId, local_private_key: bytes, local_enr_rlp: bytes, - local_enr_seq: int, + local_enr_seq: SeqNumber, session_cache: SessionCache, timeout_secs: float = HANDSHAKE_TIMEOUT_SECS, ): @@ -204,7 +207,7 @@ def create_whoareyou( self, remote_node_id: NodeId, request_nonce: bytes, - remote_enr_seq: int, + remote_enr_seq: SeqNumber, masking_iv: bytes, ) -> tuple[bytes, bytes, bytes, bytes]: """ @@ -255,7 +258,7 @@ def create_handshake_response( remote_pubkey: bytes, challenge_data: bytes, remote_ip: str = "", - remote_port: int = 0, + remote_port: Port = _DEFAULT_PORT, ) -> tuple[bytes, bytes, bytes]: """ Create a HANDSHAKE packet in response to WHOAREYOU. @@ -293,7 +296,7 @@ def create_handshake_response( # Include our ENR if the remote's known seq is stale. record = None - if int(whoareyou.enr_seq) < self._local_enr_seq: + if whoareyou.enr_seq < self._local_enr_seq: record = self._local_enr_rlp # Build authdata. @@ -338,7 +341,7 @@ def handle_handshake( remote_node_id: NodeId, handshake: HandshakeAuthdata, remote_ip: str = "", - remote_port: int = 0, + remote_port: Port = _DEFAULT_PORT, ) -> HandshakeResult: """ Process a received HANDSHAKE packet. @@ -383,7 +386,7 @@ def handle_handshake( # If we sent enr_seq=0, we signaled that we don't know the remote's ENR. # Per spec, the remote MUST include their ENR in the HANDSHAKE response # so we can verify their identity. - if remote_enr_seq == 0 and handshake.record is None: + if remote_enr_seq == SeqNumber(0) and handshake.record is None: raise HandshakeError( f"ENR required in HANDSHAKE from unknown node {remote_node_id.hex()[:16]}" ) diff --git a/src/lean_spec/subspecs/networking/discovery/packet.py b/src/lean_spec/subspecs/networking/discovery/packet.py index b52cca1e..45896c7e 100644 --- a/src/lean_spec/subspecs/networking/discovery/packet.py +++ b/src/lean_spec/subspecs/networking/discovery/packet.py @@ -33,8 +33,8 @@ import struct from dataclasses import dataclass -from lean_spec.subspecs.networking.types import NodeId -from lean_spec.types import Bytes12, Bytes16, Uint64 +from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes12, Bytes16 from .config import MAX_PACKET_SIZE, MIN_PACKET_SIZE from .crypto import ( @@ -92,7 +92,7 @@ class WhoAreYouAuthdata: id_nonce: IdNonce """16-byte identity challenge nonce.""" - enr_seq: Uint64 + enr_seq: SeqNumber """Sender's last known ENR sequence for the target. 0 if unknown.""" @@ -273,7 +273,7 @@ def decode_whoareyou_authdata(authdata: bytes) -> WhoAreYouAuthdata: raise ValueError(f"Invalid WHOAREYOU authdata size: {len(authdata)}") id_nonce = IdNonce(authdata[:16]) - enr_seq = Uint64(struct.unpack(">Q", authdata[16:24])[0]) + enr_seq = SeqNumber(struct.unpack(">Q", authdata[16:24])[0]) return WhoAreYouAuthdata(id_nonce=id_nonce, enr_seq=enr_seq) @@ -342,7 +342,7 @@ def encode_message_authdata(src_id: NodeId) -> bytes: return src_id -def encode_whoareyou_authdata(id_nonce: bytes, enr_seq: int) -> bytes: +def encode_whoareyou_authdata(id_nonce: bytes, enr_seq: SeqNumber) -> bytes: """Encode WHOAREYOU packet authdata.""" if len(id_nonce) != 16: raise ValueError(f"ID nonce must be 16 bytes, got {len(id_nonce)}") diff --git a/src/lean_spec/subspecs/networking/discovery/session.py b/src/lean_spec/subspecs/networking/discovery/session.py index 9f9581e2..6708d556 100644 --- a/src/lean_spec/subspecs/networking/discovery/session.py +++ b/src/lean_spec/subspecs/networking/discovery/session.py @@ -27,6 +27,10 @@ from lean_spec.subspecs.networking.types import NodeId from .config import BOND_EXPIRY_SECS +from .messages import Port + +_DEFAULT_PORT = Port(0) +"""Default port value for optional port parameters.""" DEFAULT_SESSION_TIMEOUT_SECS = 86400 """Default session timeout (24 hours).""" @@ -71,7 +75,7 @@ def touch(self) -> None: self.last_seen = time.time() -type SessionKey = tuple[NodeId, str, int] +type SessionKey = tuple[NodeId, str, Port] """Session cache key: (node_id, ip, port). Per spec, sessions are tied to a specific UDP endpoint. @@ -101,7 +105,7 @@ class SessionCache: _lock: Lock = field(default_factory=Lock) """Thread safety lock.""" - def get(self, node_id: NodeId, ip: str = "", port: int = 0) -> Session | None: + def get(self, node_id: NodeId, ip: str = "", port: Port = _DEFAULT_PORT) -> Session | None: """ Get an active session for a node at a specific endpoint. @@ -134,7 +138,7 @@ def create( recv_key: bytes, is_initiator: bool, ip: str = "", - port: int = 0, + port: Port = _DEFAULT_PORT, ) -> Session: """ Create and store a new session. @@ -180,7 +184,7 @@ def create( return session - def remove(self, node_id: NodeId, ip: str = "", port: int = 0) -> bool: + def remove(self, node_id: NodeId, ip: str = "", port: Port = _DEFAULT_PORT) -> bool: """ Remove a session. @@ -199,7 +203,7 @@ def remove(self, node_id: NodeId, ip: str = "", port: int = 0) -> bool: return True return False - def touch(self, node_id: NodeId, ip: str = "", port: int = 0) -> bool: + def touch(self, node_id: NodeId, ip: str = "", port: Port = _DEFAULT_PORT) -> bool: """ Update the last_seen timestamp for a session. diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py index 16c443ba..02f9aba8 100644 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ b/src/lean_spec/subspecs/networking/discovery/transport.py @@ -24,8 +24,8 @@ from cryptography.exceptions import InvalidTag from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId -from lean_spec.types import Bytes16, Uint64 +from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes16 from .codec import ( DiscoveryMessage, @@ -44,6 +44,7 @@ PacketFlag, Ping, Pong, + Port, TalkReq, TalkResp, ) @@ -186,7 +187,7 @@ def __init__( local_node_id=local_node_id, local_private_key=local_private_key, local_enr_rlp=local_enr.to_rlp(), - local_enr_seq=int(local_enr.seq), + local_enr_seq=SeqNumber(local_enr.seq), session_cache=self._session_cache, ) @@ -302,7 +303,7 @@ async def send_ping(self, dest_node_id: NodeId, dest_addr: tuple[str, int]) -> P request_id = generate_request_id() ping = Ping( request_id=request_id, - enr_seq=Uint64(self._local_enr.seq), + enr_seq=SeqNumber(self._local_enr.seq), ) response = await self._send_request(dest_node_id, dest_addr, ping) @@ -538,7 +539,7 @@ def _build_message_packet( Encoded packet bytes. """ ip, port = dest_addr - session = self._session_cache.get(dest_node_id, ip, port) + session = self._session_cache.get(dest_node_id, ip, Port(port)) authdata = encode_message_authdata(self._local_node_id) @@ -670,7 +671,7 @@ async def _handle_whoareyou( remote_pubkey=remote_pubkey, challenge_data=challenge_data, remote_ip=ip, - remote_port=port, + remote_port=Port(port), ) # Re-send the original message, now encrypted with the new session key. @@ -716,7 +717,7 @@ async def _handle_handshake( try: ip, port = addr result = self._handshake_manager.handle_handshake( - remote_node_id, handshake_authdata, remote_ip=ip, remote_port=port + remote_node_id, handshake_authdata, remote_ip=ip, remote_port=Port(port) ) logger.debug("Handshake completed with %s", remote_node_id.hex()[:16]) @@ -753,7 +754,7 @@ async def _handle_message( # Get session keyed by (node_id, ip, port). ip, port = addr - session = self._session_cache.get(remote_node_id, ip, port) + session = self._session_cache.get(remote_node_id, ip, Port(port)) if session is None: # Can't decrypt - send WHOAREYOU. await self._send_whoareyou(remote_node_id, header.nonce, addr) @@ -784,7 +785,7 @@ async def _handle_decoded_message( """Process a successfully decoded message.""" # Update session activity. ip, port = addr - self._session_cache.touch(remote_node_id, ip, port) + self._session_cache.touch(remote_node_id, ip, Port(port)) # Check if this is a response to a pending request. request_id = bytes(message.request_id) @@ -821,7 +822,7 @@ async def _send_whoareyou( # including the full ENR in the handshake response, saving bandwidth. # Fall back to 0 if unknown, which forces the remote to include their ENR. cached_enr = self._handshake_manager.get_cached_enr(remote_node_id) - remote_enr_seq = int(cached_enr.seq) if cached_enr is not None else 0 + remote_enr_seq = SeqNumber(cached_enr.seq) if cached_enr is not None else SeqNumber(0) # Generate masking IV for the WHOAREYOU packet. # @@ -879,7 +880,7 @@ async def send_response( # The requester initiated the handshake. # By the time we respond, session keys must exist. ip, port = dest_addr - session = self._session_cache.get(dest_node_id, ip, port) + session = self._session_cache.get(dest_node_id, ip, Port(port)) if session is None: logger.debug("No session for response to %s", dest_node_id.hex()[:16]) return False diff --git a/src/lean_spec/subspecs/networking/enr/enr.py b/src/lean_spec/subspecs/networking/enr/enr.py index 326ff24f..3afc188c 100644 --- a/src/lean_spec/subspecs/networking/enr/enr.py +++ b/src/lean_spec/subspecs/networking/enr/enr.py @@ -388,7 +388,7 @@ def from_rlp(cls, rlp_data: bytes) -> Self: enr = cls( signature=signature, - seq=Uint64(seq), + seq=SeqNumber(seq), pairs=pairs, ) diff --git a/src/lean_spec/subspecs/networking/enr/eth2.py b/src/lean_spec/subspecs/networking/enr/eth2.py index e8763361..5df7ab12 100644 --- a/src/lean_spec/subspecs/networking/enr/eth2.py +++ b/src/lean_spec/subspecs/networking/enr/eth2.py @@ -18,6 +18,7 @@ from typing import ClassVar +from lean_spec.subspecs.containers.validator import SubnetId from lean_spec.subspecs.networking.types import ForkDigest, Version from lean_spec.types import StrictBaseModel, Uint64 from lean_spec.types.bitfields import BaseBitvector @@ -89,9 +90,9 @@ def is_subscribed(self, subnet_id: int) -> bool: raise ValueError(f"Subnet ID must be 0-63, got {subnet_id}") return bool(self.data[subnet_id]) - def subscribed_subnets(self) -> list[int]: + def subscribed_subnets(self) -> list[SubnetId]: """List of subscribed subnet IDs.""" - return [i for i in range(self.LENGTH) if self.data[i]] + return [SubnetId(i) for i in range(self.LENGTH) if self.data[i]] def subscription_count(self) -> int: """Number of subscribed subnets.""" diff --git a/src/lean_spec/subspecs/networking/gossipsub/topic.py b/src/lean_spec/subspecs/networking/gossipsub/topic.py index c4b58ea1..6b1c5560 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/topic.py +++ b/src/lean_spec/subspecs/networking/gossipsub/topic.py @@ -61,6 +61,8 @@ from dataclasses import dataclass from enum import Enum +from lean_spec.subspecs.containers.validator import SubnetId + class ForkMismatchError(ValueError): """Raised when a topic's fork_digest does not match the expected value.""" @@ -151,7 +153,7 @@ class GossipTopic: Peers must match on fork digest to exchange messages on a topic. """ - subnet_id: int | None = None + subnet_id: SubnetId | None = None """Subnet id for attestation subnet topics (required for ATTESTATION_SUBNET).""" def __str__(self) -> str: @@ -227,7 +229,7 @@ def from_string(cls, topic_str: str) -> GossipTopic: try: # Validate the subnet ID is a valid integer subnet_part = topic_name[len("attestation_") :] - subnet_id = int(subnet_part) + subnet_id = SubnetId(int(subnet_part)) return cls( kind=TopicKind.ATTESTATION_SUBNET, fork_digest=fork_digest, @@ -290,7 +292,7 @@ def committee_aggregation(cls, fork_digest: str) -> GossipTopic: return cls(kind=TopicKind.AGGREGATED_ATTESTATION, fork_digest=fork_digest) @classmethod - def attestation_subnet(cls, fork_digest: str, subnet_id: int) -> GossipTopic: + def attestation_subnet(cls, fork_digest: str, subnet_id: SubnetId) -> GossipTopic: """Create an attestation subnet topic for the given fork and subnet. Args: diff --git a/src/lean_spec/subspecs/networking/service/service.py b/src/lean_spec/subspecs/networking/service/service.py index 5fc59902..3bcd5bbd 100644 --- a/src/lean_spec/subspecs/networking/service/service.py +++ b/src/lean_spec/subspecs/networking/service/service.py @@ -29,6 +29,7 @@ from lean_spec.snappy import frame_compress from lean_spec.subspecs.containers import SignedBlockWithAttestation from lean_spec.subspecs.containers.attestation import SignedAggregatedAttestation, SignedAttestation +from lean_spec.subspecs.containers.validator import SubnetId from lean_spec.subspecs.networking.client.event_source import EventSource from lean_spec.subspecs.networking.gossipsub.topic import GossipTopic from lean_spec.subspecs.networking.peer import PeerInfo @@ -222,7 +223,9 @@ async def publish_block(self, block: SignedBlockWithAttestation) -> None: await self.event_source.publish(str(topic), compressed) logger.debug("Published block at slot %s", block.message.block.slot) - async def publish_attestation(self, attestation: SignedAttestation, subnet_id: int) -> None: + async def publish_attestation( + self, attestation: SignedAttestation, subnet_id: SubnetId + ) -> None: """ Publish an attestation to the attestation subnet gossip topic. diff --git a/src/lean_spec/subspecs/networking/types.py b/src/lean_spec/subspecs/networking/types.py index b004a3fe..eb3283bf 100644 --- a/src/lean_spec/subspecs/networking/types.py +++ b/src/lean_spec/subspecs/networking/types.py @@ -23,11 +23,10 @@ Version = Bytes4 """4-byte fork version number (e.g., 0x01000000 for Phase0).""" -SeqNumber = Uint64 -"""Sequence number used in ENR records, metadata, and ping messages.""" -SubnetId = Uint64 -"""Subnet identifier (0-63) for attestation subnet partitioning.""" +class SeqNumber(Uint64): + """Sequence number used in ENR records, metadata, and ping messages.""" + ProtocolId = str """Libp2p protocol identifier, e.g. ``/eth2/beacon_chain/req/status/1/ssz_snappy``.""" diff --git a/src/lean_spec/subspecs/node/node.py b/src/lean_spec/subspecs/node/node.py index a6a4c2ca..3372b02c 100644 --- a/src/lean_spec/subspecs/node/node.py +++ b/src/lean_spec/subspecs/node/node.py @@ -19,6 +19,7 @@ from lean_spec.subspecs.api import ApiServer, ApiServerConfig from lean_spec.subspecs.chain import SlotClock +from lean_spec.subspecs.chain.clock import Interval from lean_spec.subspecs.chain.config import ( ATTESTATION_COMMITTEE_COUNT, INTERVALS_PER_SLOT, @@ -367,7 +368,7 @@ def _try_load_from_database( # The store starts with just the head block and state. # Additional blocks can be loaded on demand or via sync. return Store( - time=store_time, + time=Interval(int(store_time)), config=head_state.config, head=head_root, safe_target=head_root, diff --git a/src/lean_spec/subspecs/validator/registry.py b/src/lean_spec/subspecs/validator/registry.py index 3eac698e..7a291486 100644 --- a/src/lean_spec/subspecs/validator/registry.py +++ b/src/lean_spec/subspecs/validator/registry.py @@ -297,7 +297,7 @@ def from_yaml( return registry @classmethod - def from_secret_keys(cls, keys: dict[int, SecretKey]) -> ValidatorRegistry: + def from_secret_keys(cls, keys: dict[ValidatorIndex, SecretKey]) -> ValidatorRegistry: """ Create registry from a dictionary of secret keys. @@ -311,5 +311,5 @@ def from_secret_keys(cls, keys: dict[int, SecretKey]) -> ValidatorRegistry: """ registry = cls() for index, secret_key in keys.items(): - registry.add(ValidatorEntry(index=ValidatorIndex(index), secret_key=secret_key)) + registry.add(ValidatorEntry(index=index, secret_key=secret_key)) return registry diff --git a/src/lean_spec/subspecs/validator/service.py b/src/lean_spec/subspecs/validator/service.py index e077bb4e..610ab092 100644 --- a/src/lean_spec/subspecs/validator/service.py +++ b/src/lean_spec/subspecs/validator/service.py @@ -111,7 +111,7 @@ class ValidatorService: _attestations_produced: int = field(default=0, repr=False) """Counter for produced attestations.""" - _attested_slots: set[int] = field(default_factory=set, repr=False) + _attested_slots: set[Slot] = field(default_factory=set, repr=False) """Slots for which we've already produced attestations (prevents duplicates).""" async def run(self) -> None: @@ -188,14 +188,13 @@ async def run(self) -> None: # we should still attest as soon as we can within the same slot. # # We track attested slots to prevent duplicate attestations. - slot_int = int(slot) logger.debug( - "ValidatorService: attestation check interval=%d slot_int=%d attested=%s", + "ValidatorService: attestation check interval=%d slot=%d attested=%s", interval, - slot_int, - slot_int in self._attested_slots, + slot, + slot in self._attested_slots, ) - if interval >= Uint64(1) and slot_int not in self._attested_slots: + if interval >= Uint64(1) and slot not in self._attested_slots: logger.debug( "ValidatorService: producing attestations for slot %d (interval %d)", slot, @@ -203,13 +202,13 @@ async def run(self) -> None: ) await self._produce_attestations(slot) logger.debug("ValidatorService: done producing attestations for slot %d", slot) - self._attested_slots.add(slot_int) + self._attested_slots.add(slot) # Prune old entries to prevent unbounded growth. # # Keep only recent slots (current slot - 4) to bound memory usage. # We never need to attest for slots that far in the past. - prune_threshold = max(0, slot_int - 4) + prune_threshold = Slot(max(0, int(slot) - 4)) self._attested_slots = {s for s in self._attested_slots if s >= prune_threshold} # Intervals 2-4 have no additional validator duties. @@ -427,8 +426,8 @@ def _sign_block( if entry is None: raise ValueError(f"No secret key for validator {validator_index}") - # Ensure the XMSS secret key is prepared for this epoch. - entry = self._ensure_prepared_for_epoch(entry, block.slot) + # Ensure the XMSS secret key is prepared for this slot. + entry = self._ensure_prepared_for_slot(entry, block.slot) proposer_signature = TARGET_SIGNATURE_SCHEME.sign( entry.secret_key, @@ -479,12 +478,12 @@ def _sign_attestation( if entry is None: raise ValueError(f"No secret key for validator {validator_index}") - # Ensure the XMSS secret key is prepared for this epoch. - entry = self._ensure_prepared_for_epoch(entry, attestation_data.slot) + # Ensure the XMSS secret key is prepared for this slot. + entry = self._ensure_prepared_for_slot(entry, attestation_data.slot) # Sign the attestation data root. # - # Uses XMSS one-time signature for the current epoch (slot). + # Uses XMSS one-time signature for the current slot. signature = TARGET_SIGNATURE_SCHEME.sign( entry.secret_key, attestation_data.slot, @@ -546,36 +545,36 @@ def _store_proposer_attestation_signature( } ) - def _ensure_prepared_for_epoch( + def _ensure_prepared_for_slot( self, entry: ValidatorEntry, - epoch: Slot, + slot: Slot, ) -> ValidatorEntry: """ - Ensure the secret key is prepared for signing at the given epoch. + Ensure the secret key is prepared for signing at the given slot. - XMSS uses a sliding window of prepared epochs. If the requested epoch + XMSS uses a sliding window of prepared slots. If the requested slot is outside this window, we advance the preparation by computing - additional bottom trees until the epoch is covered. + additional bottom trees until the slot is covered. Args: entry: Validator entry containing the secret key. - epoch: The epoch (slot) at which we need to sign. + slot: The slot at which we need to sign. Returns: The entry, possibly with an updated secret key. """ scheme = cast(GeneralizedXmssScheme, TARGET_SIGNATURE_SCHEME) - get_prepared_interval = scheme.get_prepared_interval(entry.secret_key) + prepared_interval = scheme.get_prepared_interval(entry.secret_key) - # If epoch is already in the prepared interval, no action needed. - epoch_int = int(epoch) - if epoch_int in get_prepared_interval: + # If slot is already in the prepared interval, no action needed. + slot_int = int(slot) + if slot_int in prepared_interval: return entry - # Advance preparation until the epoch is covered. + # Advance preparation until the slot is covered. secret_key = entry.secret_key - while epoch_int not in scheme.get_prepared_interval(secret_key): + while slot_int not in scheme.get_prepared_interval(secret_key): secret_key = scheme.advance_preparation(secret_key) # Update the registry with the new secret key. diff --git a/src/lean_spec/subspecs/xmss/aggregation.py b/src/lean_spec/subspecs/xmss/aggregation.py index acd553f4..809a52e7 100644 --- a/src/lean_spec/subspecs/xmss/aggregation.py +++ b/src/lean_spec/subspecs/xmss/aggregation.py @@ -13,8 +13,9 @@ ) from lean_spec.subspecs.containers.attestation import AggregationBits +from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.validator import ValidatorIndex -from lean_spec.types import Bytes32, Uint64 +from lean_spec.types import Bytes32 from lean_spec.types.byte_arrays import ByteListMiB from lean_spec.types.container import Container @@ -61,7 +62,7 @@ class AggregatedSignatureProof(Container): it covers. The proof can verify that all participants signed the same message in the - same epoch, using a single verification operation instead of checking + same slot, using a single verification operation instead of checking each signature individually. """ @@ -78,7 +79,7 @@ def aggregate( public_keys: Sequence[PublicKey], signatures: Sequence[Signature], message: Bytes32, - epoch: Uint64, + slot: Slot, mode: str | None = None, ) -> Self: """ @@ -89,7 +90,7 @@ def aggregate( public_keys: Public keys of the signers (must match signatures order). signatures: Individual XMSS signatures to aggregate. message: The 32-byte message that was signed. - epoch: The epoch in which the signatures were created. + slot: The slot in which the signatures were created. mode: The mode to use for the aggregation (test or prod). Returns: @@ -105,7 +106,7 @@ def aggregate( [pk.encode_bytes() for pk in public_keys], [sig.encode_bytes() for sig in signatures], message, - epoch, + slot, mode=mode, ) return cls( @@ -119,7 +120,7 @@ def verify( self, public_keys: Sequence[PublicKey], message: Bytes32, - epoch: Uint64, + slot: Slot, mode: str | None = None, ) -> None: """ @@ -128,7 +129,7 @@ def verify( Args: public_keys: Public keys of the participants (order must match participants bitfield). message: The 32-byte message that was signed. - epoch: The epoch in which the signatures were created. + slot: The slot in which the signatures were created. mode: The mode to use for the verification (test or prod). Raises: @@ -141,7 +142,7 @@ def verify( [pk.encode_bytes() for pk in public_keys], message, self.proof_data.encode_bytes(), - epoch, + slot, mode=mode, ) except Exception as exc: diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index 79dd2289..38d5c4f5 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -31,7 +31,7 @@ class XmssConfig(StrictBaseModel): @property def LIFETIME(self) -> Uint64: # noqa: N802 """ - The maximum number of epochs supported by this configuration. + The maximum number of slots supported by this configuration. An individual key pair can be active for a smaller sub-range. """ diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index d2d0ea6a..c9ad8060 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -11,6 +11,8 @@ from pydantic import model_serializer +from lean_spec.subspecs.containers.slot import Slot + from ...types import Bytes32, Uint64 from ...types.container import Container from .subtree import HashSubTree @@ -52,7 +54,7 @@ class Signature(Container): A signature produced by the `sign` function. It contains all the necessary components for a verifier to confirm that a - specific message was signed by the owner of a `PublicKey` for a specific epoch. + specific message was signed by the owner of a `PublicKey` for a specific slot. SSZ Container with fields: - path: HashTreeOpening (container with siblings list) @@ -77,7 +79,7 @@ def _serialize_as_bytes(self) -> str: def verify( self, public_key: PublicKey, - epoch: "Uint64", + slot: "Slot", message: "Bytes32", scheme: "GeneralizedXmssScheme", ) -> bool: @@ -89,13 +91,13 @@ def verify( Invalid or malformed signatures return `False`. Expected exceptions: - - `ValueError` for invalid epochs, + - `ValueError` for invalid slots, - `IndexError` for malformed signatures are caught and converted to `False`. Args: public_key: The public key to verify against. - epoch: The epoch the signature corresponds to. + slot: The slot the signature corresponds to. message: The message that was supposedly signed. scheme: The XMSS scheme instance to use for verification. @@ -103,7 +105,7 @@ def verify( `True` if the signature is valid, `False` otherwise. """ try: - return scheme.verify(public_key, epoch, message, self) + return scheme.verify(public_key, slot, message, self) except (ValueError, IndexError): return False @@ -113,13 +115,13 @@ class SecretKey(Container): The private component of a key pair. **MUST BE KEPT CONFIDENTIAL.** This object contains all the secret material and pre-computed data needed to - generate signatures for any epoch within its active lifetime. + generate signatures for any slot within its active lifetime. SSZ Container with fields: - prf_key: Bytes[PRF_KEY_LENGTH] - parameter: Vector[Fp, PARAMETER_LEN] - - activation_epoch: uint64 - - num_active_epochs: uint64 + - activation_slot: uint64 + - num_active_slots: uint64 - top_tree: HashSubTree - left_bottom_tree_index: uint64 - left_bottom_tree: HashSubTree @@ -134,17 +136,17 @@ class SecretKey(Container): parameter: Parameter """The public parameter `P`, stored for convenience during signing.""" - activation_epoch: Uint64 + activation_slot: Slot """ - The first epoch for which this secret key is valid. + The first slot for which this secret key is valid. Note: With top-bottom trees, this is aligned to a multiple of `sqrt(LIFETIME)` to ensure efficient tree partitioning. """ - num_active_epochs: Uint64 + num_active_slots: Uint64 """ - The number of consecutive epochs this key can be used for. + The number of consecutive slots this key can be used for. Note: With top-bottom trees, this is rounded up to be a multiple of `sqrt(LIFETIME)`, with a minimum of `2 * sqrt(LIFETIME)`. @@ -162,7 +164,7 @@ class SecretKey(Container): """ The index of the left bottom tree in the sliding window. - Bottom trees are numbered 0, 1, 2, ... where tree `i` covers epochs + Bottom trees are numbered 0, 1, 2, ... where tree `i` covers slots `[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))`. The prepared interval is: @@ -174,7 +176,7 @@ class SecretKey(Container): """ The left bottom tree in the sliding window. - This covers epochs: + This covers slots: [left_bottom_tree_index * sqrt(LIFETIME), (left_bottom_tree_index + 1) * sqrt(LIFETIME)) """ @@ -182,11 +184,11 @@ class SecretKey(Container): """ The right bottom tree in the sliding window. - This covers epochs: + This covers slots: [(left_bottom_tree_index + 1) * sqrt(LIFETIME), (left_bottom_tree_index + 2) * sqrt(LIFETIME)) Together with `left_bottom_tree`, this provides a prepared interval of - exactly `2 * sqrt(LIFETIME)` consecutive epochs. + exactly `2 * sqrt(LIFETIME)` consecutive slots. """ diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 93eb7245..ceef26a8 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -11,6 +11,7 @@ from pydantic import model_validator from lean_spec.config import LEAN_ENV +from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.xmss.target_sum import ( PROD_TARGET_SUM_ENCODER, TEST_TARGET_SUM_ENCODER, @@ -73,9 +74,9 @@ def _validate_strict_types(self) -> "GeneralizedXmssScheme": ) return self - def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPair: + def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: """ - Generates a new cryptographic key pair for a specified range of epochs. + Generates a new cryptographic key pair for a specified range of slots. This is a **randomized** algorithm that establishes a signer's identity using the memory-efficient Top-Bottom Tree Traversal approach. @@ -85,14 +86,14 @@ def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPai 1. **Expand Activation Time**: Align the requested activation interval to `sqrt(LIFETIME)` boundaries to enable efficient tree partitioning. This ensures the interval starts at a multiple of `sqrt(LIFETIME)` and - has a minimum duration of `2 * sqrt(LIFETIME)` epochs. + has a minimum duration of `2 * sqrt(LIFETIME)` slots. 2. **Generate Master Secrets**: Generate PRF key and public parameter `P`. The PRF key allows deterministic on-demand regeneration of one-time keys. 3. **Generate First Two Bottom Trees**: Create the first two bottom trees - (covering the initial `2 * sqrt(LIFETIME)` epochs) and keep them in memory. - Each bottom tree covers `sqrt(LIFETIME)` consecutive epochs. + (covering the initial `2 * sqrt(LIFETIME)` slots) and keep them in memory. + Each bottom tree covers `sqrt(LIFETIME)` consecutive slots. 4. **Generate Remaining Bottom Tree Roots**: For all other bottom trees in the range, generate only their roots (not the full trees). This saves @@ -107,21 +108,21 @@ def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPai Traditional approach: O(LIFETIME) memory Top-Bottom approach: O(sqrt(LIFETIME)) memory - For LOG_LIFETIME=32 (2^32 epochs): + For LOG_LIFETIME=32 (2^32 slots): - Traditional: ~hundreds of GiB - Top-Bottom: much more reasonable Args: - activation_epoch: The starting epoch for which this key is valid. + activation_slot: The starting slot for which this key is valid. - Will be aligned downward to `sqrt(LIFETIME)` boundary. - num_active_epochs: The number of consecutive epochs the key can be used for. + num_active_slots: The number of consecutive slots the key can be used for. - Will be rounded up to at least `2 * sqrt(LIFETIME)`. Returns: A `KeyPair` containing the public and secret keys. Note: - The actual activation epoch and num_active_epochs in the returned SecretKey + The actual activation slot and num_active_slots in the returned SecretKey may be larger than requested due to alignment requirements. For the formal specification of this process, please refer to: @@ -133,7 +134,7 @@ def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPai config = self.config # Ensure the requested activation range is within the scheme's total supported lifetime. - if activation_epoch + num_active_epochs > config.LIFETIME: + if activation_slot + num_active_slots > config.LIFETIME: raise ValueError("Activation range exceeds the key's lifetime.") # Generate the random public parameter `P` and the master PRF key. @@ -144,15 +145,15 @@ def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPai # Step 1: Expand and align activation time to sqrt(LIFETIME) boundaries. start_bottom_tree_index, end_bottom_tree_index = expand_activation_time( - config.LOG_LIFETIME, int(activation_epoch), int(num_active_epochs) + config.LOG_LIFETIME, int(activation_slot), int(num_active_slots) ) num_bottom_trees = end_bottom_tree_index - start_bottom_tree_index leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - # Calculate the actual (expanded) activation epoch and count. - actual_activation_epoch = start_bottom_tree_index * leaves_per_bottom_tree - actual_num_active_epochs = num_bottom_trees * leaves_per_bottom_tree + # Calculate the actual (expanded) activation slot and count. + actual_activation_slot = start_bottom_tree_index * leaves_per_bottom_tree + actual_num_active_slots = num_bottom_trees * leaves_per_bottom_tree # Step 2: Generate the first two bottom trees (kept in memory). left_bottom_tree = HashSubTree.from_prf_key( @@ -211,8 +212,8 @@ def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPai sk = SecretKey( prf_key=prf_key, parameter=parameter, - activation_epoch=Uint64(actual_activation_epoch), - num_active_epochs=Uint64(actual_num_active_epochs), + activation_slot=Slot(actual_activation_slot), + num_active_slots=Uint64(actual_num_active_slots), top_tree=top_tree, left_bottom_tree_index=Uint64(start_bottom_tree_index), left_bottom_tree=left_bottom_tree, @@ -220,14 +221,14 @@ def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPai ) return KeyPair(public=pk, secret=sk) - def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature: + def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: """ - Produces a digital signature for a given message at a specific epoch. + Produces a digital signature for a given message at a specific slot. This is a **deterministic** algorithm. Calling `sign` twice with the same - (sk, epoch, message) triple produces the same signature. + (sk, slot, message) triple produces the same signature. - **CRITICAL SECURITY WARNING**: A secret key for a given epoch must **NEVER** be used + **CRITICAL SECURITY WARNING**: A secret key for a given slot must **NEVER** be used to sign two different messages. Doing so would reveal parts of the secret key and allow an attacker to forge signatures. This is the fundamental security property of a synchronized (stateful) signature scheme. @@ -248,12 +249,12 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature: The collection of these intermediate hashes forms the one-time signature. 3. **Merkle Path**: The signer retrieves the Merkle authentication path for the leaf - corresponding to the current `epoch`. This path proves that the one-time public key - for this epoch is part of the main public key (the Merkle root). + corresponding to the current `slot`. This path proves that the one-time public key + for this slot is part of the main public key (the Merkle root). Args: sk: The secret key to use for signing. - epoch: The epoch for which the signature is being created. + slot: The slot for which the signature is being created. message: The message to be signed. Returns: @@ -267,24 +268,24 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature: # Retrieve the scheme's configuration parameters. config = self.config - # Verify that the secret key is currently active for the requested signing epoch. - epoch_int = int(epoch) - activation_int = int(sk.activation_epoch) - if not (activation_int <= epoch_int < activation_int + int(sk.num_active_epochs)): - raise ValueError("Key is not active for the specified epoch.") + # Verify that the secret key is currently active for the requested signing slot. + slot_int = int(slot) + activation_int = int(sk.activation_slot) + if not (activation_int <= slot_int < activation_int + int(sk.num_active_slots)): + raise ValueError("Key is not active for the specified slot.") - # Verify that the epoch is within the prepared interval (covered by loaded bottom trees). + # Verify that the slot is within the prepared interval (covered by loaded bottom trees). # - # With top-bottom tree traversal, only epochs within the prepared interval can be + # With top-bottom tree traversal, only slots within the prepared interval can be # signed without computing additional bottom trees. # - # If the epoch is outside this range, we need to slide the window forward. + # If the slot is outside this range, we need to slide the window forward. leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) prepared_start = int(sk.left_bottom_tree_index) * leaves_per_bottom_tree prepared_end = prepared_start + 2 * leaves_per_bottom_tree - if not (prepared_start <= epoch_int < prepared_end): + if not (prepared_start <= slot_int < prepared_end): raise ValueError( - f"Epoch {epoch} is outside the prepared interval " + f"Slot {slot} is outside the prepared interval " f"[{prepared_start}, {prepared_end}). " f"Call advance_preparation() to slide the window forward." ) @@ -295,12 +296,12 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature: # produces a valid codeword (i.e., one that meets the target sum constraint). # # The randomness is deterministically derived from the PRF to ensure - # that signing is reproducible for the same (sk, epoch, message). + # that signing is reproducible for the same (sk, slot, message). for attempts in range(config.MAX_TRIES): # Derive deterministic randomness `rho` from PRF using the attempt counter. - rho = self.prf.get_randomness(sk.prf_key, epoch, message, Uint64(attempts)) + rho = self.prf.get_randomness(sk.prf_key, slot, message, Uint64(attempts)) # Attempt to encode the message with the deterministic `rho`. - codeword = self.encoder.encode(sk.parameter, message, rho, epoch) + codeword = self.encoder.encode(sk.parameter, message, rho, slot) # If encoding is successful, we've found our `rho` and `codeword`. # # We can exit the loop. @@ -322,14 +323,14 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature: ots_hashes: list[HashDigestVector] = [] for chain_index, steps in enumerate(codeword): # Derive the secret start of the current chain using the master PRF key. - start_digest = self.prf.apply(sk.prf_key, epoch, Uint64(chain_index)) + start_digest = self.prf.apply(sk.prf_key, slot, Uint64(chain_index)) # Walk the hash chain for the number of `steps` specified by the # corresponding digit in the codeword. # # The result is one component of the OTS. ots_digest = self.hasher.hash_chain( parameter=sk.parameter, - epoch=epoch, + epoch=slot, chain_index=chain_index, start_step=0, num_steps=steps, @@ -337,24 +338,24 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature: ) ots_hashes.append(ots_digest) - # Retrieve the Merkle authentication path for the current epoch's leaf. + # Retrieve the Merkle authentication path for the current slot's leaf. # With top-bottom tree traversal, we use combined_path to merge paths from # the bottom tree and top tree. - # Determine which bottom tree contains this epoch (reuse leaves_per_bottom_tree from above). + # Determine which bottom tree contains this slot (reuse leaves_per_bottom_tree from above). boundary = (int(sk.left_bottom_tree_index) + 1) * leaves_per_bottom_tree - bottom_tree = sk.left_bottom_tree if epoch_int < boundary else sk.right_bottom_tree + bottom_tree = sk.left_bottom_tree if slot_int < boundary else sk.right_bottom_tree # Ensure bottom tree exists if bottom_tree is None: raise ValueError( - f"Epoch {epoch} requires bottom tree but it is not available. " + f"Slot {slot} requires bottom tree but it is not available. " f"Prepared interval may have been exceeded. Call advance_preparation() " f"to slide the window forward." ) # Generate the combined authentication path - path = combined_path(sk.top_tree, bottom_tree, epoch) + path = combined_path(sk.top_tree, bottom_tree, slot) # Assemble and return the final signature, which contains: # - The OTS, @@ -362,9 +363,9 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature: # - The randomness `rho` needed for verification. return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes)) - def verify(self, pk: PublicKey, epoch: Uint64, message: Bytes32, sig: Signature) -> bool: + def verify(self, pk: PublicKey, slot: Slot, message: Bytes32, sig: Signature) -> bool: r""" - Verifies a digital signature against a public key, message, and epoch. + Verifies a digital signature against a public key, message, and slot. This is a **deterministic** algorithm. @@ -381,7 +382,7 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: Bytes32, sig: Signature) chain's public endpoint, which is one component of the one-time public key. 3. **Compute Merkle Leaf**: The verifier hashes the full set of reconstructed - chain endpoints to compute the expected Merkle leaf for the given `epoch`. + chain endpoints to compute the expected Merkle leaf for the given slot. 4. **Verify Merkle Path**: The verifier uses the authentication `path` from the signature to compute a candidate Merkle root, starting from the leaf computed @@ -390,7 +391,7 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: Bytes32, sig: Signature) Args: pk: The public key to verify against. - epoch: The epoch the signature corresponds to. + slot: The slot the signature corresponds to. message: The message that was supposedly signed. sig: The signature object to be verified. @@ -405,17 +406,17 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: Bytes32, sig: Signature) # Retrieve the scheme's configuration parameters. config = self.config - # Validate epoch bounds. + # Validate slot bounds. # # Return False instead of raising to avoid panic on invalid signatures. - # The epoch is attacker-controlled input. - if epoch > self.config.LIFETIME: + # The slot is attacker-controlled input. + if slot > self.config.LIFETIME: return False # Re-encode the message using the randomness `rho` from the signature. # # If the encoding is invalid (e.g., fails the target sum check), the signature is invalid. - codeword = self.encoder.encode(pk.parameter, message, sig.rho, epoch) + codeword = self.encoder.encode(pk.parameter, message, sig.rho, slot) if codeword is None: return False @@ -429,7 +430,7 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: Bytes32, sig: Signature) num_steps_remaining = config.BASE - 1 - xi end_digest = self.hasher.hash_chain( parameter=pk.parameter, - epoch=epoch, + epoch=slot, chain_index=chain_index, start_step=xi, num_steps=num_steps_remaining, @@ -440,41 +441,41 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: Bytes32, sig: Signature) # Verify the Merkle path. # # This function internally: - # - Hashes the `chain_ends` to get the leaf node for the epoch, + # - Hashes the `chain_ends` to get the leaf node for the slot, # - Uses the `opening` path from the signature to compute a candidate root. # - It returns true if and only if this candidate root matches the public key's root. return verify_path( hasher=self.hasher, parameter=pk.parameter, root=pk.root, - position=epoch, + position=slot, leaf_parts=chain_ends, opening=sig.path, ) def get_activation_interval(self, sk: SecretKey) -> range: """ - Returns the epoch range for which this secret key is active. + Returns the slot range for which this secret key is active. - The activation interval is `[activation_epoch, activation_epoch + num_active_epochs)`. - A signature can only be created for an epoch within this range. + The activation interval is `[activation_slot, activation_slot + num_active_slots)`. + A signature can only be created for a slot within this range. Args: sk: The secret key to query. Returns: - A Python range object representing the valid epoch range. + A Python range object representing the valid slot range. """ - start = int(sk.activation_epoch) - end = start + int(sk.num_active_epochs) + start = int(sk.activation_slot) + end = start + int(sk.num_active_slots) return range(start, end) def get_prepared_interval(self, sk: SecretKey) -> range: """ - Returns the epoch range currently prepared (covered by loaded bottom trees). + Returns the slot range currently prepared (covered by loaded bottom trees). With top-bottom tree traversal, a secret key maintains a sliding window of - two consecutive bottom trees. This method returns the range of epochs that + two consecutive bottom trees. This method returns the range of slots that can be signed with the currently loaded trees, without needing to compute additional bottom trees. @@ -485,7 +486,7 @@ def get_prepared_interval(self, sk: SecretKey) -> range: sk: The secret key to query. Returns: - A Python range object representing the prepared epoch range. + A Python range object representing the prepared slot range. Raises: ValueError: If the secret key is missing top-bottom tree structures. @@ -505,10 +506,10 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: 3. The newly computed tree becomes the new right tree 4. Increments `left_bottom_tree_index` - After this operation, the prepared interval moves forward by `sqrt(LIFETIME)` epochs. + After this operation, the prepared interval moves forward by `sqrt(LIFETIME)` slots. - **When to call**: Call this method after signing with an epoch that is in the - right half of the prepared interval, to ensure the next epoch range is ready. + **When to call**: Call this method after signing with a slot that is in the + right half of the prepared interval, to ensure the next slot range is ready. Args: sk: The secret key to advance. @@ -523,9 +524,9 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: left_index = int(sk.left_bottom_tree_index) # Check if advancing would exceed the activation interval - next_prepared_end_epoch = (left_index + 3) * leaves_per_bottom_tree - activation_end = int(sk.activation_epoch) + int(sk.num_active_epochs) - if next_prepared_end_epoch > activation_end: + next_prepared_end_slot = (left_index + 3) * leaves_per_bottom_tree + activation_end = int(sk.activation_slot) + int(sk.num_active_slots) + if next_prepared_end_slot > activation_end: # Nothing to do - we're already at the end of the activation interval return sk diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index 35bdd745..cdde32ed 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -72,7 +72,7 @@ def int_to_base_p(value: int, num_limbs: int) -> list[Fp]: def expand_activation_time( - log_lifetime: int, desired_activation_epoch: int, desired_num_active_epochs: int + log_lifetime: int, desired_activation_slot: int, desired_num_active_slots: int ) -> tuple[int, int]: """ Expands and aligns the activation time to top-bottom tree boundaries. @@ -81,9 +81,9 @@ def expand_activation_time( `sqrt(LIFETIME)` boundaries. This function takes the user's desired activation interval and expands it to meet the following requirements: - 1. **Start alignment**: Start epoch is rounded down to a multiple of `sqrt(LIFETIME)` - 2. **End alignment**: End epoch is rounded up to a multiple of `sqrt(LIFETIME)` - 3. **Minimum duration**: At least `2 * sqrt(LIFETIME)` epochs (two bottom trees) + 1. **Start alignment**: Start slot is rounded down to a multiple of `sqrt(LIFETIME)` + 2. **End alignment**: End slot is rounded up to a multiple of `sqrt(LIFETIME)` + 3. **Minimum duration**: At least `2 * sqrt(LIFETIME)` slots (two bottom trees) 4. **Lifetime bounds**: Clamped to `[0, LIFETIME)` ### Algorithm @@ -98,32 +98,32 @@ def expand_activation_time( ### Example For `LOG_LIFETIME = 32` (LIFETIME = 2^32, C = 2^16 = 65536): - - Request: epochs [10000, 80000) → 70000 epochs - - Aligned: epochs [0, 131072) → 131072 epochs = 2 bottom trees + - Request: slots [10000, 80000) → 70000 slots + - Aligned: slots [0, 131072) → 131072 slots = 2 bottom trees Args: log_lifetime: The logarithm (base 2) of the total lifetime. - desired_activation_epoch: The user's requested first epoch. - desired_num_active_epochs: The user's requested number of epochs. + desired_activation_slot: The user's requested first slot. + desired_num_active_slots: The user's requested number of slots. Returns: A tuple `(start_bottom_tree_index, end_bottom_tree_index)` where: - `start_bottom_tree_index`: Index of the first bottom tree (0, 1, 2, ...) - `end_bottom_tree_index`: Index past the last bottom tree (exclusive) - - Actual epochs: `[start_index * C, end_index * C)` + - Actual slots: `[start_index * C, end_index * C)` """ # Calculate sqrt(LIFETIME) and the alignment mask. c = 1 << (log_lifetime // 2) # C = 2^(LOG_LIFETIME/2) c_mask = ~(c - 1) # Mask for rounding to multiples of C - # Calculate the desired end epoch. - desired_end_epoch = desired_activation_epoch + desired_num_active_epochs + # Calculate the desired end slot. + desired_end_slot = desired_activation_slot + desired_num_active_slots # Step 1: Align start downward to a multiple of C. - start = desired_activation_epoch & c_mask + start = desired_activation_slot & c_mask # Step 2: Round end upward to a multiple of C. - end = (desired_end_epoch + c - 1) & c_mask + end = (desired_end_slot + c - 1) & c_mask # Step 3: Enforce minimum duration of 2*C. if end - start < 2 * c: @@ -146,7 +146,7 @@ def expand_activation_time( start = (lifetime - duration) & c_mask # Keep alignment # Convert to bottom tree indices. - # Bottom tree i covers epochs [i*C, (i+1)*C). + # Bottom tree i covers slots [i*C, (i+1)*C). start_bottom_tree_index = start // c end_bottom_tree_index = end // c diff --git a/tests/consensus/devnet/ssz/test_xmss_containers.py b/tests/consensus/devnet/ssz/test_xmss_containers.py index 16b0cd02..06564b78 100644 --- a/tests/consensus/devnet/ssz/test_xmss_containers.py +++ b/tests/consensus/devnet/ssz/test_xmss_containers.py @@ -4,6 +4,7 @@ from consensus_testing import SSZTestFiller from lean_spec.subspecs.containers.attestation import AggregationBits +from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.koalabear import Fp from lean_spec.subspecs.xmss import PublicKey, SecretKey, Signature from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof @@ -87,8 +88,8 @@ def test_secret_key_minimal(ssz: SSZTestFiller) -> None: value=SecretKey( prf_key=PRFKey.zero(), parameter=_zero_parameter(), - activation_epoch=Uint64(0), - num_active_epochs=Uint64(1), + activation_slot=Slot(0), + num_active_slots=Uint64(1), top_tree=empty_subtree, left_bottom_tree_index=Uint64(0), left_bottom_tree=empty_subtree, diff --git a/tests/interop/helpers/node_runner.py b/tests/interop/helpers/node_runner.py index c0c9e5d5..5ee9dc1e 100644 --- a/tests/interop/helpers/node_runner.py +++ b/tests/interop/helpers/node_runner.py @@ -14,6 +14,7 @@ from lean_spec.subspecs.chain.config import ATTESTATION_COMMITTEE_COUNT from lean_spec.subspecs.containers import Checkpoint, Validator +from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import Validators from lean_spec.subspecs.containers.validator import ValidatorIndex from lean_spec.subspecs.forkchoice import Store @@ -187,7 +188,7 @@ class NodeCluster: _validators: Validators | None = field(default=None, repr=False) """Shared validator set.""" - _secret_keys: dict[int, SecretKey] = field(default_factory=dict, repr=False) + _secret_keys: dict[ValidatorIndex, SecretKey] = field(default_factory=dict, repr=False) """Secret keys by validator index.""" _genesis_time: int = field(default=0, repr=False) @@ -208,15 +209,15 @@ def _generate_validators(self) -> None: validators: list[Validator] = [] scheme = TARGET_SIGNATURE_SCHEME - # Use a number of active epochs within the scheme's lifetime. + # Use a number of active slots within the scheme's lifetime. # TEST_CONFIG has LOG_LIFETIME=8 -> lifetime=256. # PROD_CONFIG has LOG_LIFETIME=32 -> lifetime=2^32. - # Use the full lifetime to avoid exhausting prepared epochs during tests. - num_active_epochs = int(scheme.config.LIFETIME) + # Use the full lifetime to avoid exhausting prepared slots during tests. + num_active_slots = int(scheme.config.LIFETIME) for i in range(self.num_validators): - keypair = scheme.key_gen(Uint64(0), Uint64(num_active_epochs)) - self._secret_keys[i] = keypair.secret + keypair = scheme.key_gen(Slot(0), Uint64(num_active_slots)) + self._secret_keys[ValidatorIndex(i)] = keypair.secret pubkey_bytes = keypair.public.encode_bytes()[:52] pubkey = Bytes52(pubkey_bytes.ljust(52, b"\x00")) @@ -233,7 +234,7 @@ def _generate_validators(self) -> None: async def start_node( self, node_index: int, - validator_indices: list[int] | None = None, + validator_indices: list[ValidatorIndex] | None = None, is_aggregator: bool = False, bootnodes: list[str] | None = None, *, @@ -395,7 +396,7 @@ async def start_node( async def start_all( self, topology: list[tuple[int, int]], - validators_per_node: list[list[int]] | None = None, + validators_per_node: list[list[ValidatorIndex]] | None = None, ) -> None: """ Start multiple nodes with given topology. @@ -480,7 +481,7 @@ async def start_all( for node in self.nodes: await node.start() - def _distribute_validators(self, num_nodes: int) -> list[list[int]]: + def _distribute_validators(self, num_nodes: int) -> list[list[ValidatorIndex]]: """ Distribute validators evenly across nodes. @@ -493,9 +494,9 @@ def _distribute_validators(self, num_nodes: int) -> list[list[int]]: if num_nodes == 0: return [] - distribution: list[list[int]] = [[] for _ in range(num_nodes)] + distribution: list[list[ValidatorIndex]] = [[] for _ in range(num_nodes)] for i in range(self.num_validators): - distribution[i % num_nodes].append(i) + distribution[i % num_nodes].append(ValidatorIndex(i)) return distribution diff --git a/tests/interop/test_consensus_lifecycle.py b/tests/interop/test_consensus_lifecycle.py index 661be186..ab0671a1 100644 --- a/tests/interop/test_consensus_lifecycle.py +++ b/tests/interop/test_consensus_lifecycle.py @@ -19,6 +19,8 @@ import pytest +from lean_spec.subspecs.containers.validator import ValidatorIndex + from .helpers import ( NodeCluster, PipelineDiagnostics, @@ -77,7 +79,11 @@ async def test_consensus_lifecycle(node_cluster: NodeCluster) -> None: # One validator per node. Isolating validators ensures each node # proposes independently and attestations travel over the network. - validators_per_node = [[0], [1], [2]] + validators_per_node = [ + [ValidatorIndex(0)], + [ValidatorIndex(1)], + [ValidatorIndex(2)], + ] await node_cluster.start_all(topology, validators_per_node) diff --git a/tests/lean_spec/helpers/builders.py b/tests/lean_spec/helpers/builders.py index e03c315f..2e3dd12d 100644 --- a/tests/lean_spec/helpers/builders.py +++ b/tests/lean_spec/helpers/builders.py @@ -11,7 +11,7 @@ from consensus_testing.keys import XmssKeyManager, get_shared_key_manager -from lean_spec.subspecs.chain.clock import SlotClock +from lean_spec.subspecs.chain.clock import Interval, SlotClock from lean_spec.subspecs.chain.config import INTERVALS_PER_SLOT from lean_spec.subspecs.containers import ( Attestation, @@ -263,7 +263,7 @@ def make_signed_block( def make_aggregated_attestation( - participant_ids: list[int], + participant_ids: list[ValidatorIndex], attestation_slot: Slot, source: Checkpoint, target: Checkpoint, @@ -281,9 +281,7 @@ def make_aggregated_attestation( ) return AggregatedAttestation( - aggregation_bits=AggregationBits.from_validator_indices( - [ValidatorIndex(i) for i in participant_ids] - ), + aggregation_bits=AggregationBits.from_validator_indices(participant_ids), data=data, ) @@ -433,16 +431,16 @@ def make_store_with_gossip_signatures( def make_attestation_data_simple( - slot: int, + slot: Slot, head_root: Bytes32, target_root: Bytes32, source: Checkpoint, ) -> AttestationData: """Create attestation data with head/target roots and a source checkpoint.""" return AttestationData( - slot=Slot(slot), - head=Checkpoint(root=head_root, slot=Slot(slot)), - target=Checkpoint(root=target_root, slot=Slot(slot)), + slot=slot, + head=Checkpoint(root=head_root, slot=slot), + target=Checkpoint(root=target_root, slot=slot), source=source, ) @@ -472,7 +470,7 @@ def make_aggregated_proof( key_manager.sign_attestation_data(vid, attestation_data) for vid in participants ], message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) @@ -518,7 +516,7 @@ def make_signed_block_from_store( ), ) - target_interval = block.slot * INTERVALS_PER_SLOT + target_interval = Interval(block.slot * INTERVALS_PER_SLOT) advanced_store, _ = store.on_tick(target_interval, has_proposal=True) return advanced_store, signed_block diff --git a/tests/lean_spec/subspecs/containers/test_state_aggregation.py b/tests/lean_spec/subspecs/containers/test_state_aggregation.py index cb90dd58..aa4153ae 100644 --- a/tests/lean_spec/subspecs/containers/test_state_aggregation.py +++ b/tests/lean_spec/subspecs/containers/test_state_aggregation.py @@ -26,7 +26,9 @@ def test_aggregated_signatures_prefers_full_gossip_payload( ) -> None: state = make_keyed_genesis_state(2, container_key_manager) source = Checkpoint(root=make_bytes32(1), slot=Slot(0)) - att_data = make_attestation_data_simple(2, make_bytes32(3), make_bytes32(4), source=source) + att_data = make_attestation_data_simple( + Slot(2), make_bytes32(3), make_bytes32(4), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(2)] data_root = att_data.data_root_bytes() gossip_signatures = { @@ -56,7 +58,7 @@ def test_aggregated_signatures_prefers_full_gossip_payload( aggregated_proofs[0].verify( public_keys=public_keys, message=data_root, - epoch=att_data.slot, + slot=att_data.slot, ) @@ -66,7 +68,9 @@ def test_aggregate_signatures_splits_when_needed( """Test that gossip and aggregated proofs are kept separate.""" state = make_keyed_genesis_state(3, container_key_manager) source = Checkpoint(root=make_bytes32(2), slot=Slot(0)) - att_data = make_attestation_data_simple(3, make_bytes32(5), make_bytes32(6), source=source) + att_data = make_attestation_data_simple( + Slot(3), make_bytes32(5), make_bytes32(6), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(3)] data_root = att_data.data_root_bytes() gossip_signatures = { @@ -113,7 +117,7 @@ def test_aggregate_signatures_splits_when_needed( proof.verify( public_keys=[container_key_manager.get_public_key(ValidatorIndex(0))], message=data_root, - epoch=att_data.slot, + slot=att_data.slot, ) @@ -163,7 +167,7 @@ def test_build_block_collects_valid_available_attestations( aggregated_proofs[0].verify( public_keys=[container_key_manager.get_public_key(ValidatorIndex(0))], message=data_root, - epoch=att_data.slot, + slot=att_data.slot, ) @@ -222,8 +226,12 @@ def test_aggregated_signatures_with_multiple_data_groups( """Multiple attestation data groups should be processed independently.""" state = make_keyed_genesis_state(4, container_key_manager) source = Checkpoint(root=make_bytes32(22), slot=Slot(0)) - att_data1 = make_attestation_data_simple(9, make_bytes32(23), make_bytes32(24), source=source) - att_data2 = make_attestation_data_simple(10, make_bytes32(25), make_bytes32(26), source=source) + att_data1 = make_attestation_data_simple( + Slot(9), make_bytes32(23), make_bytes32(24), source=source + ) + att_data2 = make_attestation_data_simple( + Slot(10), make_bytes32(25), make_bytes32(26), source=source + ) attestations = [ Attestation(validator_id=ValidatorIndex(0), data=att_data1), @@ -268,7 +276,7 @@ def test_aggregated_signatures_with_multiple_data_groups( proof.verify( public_keys=public_keys, message=agg_att.data.data_root_bytes(), - epoch=agg_att.data.slot, + slot=agg_att.data.slot, ) @@ -278,7 +286,9 @@ def test_aggregated_signatures_falls_back_to_block_payload( """Should fall back to block payload when gossip is incomplete.""" state = make_keyed_genesis_state(2, container_key_manager) source = Checkpoint(root=make_bytes32(27), slot=Slot(0)) - att_data = make_attestation_data_simple(11, make_bytes32(28), make_bytes32(29), source=source) + att_data = make_attestation_data_simple( + Slot(11), make_bytes32(28), make_bytes32(29), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(2)] data_root = att_data.data_root_bytes() @@ -320,7 +330,7 @@ def test_aggregated_signatures_falls_back_to_block_payload( proof.verify( public_keys=[container_key_manager.get_public_key(ValidatorIndex(0))], message=data_root, - epoch=att_data.slot, + slot=att_data.slot, ) @@ -433,7 +443,9 @@ def test_greedy_selects_proof_with_maximum_overlap( """ state = make_keyed_genesis_state(4, container_key_manager) source = Checkpoint(root=make_bytes32(60), slot=Slot(0)) - att_data = make_attestation_data_simple(12, make_bytes32(61), make_bytes32(62), source=source) + att_data = make_attestation_data_simple( + Slot(12), make_bytes32(61), make_bytes32(62), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(4)] data_root = att_data.data_root_bytes() @@ -489,7 +501,9 @@ def test_greedy_stops_when_no_useful_proofs_remain( """ state = make_keyed_genesis_state(5, container_key_manager) source = Checkpoint(root=make_bytes32(70), slot=Slot(0)) - att_data = make_attestation_data_simple(13, make_bytes32(71), make_bytes32(72), source=source) + att_data = make_attestation_data_simple( + Slot(13), make_bytes32(71), make_bytes32(72), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(5)] data_root = att_data.data_root_bytes() @@ -558,7 +572,9 @@ def test_greedy_handles_overlapping_proof_chains( """ state = make_keyed_genesis_state(5, container_key_manager) source = Checkpoint(root=make_bytes32(80), slot=Slot(0)) - att_data = make_attestation_data_simple(14, make_bytes32(81), make_bytes32(82), source=source) + att_data = make_attestation_data_simple( + Slot(14), make_bytes32(81), make_bytes32(82), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(5)] data_root = att_data.data_root_bytes() @@ -626,7 +642,9 @@ def test_greedy_single_validator_proofs( """ state = make_keyed_genesis_state(3, container_key_manager) source = Checkpoint(root=make_bytes32(90), slot=Slot(0)) - att_data = make_attestation_data_simple(15, make_bytes32(91), make_bytes32(92), source=source) + att_data = make_attestation_data_simple( + Slot(15), make_bytes32(91), make_bytes32(92), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(3)] data_root = att_data.data_root_bytes() @@ -683,7 +701,9 @@ def test_validator_in_both_gossip_and_fallback_proof( """ state = make_keyed_genesis_state(2, container_key_manager) source = Checkpoint(root=make_bytes32(100), slot=Slot(0)) - att_data = make_attestation_data_simple(16, make_bytes32(101), make_bytes32(102), source=source) + att_data = make_attestation_data_simple( + Slot(16), make_bytes32(101), make_bytes32(102), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(2)] data_root = att_data.data_root_bytes() @@ -726,7 +746,7 @@ def test_validator_in_both_gossip_and_fallback_proof( for proof in aggregated_proofs: participants = proof.participants.to_validator_indices() public_keys = [container_key_manager.get_public_key(vid) for vid in participants] - proof.verify(public_keys=public_keys, message=data_root, epoch=att_data.slot) + proof.verify(public_keys=public_keys, message=data_root, slot=att_data.slot) def test_gossip_none_and_aggregated_payloads_none( @@ -741,7 +761,9 @@ def test_gossip_none_and_aggregated_payloads_none( """ state = make_keyed_genesis_state(2, container_key_manager) source = Checkpoint(root=make_bytes32(110), slot=Slot(0)) - att_data = make_attestation_data_simple(17, make_bytes32(111), make_bytes32(112), source=source) + att_data = make_attestation_data_simple( + Slot(17), make_bytes32(111), make_bytes32(112), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(2)] results = state.aggregate_gossip_signatures( @@ -770,7 +792,9 @@ def test_aggregated_payloads_only_no_gossip( """ state = make_keyed_genesis_state(3, container_key_manager) source = Checkpoint(root=make_bytes32(120), slot=Slot(0)) - att_data = make_attestation_data_simple(18, make_bytes32(121), make_bytes32(122), source=source) + att_data = make_attestation_data_simple( + Slot(18), make_bytes32(121), make_bytes32(122), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(3)] data_root = att_data.data_root_bytes() @@ -794,7 +818,7 @@ def test_aggregated_payloads_only_no_gossip( assert participants == {0, 1, 2} public_keys = [container_key_manager.get_public_key(ValidatorIndex(i)) for i in range(3)] - aggregated_proofs[0].verify(public_keys=public_keys, message=data_root, epoch=att_data.slot) + aggregated_proofs[0].verify(public_keys=public_keys, message=data_root, slot=att_data.slot) def test_proof_with_extra_validators_beyond_needed( @@ -818,7 +842,9 @@ def test_proof_with_extra_validators_beyond_needed( """ state = make_keyed_genesis_state(4, container_key_manager) source = Checkpoint(root=make_bytes32(130), slot=Slot(0)) - att_data = make_attestation_data_simple(19, make_bytes32(131), make_bytes32(132), source=source) + att_data = make_attestation_data_simple( + Slot(19), make_bytes32(131), make_bytes32(132), source=source + ) attestations = [Attestation(validator_id=ValidatorIndex(i), data=att_data) for i in range(2)] data_root = att_data.data_root_bytes() diff --git a/tests/lean_spec/subspecs/containers/test_state_justified_slots.py b/tests/lean_spec/subspecs/containers/test_state_justified_slots.py index 36ecc1ed..066b82c3 100644 --- a/tests/lean_spec/subspecs/containers/test_state_justified_slots.py +++ b/tests/lean_spec/subspecs/containers/test_state_justified_slots.py @@ -15,6 +15,7 @@ JustificationRoots, JustificationValidators, ) +from lean_spec.subspecs.containers.validator import ValidatorIndex from lean_spec.types import Boolean from tests.lean_spec.helpers import make_aggregated_attestation, make_block, make_genesis_state @@ -56,7 +57,7 @@ def test_justified_slots_rebases_when_finalization_advances() -> None: source_0 = Checkpoint(root=block_1.parent_root, slot=Slot(0)) target_1 = Checkpoint(root=block_2.parent_root, slot=Slot(1)) att_0_to_1 = make_aggregated_attestation( - participant_ids=[0, 1], + participant_ids=[ValidatorIndex(0), ValidatorIndex(1)], attestation_slot=Slot(2), source=source_0, target=target_1, @@ -72,7 +73,7 @@ def test_justified_slots_rebases_when_finalization_advances() -> None: source_1 = Checkpoint(root=block_2.parent_root, slot=Slot(1)) target_2 = Checkpoint(root=block_3.parent_root, slot=Slot(2)) att_1_to_2 = make_aggregated_attestation( - participant_ids=[0, 1], + participant_ids=[ValidatorIndex(0), ValidatorIndex(1)], attestation_slot=Slot(3), source=source_1, target=target_2, @@ -127,7 +128,7 @@ def test_pruning_keeps_pending_justifications() -> None: source_0 = Checkpoint(root=block_1.parent_root, slot=Slot(0)) target_1 = Checkpoint(root=block_2.parent_root, slot=Slot(1)) att_0_to_1 = make_aggregated_attestation( - participant_ids=[0, 1], + participant_ids=[ValidatorIndex(0), ValidatorIndex(1)], attestation_slot=Slot(2), source=source_0, target=target_1, @@ -178,7 +179,7 @@ def test_pruning_keeps_pending_justifications() -> None: source_1 = Checkpoint(root=state.historical_block_hashes[1], slot=Slot(1)) target_2 = Checkpoint(root=state.historical_block_hashes[2], slot=Slot(2)) att_1_to_2 = make_aggregated_attestation( - participant_ids=[0, 1], + participant_ids=[ValidatorIndex(0), ValidatorIndex(1)], attestation_slot=Slot(5), source=source_1, target=target_2, diff --git a/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py b/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py index 7ed8f968..4a06e3af 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py +++ b/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py @@ -5,6 +5,7 @@ import pytest from consensus_testing.keys import XmssKeyManager +from lean_spec.subspecs.chain.clock import Interval from lean_spec.subspecs.chain.config import ( INTERVALS_PER_SLOT, JUSTIFICATION_LOOKBACK_SLOTS, @@ -588,7 +589,7 @@ def test_attestation_target_after_on_block( # Process block via on_block on a fresh consumer store consumer_store = observer_store - target_interval = block.slot * INTERVALS_PER_SLOT + target_interval = Interval(block.slot * INTERVALS_PER_SLOT) consumer_store, _ = consumer_store.on_tick(target_interval, has_proposal=True) consumer_store = consumer_store.on_block(signed_block) diff --git a/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py index ec1a5868..e2cd1e28 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py +++ b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py @@ -342,7 +342,7 @@ def test_valid_proof_stored_correctly(self, key_manager: XmssKeyManager) -> None key_manager.sign_attestation_data(vid, attestation_data) for vid in participants ], message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) signed_aggregated = SignedAggregatedAttestation( @@ -383,7 +383,7 @@ def test_attestation_data_stored_by_root(self, key_manager: XmssKeyManager) -> N key_manager.sign_attestation_data(vid, attestation_data) for vid in participants ], message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) signed_aggregated = SignedAggregatedAttestation( @@ -419,7 +419,7 @@ def test_invalid_proof_rejected(self, key_manager: XmssKeyManager) -> None: key_manager.sign_attestation_data(vid, attestation_data) for vid in actual_signers ], message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) signed_aggregated = SignedAggregatedAttestation( @@ -452,7 +452,7 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager) -> None: key_manager.sign_attestation_data(vid, attestation_data) for vid in participants_1 ], message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) # Second proof: validators 1 and 3 (validator 1 overlaps) @@ -464,7 +464,7 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager) -> None: key_manager.sign_attestation_data(vid, attestation_data) for vid in participants_2 ], message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) store = store.on_gossip_aggregated_attestation( @@ -551,7 +551,7 @@ def test_aggregated_proof_is_valid(self, key_manager: XmssKeyManager) -> None: proof.verify( public_keys=public_keys, message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) def test_empty_gossip_signatures_produces_no_proofs(self, key_manager: XmssKeyManager) -> None: @@ -829,5 +829,5 @@ def test_gossip_to_aggregation_to_storage(self, key_manager: XmssKeyManager) -> proof.verify( public_keys=public_keys, message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) diff --git a/tests/lean_spec/subspecs/forkchoice/test_time_management.py b/tests/lean_spec/subspecs/forkchoice/test_time_management.py index f2c6a2c1..2adbd852 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_time_management.py +++ b/tests/lean_spec/subspecs/forkchoice/test_time_management.py @@ -3,6 +3,7 @@ from hypothesis import given, settings from hypothesis import strategies as st +from lean_spec.subspecs.chain.clock import Interval from lean_spec.subspecs.chain.config import ( INTERVALS_PER_SLOT, MILLISECONDS_PER_INTERVAL, @@ -48,7 +49,7 @@ def test_store_time_from_anchor_slot(self, anchor_slot: int) -> None: validator_id=TEST_VALIDATOR_ID, ) - assert store.time == INTERVALS_PER_SLOT * Uint64(anchor_slot) + assert store.time == Interval(int(INTERVALS_PER_SLOT) * anchor_slot) class TestOnTick: @@ -58,7 +59,7 @@ def test_on_tick_basic(self, sample_store: Store) -> None: """Test basic on_tick.""" initial_time = sample_store.time # 200 seconds = 200*1000/800 = 250 intervals - target_interval = Uint64(200) * Uint64(1000) // MILLISECONDS_PER_INTERVAL + target_interval = Interval(200 * 1000 // int(MILLISECONDS_PER_INTERVAL)) sample_store, _ = sample_store.on_tick(target_interval, has_proposal=True) @@ -69,7 +70,7 @@ def test_on_tick_no_proposal(self, sample_store: Store) -> None: """Test on_tick without proposal.""" initial_time = sample_store.time # 100 seconds = 125 intervals - target_interval = Uint64(100) * Uint64(1000) // MILLISECONDS_PER_INTERVAL + target_interval = Interval(100 * 1000 // int(MILLISECONDS_PER_INTERVAL)) sample_store, _ = sample_store.on_tick(target_interval, has_proposal=False) @@ -80,7 +81,7 @@ def test_on_tick_already_current(self, sample_store: Store) -> None: """Test on_tick when already at target time (should be no-op).""" initial_time = sample_store.time - sample_store, _ = sample_store.on_tick(initial_time, has_proposal=True) + sample_store, _ = sample_store.on_tick(Interval(initial_time), has_proposal=True) # No-op: target equals current time assert sample_store.time == initial_time @@ -88,7 +89,7 @@ def test_on_tick_already_current(self, sample_store: Store) -> None: def test_on_tick_small_increment(self, sample_store: Store) -> None: """Test on_tick with small interval increment.""" initial_time = sample_store.time - target_interval = initial_time + Uint64(1) + target_interval = Interval(int(initial_time) + 1) sample_store, _ = sample_store.on_tick(target_interval, has_proposal=False) diff --git a/tests/lean_spec/subspecs/forkchoice/test_validator.py b/tests/lean_spec/subspecs/forkchoice/test_validator.py index cb5902c0..57824518 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_validator.py +++ b/tests/lean_spec/subspecs/forkchoice/test_validator.py @@ -3,6 +3,7 @@ import pytest from consensus_testing.keys import XmssKeyManager +from lean_spec.subspecs.chain.clock import Interval from lean_spec.subspecs.containers import ( Attestation, AttestationData, @@ -131,7 +132,7 @@ def test_produce_block_with_attestations( proof.verify( public_keys=public_keys, message=agg_att.data.data_root_bytes(), - epoch=agg_att.data.slot, + slot=agg_att.data.slot, ) def test_produce_block_sequential_slots(self, sample_store: Store) -> None: @@ -248,7 +249,7 @@ def test_produce_block_state_consistency( proof.verify( public_keys=public_keys, message=agg_att.data.data_root_bytes(), - epoch=agg_att.data.slot, + slot=agg_att.data.slot, ) @@ -386,7 +387,7 @@ def test_produce_block_missing_parent_state(self) -> None: # Create store with missing parent state store = Store( - time=Uint64(100), + time=Interval(100), config=config, head=Bytes32(b"nonexistent" + b"\x00" * 21), safe_target=Bytes32(b"nonexistent" + b"\x00" * 21), diff --git a/tests/lean_spec/subspecs/networking/discovery/conftest.py b/tests/lean_spec/subspecs/networking/discovery/conftest.py index 293624ae..84aa2325 100644 --- a/tests/lean_spec/subspecs/networking/discovery/conftest.py +++ b/tests/lean_spec/subspecs/networking/discovery/conftest.py @@ -5,8 +5,8 @@ import pytest from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId -from lean_spec.types import Bytes64, Uint64 +from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes64 # From devp2p test vectors NODE_A_PRIVKEY = bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") @@ -46,7 +46,7 @@ def local_enr() -> ENR: """Minimal local ENR for testing.""" return ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": NODE_B_PUBKEY, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py index 8e9ece32..63c19809 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py @@ -26,8 +26,8 @@ ) from lean_spec.subspecs.networking.discovery.session import SessionCache from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId -from lean_spec.types import Bytes32, Bytes33, Bytes64, Uint64 +from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes32, Bytes33, Bytes64 from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY @@ -62,7 +62,7 @@ def manager(local_keypair, session_cache): local_node_id=node_id, local_private_key=priv, local_enr_rlp=b"mock_enr", - local_enr_seq=1, + local_enr_seq=SeqNumber(1), session_cache=session_cache, ) @@ -144,7 +144,7 @@ def test_create_whoareyou(self, manager): """Test creating a WHOAREYOU challenge.""" remote_node_id = bytes(32) request_nonce = bytes(12) - remote_enr_seq = 0 + remote_enr_seq = SeqNumber(0) masking_iv = bytes(16) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( @@ -157,7 +157,7 @@ def test_create_whoareyou(self, manager): # Verify authdata decodes correctly decoded = decode_whoareyou_authdata(authdata) assert bytes(decoded.id_nonce) == id_nonce - assert int(decoded.enr_seq) == remote_enr_seq + assert decoded.enr_seq == remote_enr_seq # Verify challenge_data structure: masking-iv || static-header || authdata # masking-iv (16) + static-header (23) + authdata (24) = 63 bytes @@ -196,7 +196,7 @@ def test_invalid_local_node_id_raises(self): local_node_id=bytes(31), # type: ignore[arg-type] local_private_key=bytes(32), local_enr_rlp=b"enr", - local_enr_seq=1, + local_enr_seq=SeqNumber(1), session_cache=SessionCache(), ) @@ -207,7 +207,7 @@ def test_invalid_local_private_key_raises(self): local_node_id=NodeId(bytes(32)), local_private_key=bytes(31), local_enr_rlp=b"enr", - local_enr_seq=1, + local_enr_seq=SeqNumber(1), session_cache=SessionCache(), ) @@ -271,7 +271,7 @@ def test_create_whoareyou_transitions_to_sent_whoareyou(self, manager): """ remote_node_id = bytes(32) request_nonce = bytes(12) - remote_enr_seq = 0 + remote_enr_seq = SeqNumber(0) masking_iv = bytes(16) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( @@ -289,7 +289,7 @@ def test_sent_whoareyou_state_has_challenge_data(self, manager): """In SENT_WHOAREYOU state, all challenge data is stored.""" remote_node_id = bytes(32) request_nonce = bytes(12) - remote_enr_seq = 5 + remote_enr_seq = SeqNumber(5) masking_iv = bytes(16) manager.create_whoareyou(remote_node_id, request_nonce, remote_enr_seq, masking_iv) @@ -367,7 +367,7 @@ def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair) manager.create_whoareyou( NodeId(remote_node_id), bytes(12), - 0, + SeqNumber(0), bytes(16), ) @@ -398,7 +398,7 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa manager.create_whoareyou( NodeId(remote_node_id), bytes(12), - 0, # enr_seq = 0 means we don't know remote's ENR + SeqNumber(0), # enr_seq = 0 means we don't know remote's ENR bytes(16), ) @@ -428,7 +428,7 @@ def test_successful_handshake_with_signature_verification( # Node A (manager) creates WHOAREYOU for remote. masking_iv = bytes(16) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - NodeId(remote_node_id), bytes(12), 0, masking_iv + NodeId(remote_node_id), bytes(12), SeqNumber(0), masking_iv ) # Remote creates handshake response. @@ -446,7 +446,7 @@ def test_successful_handshake_with_signature_verification( # Remote includes their ENR since enr_seq=0. remote_enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4", "secp256k1": bytes(remote_pub)}, ) @@ -476,7 +476,7 @@ def test_handle_handshake_rejects_invalid_signature( # Set up WHOAREYOU state. masking_iv = bytes(16) - manager.create_whoareyou(NodeId(remote_node_id), bytes(12), 0, masking_iv) + manager.create_whoareyou(NodeId(remote_node_id), bytes(12), SeqNumber(0), masking_iv) # Generate ephemeral key. _eph_priv, eph_pub = generate_secp256k1_keypair() @@ -484,7 +484,7 @@ def test_handle_handshake_rejects_invalid_signature( # Create authdata with INVALID signature (all-zero 64 bytes). remote_enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4", "secp256k1": bytes(remote_pub)}, ) @@ -515,7 +515,7 @@ def test_multiple_handshakes_independent(self, manager): manager.start_handshake(remote2) # Create WHOAREYOU for third remote. - manager.create_whoareyou(remote3, bytes(12), 0, bytes(16)) + manager.create_whoareyou(remote3, bytes(12), SeqNumber(0), bytes(16)) # All should have independent state. assert manager.get_pending(remote1).state == HandshakeState.SENT_ORDINARY @@ -581,8 +581,8 @@ def test_id_nonce_uniqueness_across_challenges(self, manager): remote1 = bytes.fromhex("01" + "00" * 31) remote2 = bytes.fromhex("02" + "00" * 31) - id_nonce1, _, _, _ = manager.create_whoareyou(remote1, bytes(12), 0, bytes(16)) - id_nonce2, _, _, _ = manager.create_whoareyou(remote2, bytes(12), 0, bytes(16)) + id_nonce1, _, _, _ = manager.create_whoareyou(remote1, bytes(12), SeqNumber(0), bytes(16)) + id_nonce2, _, _, _ = manager.create_whoareyou(remote2, bytes(12), SeqNumber(0), bytes(16)) # Each challenge should have unique id_nonce. assert id_nonce1 != id_nonce2 @@ -605,14 +605,14 @@ def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypa local_node_id=local_node_id, local_private_key=local_priv, local_enr_rlp=b"mock_enr_data", - local_enr_seq=5, + local_enr_seq=SeqNumber(5), session_cache=session_cache, ) # Remote creates WHOAREYOU with enr_seq=0 (stale). whoareyou = WhoAreYouAuthdata( id_nonce=IdNonce(bytes(16)), - enr_seq=Uint64(0), + enr_seq=SeqNumber(0), ) challenge_data = bytes(63) @@ -641,14 +641,14 @@ def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_key local_node_id=local_node_id, local_private_key=local_priv, local_enr_rlp=b"mock_enr_data", - local_enr_seq=5, + local_enr_seq=SeqNumber(5), session_cache=session_cache, ) # Remote creates WHOAREYOU with enr_seq=5 (current). whoareyou = WhoAreYouAuthdata( id_nonce=IdNonce(bytes(16)), - enr_seq=Uint64(5), + enr_seq=SeqNumber(5), ) challenge_data = bytes(63) @@ -673,7 +673,7 @@ def test_register_enr_stores_in_cache(self, manager): enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": NODE_B_PUBKEY, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_integration.py b/tests/lean_spec/subspecs/networking/discovery/test_integration.py index 8a836d9b..bff0d39c 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_integration.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_integration.py @@ -41,7 +41,7 @@ from lean_spec.subspecs.networking.discovery.session import Session, SessionCache from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes64, Uint64 +from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes64 @pytest.fixture @@ -86,7 +86,7 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): # Create a PING message. ping = Ping( request_id=RequestId(data=b"\x01"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), ) message_bytes = encode_message(ping) @@ -218,7 +218,7 @@ def test_whoareyou_generation(self, node_a_keys, node_b_keys): cache = SessionCache() enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4"}, ) @@ -226,7 +226,7 @@ def test_whoareyou_generation(self, node_a_keys, node_b_keys): local_node_id=node_a_keys["node_id"], local_private_key=node_a_keys["private_key"], local_enr_rlp=enr.to_rlp(), - local_enr_seq=1, + local_enr_seq=SeqNumber(1), session_cache=cache, ) @@ -236,7 +236,7 @@ def test_whoareyou_generation(self, node_a_keys, node_b_keys): id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( remote_node_id=node_b_keys["node_id"], request_nonce=request_nonce, - remote_enr_seq=0, + remote_enr_seq=SeqNumber(0), masking_iv=masking_iv, ) @@ -252,7 +252,7 @@ def test_start_and_cancel_handshake(self, node_a_keys, node_b_keys): cache = SessionCache() enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4"}, ) @@ -260,7 +260,7 @@ def test_start_and_cancel_handshake(self, node_a_keys, node_b_keys): local_node_id=node_a_keys["node_id"], local_private_key=node_a_keys["private_key"], local_enr_rlp=enr.to_rlp(), - local_enr_seq=1, + local_enr_seq=SeqNumber(1), session_cache=cache, ) @@ -298,12 +298,12 @@ def test_handshake_key_agreement(self, node_a_keys, node_b_keys): enr_a = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4", "secp256k1": node_a_keys["public_key"]}, ) enr_b = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4", "secp256k1": node_b_keys["public_key"]}, ) @@ -311,7 +311,7 @@ def test_handshake_key_agreement(self, node_a_keys, node_b_keys): local_node_id=node_a_keys["node_id"], local_private_key=node_a_keys["private_key"], local_enr_rlp=enr_a.to_rlp(), - local_enr_seq=1, + local_enr_seq=SeqNumber(1), session_cache=cache_a, ) @@ -319,7 +319,7 @@ def test_handshake_key_agreement(self, node_a_keys, node_b_keys): local_node_id=node_b_keys["node_id"], local_private_key=node_b_keys["private_key"], local_enr_rlp=enr_b.to_rlp(), - local_enr_seq=1, + local_enr_seq=SeqNumber(1), session_cache=cache_b, ) @@ -332,7 +332,7 @@ def test_handshake_key_agreement(self, node_a_keys, node_b_keys): id_nonce, whoareyou_authdata, _, challenge_data = manager_b.create_whoareyou( remote_node_id=node_a_keys["node_id"], request_nonce=request_nonce, - remote_enr_seq=0, + remote_enr_seq=SeqNumber(0), masking_iv=masking_iv, ) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_messages.py b/tests/lean_spec/subspecs/networking/discovery/test_messages.py index d655ff42..46039c29 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_messages.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_messages.py @@ -315,10 +315,10 @@ class TestWhoAreYouAuthdataConstruction: def test_creation(self): authdata = WhoAreYouAuthdata( id_nonce=IdNonce(b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10"), - enr_seq=Uint64(0), + enr_seq=SeqNumber(0), ) assert len(authdata.id_nonce) == 16 - assert authdata.enr_seq == Uint64(0) + assert authdata.enr_seq == SeqNumber(0) class TestMessageConstructionFromTestVectors: @@ -338,10 +338,10 @@ def test_ping_message_construction(self): def test_whoareyou_authdata_construction(self): authdata = WhoAreYouAuthdata( id_nonce=IdNonce(SPEC_ID_NONCE), - enr_seq=Uint64(0), + enr_seq=SeqNumber(0), ) assert authdata.id_nonce == IdNonce(SPEC_ID_NONCE) - assert authdata.enr_seq == Uint64(0) + assert authdata.enr_seq == SeqNumber(0) def test_plaintext_message_type(self): plaintext = bytes.fromhex("01c20101") diff --git a/tests/lean_spec/subspecs/networking/discovery/test_packet.py b/tests/lean_spec/subspecs/networking/discovery/test_packet.py index b306582d..7ea230dc 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_packet.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_packet.py @@ -21,7 +21,7 @@ generate_id_nonce, generate_nonce, ) -from lean_spec.subspecs.networking.types import NodeId +from lean_spec.subspecs.networking.types import NodeId, SeqNumber from lean_spec.types import Bytes16 @@ -76,7 +76,7 @@ class TestWhoAreYouAuthdata: def test_encode_whoareyou_authdata(self): """Test WHOAREYOU authdata encoding.""" id_nonce = bytes(16) - enr_seq = 42 + enr_seq = SeqNumber(42) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -85,24 +85,24 @@ def test_encode_whoareyou_authdata(self): def test_decode_whoareyou_authdata(self): """Test WHOAREYOU authdata decoding.""" id_nonce = bytes.fromhex("aa" * 16) - enr_seq = 12345 + enr_seq = SeqNumber(12345) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) decoded = decode_whoareyou_authdata(authdata) assert bytes(decoded.id_nonce) == id_nonce - assert int(decoded.enr_seq) == enr_seq + assert decoded.enr_seq == enr_seq def test_roundtrip(self): """Test encoding then decoding preserves values.""" id_nonce = bytes.fromhex("01" * 16) - enr_seq = 2**63 - 1 # Max uint64 + enr_seq = SeqNumber(2**63 - 1) # Max uint64 authdata = encode_whoareyou_authdata(id_nonce, enr_seq) decoded = decode_whoareyou_authdata(authdata) assert bytes(decoded.id_nonce) == id_nonce - assert int(decoded.enr_seq) == enr_seq + assert decoded.enr_seq == enr_seq def test_invalid_size_raises(self): """Test that invalid authdata size raises ValueError.""" @@ -202,7 +202,7 @@ def test_encode_whoareyou_packet(self): dest_node_id = NodeId(bytes(32)) nonce = bytes(12) id_nonce = bytes(16) - authdata = encode_whoareyou_authdata(id_nonce, 0) + authdata = encode_whoareyou_authdata(id_nonce, SeqNumber(0)) packet = encode_packet( dest_node_id=dest_node_id, @@ -221,7 +221,7 @@ def test_decode_packet_header(self): """Test packet header decoding.""" local_node_id = NodeId(bytes(32)) nonce = bytes(12) - authdata = encode_whoareyou_authdata(bytes(16), 42) + authdata = encode_whoareyou_authdata(bytes(16), SeqNumber(42)) packet = encode_packet( dest_node_id=local_node_id, @@ -422,7 +422,7 @@ class TestAuthdataInvalidLengths: def test_encode_whoareyou_authdata_wrong_id_nonce_length(self): """WHOAREYOU authdata rejects id_nonce that is not 16 bytes.""" with pytest.raises(ValueError, match="ID nonce must be 16 bytes"): - encode_whoareyou_authdata(bytes(15), 0) + encode_whoareyou_authdata(bytes(15), SeqNumber(0)) def test_encode_message_authdata_wrong_src_id_length(self): """MESSAGE authdata rejects src_id that is not 32 bytes.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_routing.py b/tests/lean_spec/subspecs/networking/discovery/test_routing.py index 68fd7588..dff72956 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_routing.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_routing.py @@ -20,7 +20,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes64, Uint64 +from lean_spec.types import Bytes64 from lean_spec.types.byte_arrays import Bytes4 @@ -111,7 +111,7 @@ def test_create_full_entry(self): node_id = NodeId(bytes(32)) enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4"}, ) @@ -462,7 +462,7 @@ def test_fork_filter_rejects_without_eth2_data(self, local_node_id, remote_node_ enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4"}, ) entry = NodeEntry(node_id=remote_node_id, enr=enr) @@ -480,7 +480,7 @@ def test_fork_filter_rejects_mismatched_fork(self, local_node_id, remote_node_id eth2_bytes = remote_digest + remote_digest + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"eth2": eth2_bytes, "id": b"v4"}, ) entry = NodeEntry(node_id=remote_node_id, enr=enr) @@ -502,7 +502,7 @@ def test_fork_filter_accepts_matching_fork(self, local_node_id, remote_node_id): ) enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"eth2": eth2_bytes, "id": b"v4"}, ) entry = NodeEntry(node_id=remote_node_id, enr=enr) @@ -524,7 +524,7 @@ def test_is_fork_compatible_method(self, local_node_id): ) compatible_enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"eth2": eth2_match, "id": b"v4"}, ) compatible_entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr=compatible_enr) @@ -538,7 +538,7 @@ def test_is_fork_compatible_method(self, local_node_id): ) incompatible_enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"eth2": eth2_mismatch, "id": b"v4"}, ) incompatible_entry = NodeEntry(node_id=NodeId(b"\x02" * 32), enr=incompatible_enr) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_service.py b/tests/lean_spec/subspecs/networking/discovery/test_service.py index 39b529a4..9bef4a42 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_service.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_service.py @@ -27,7 +27,7 @@ ) from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes64, Uint64 +from lean_spec.types import Bytes64 from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY @@ -64,7 +64,7 @@ def test_init_with_bootnodes(self, local_enr, local_private_key): """Service accepts bootnodes list.""" bootnode = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": bytes.fromhex( @@ -85,7 +85,7 @@ def test_init_requires_public_key_in_enr(self, local_private_key): """Service requires ENR to have a public key.""" enr_without_pubkey = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4"}, ) @@ -305,7 +305,7 @@ def test_enr_ip4_extraction(self, local_private_key): """Extract IPv4 address from ENR.""" enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": NODE_B_PUBKEY, @@ -325,7 +325,7 @@ def test_enr_ip6_extraction(self, local_private_key): enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": NODE_B_PUBKEY, @@ -344,7 +344,7 @@ def test_enr_dual_stack_has_both(self, local_private_key): enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": NODE_B_PUBKEY, @@ -363,7 +363,7 @@ def test_enr_missing_ip_returns_none(self, local_private_key): """ENR without IP returns None for ip4.""" enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": NODE_B_PUBKEY, @@ -485,7 +485,7 @@ def test_service_accepts_bootnodes(self, local_enr, local_private_key): """Service accepts bootnodes in constructor.""" bootnode = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": bytes.fromhex( @@ -510,7 +510,7 @@ def test_service_accepts_multiple_bootnodes(self, local_enr, local_private_key): for i in range(5): bootnode = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(i + 1), + seq=SeqNumber(i + 1), pairs={ "id": b"v4", "secp256k1": bytes.fromhex( @@ -562,7 +562,7 @@ async def test_handle_ping_sends_pong(self, local_enr, local_private_key, remote private_key=local_private_key, ) - ping = Ping(request_id=RequestId(data=b"\x01\x02"), enr_seq=Uint64(1)) + ping = Ping(request_id=RequestId(data=b"\x01\x02"), enr_seq=SeqNumber(1)) addr = ("192.168.1.1", 30303) with patch.object( @@ -583,7 +583,7 @@ async def test_handle_ping_establishes_bond(self, local_enr, local_private_key, private_key=local_private_key, ) - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) addr = ("192.168.1.1", 30303) with patch.object(service._transport, "send_response", new=AsyncMock(return_value=True)): @@ -601,7 +601,7 @@ async def test_handle_ping_no_bond_when_send_fails( private_key=local_private_key, ) - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) addr = ("192.168.1.1", 30303) with patch.object(service._transport, "send_response", new=AsyncMock(return_value=False)): @@ -619,7 +619,7 @@ async def test_handle_ping_pong_includes_recipient_endpoint( private_key=local_private_key, ) - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) addr = ("10.0.0.5", 9001) with patch.object( @@ -855,7 +855,7 @@ async def test_bootstrap_registers_bootnode_addresses(self, local_enr, local_pri ) bootnode = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": node_a_pubkey, @@ -884,7 +884,7 @@ async def test_bootstrap_skips_bootnodes_without_ip(self, local_enr, local_priva ) bootnode = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": node_a_pubkey, @@ -929,7 +929,7 @@ def test_enr_with_wrong_distance_is_dropped(self, local_enr, local_private_key): # Create a valid ENR. enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ "id": b"v4", "secp256k1": bytes.fromhex( diff --git a/tests/lean_spec/subspecs/networking/discovery/test_session.py b/tests/lean_spec/subspecs/networking/discovery/test_session.py index 327e6fa8..bda15a64 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_session.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_session.py @@ -4,6 +4,7 @@ import pytest +from lean_spec.subspecs.networking.discovery.messages import Port from lean_spec.subspecs.networking.discovery.session import ( BondCache, Session, @@ -207,12 +208,16 @@ def test_endpoint_keying_separates_sessions(self): send_key_2 = bytes([0x02] * 16) # Create sessions for same node at different endpoints. - cache.create(node_id, send_key_1, bytes(16), is_initiator=True, ip="10.0.0.1", port=9000) - cache.create(node_id, send_key_2, bytes(16), is_initiator=True, ip="10.0.0.2", port=9000) + cache.create( + node_id, send_key_1, bytes(16), is_initiator=True, ip="10.0.0.1", port=Port(9000) + ) + cache.create( + node_id, send_key_2, bytes(16), is_initiator=True, ip="10.0.0.2", port=Port(9000) + ) # Each endpoint retrieves its own session. - session_1 = cache.get(node_id, "10.0.0.1", 9000) - session_2 = cache.get(node_id, "10.0.0.2", 9000) + session_1 = cache.get(node_id, "10.0.0.1", Port(9000)) + session_2 = cache.get(node_id, "10.0.0.2", Port(9000)) assert session_1 is not None assert session_2 is not None @@ -220,7 +225,7 @@ def test_endpoint_keying_separates_sessions(self): assert session_2.send_key == send_key_2 # Different port for same IP is also separate. - assert cache.get(node_id, "10.0.0.1", 9001) is None + assert cache.get(node_id, "10.0.0.1", Port(9001)) is None class TestBondCache: diff --git a/tests/lean_spec/subspecs/networking/discovery/test_transport.py b/tests/lean_spec/subspecs/networking/discovery/test_transport.py index 3aa74e3c..95f6a613 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_transport.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_transport.py @@ -29,8 +29,8 @@ PendingRequest, ) from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId -from lean_spec.types import Bytes64, Uint64 +from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes64 from lean_spec.types.uint import Uint8 @@ -171,7 +171,7 @@ def test_register_enr(self, local_node_id, local_private_key, local_enr, remote_ remote_enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4"}, ) @@ -312,7 +312,7 @@ def test_create_pending_request(self): loop = asyncio.new_event_loop() future: asyncio.Future = loop.create_future() - message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) pending = PendingRequest( request_id=b"\x01\x02\x03\x04", @@ -359,7 +359,7 @@ async def test_send_response_without_session_returns_false( pong = Pong( request_id=RequestId(data=b"\x01"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(9000), ) @@ -383,7 +383,7 @@ async def test_send_response_without_transport_returns_false( pong = Pong( request_id=RequestId(data=b"\x01"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(9000), ) @@ -477,9 +477,9 @@ async def test_queue(): ) # Simulate receiving 3 messages. - ping1 = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) - ping2 = Ping(request_id=RequestId(data=b"\x02"), enr_seq=Uint64(2)) - ping3 = Ping(request_id=RequestId(data=b"\x03"), enr_seq=Uint64(3)) + ping1 = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) + ping2 = Ping(request_id=RequestId(data=b"\x02"), enr_seq=SeqNumber(2)) + ping3 = Ping(request_id=RequestId(data=b"\x03"), enr_seq=SeqNumber(3)) await pending.response_queue.put(ping1) await pending.response_queue.put(ping2) await pending.response_queue.put(ping3) @@ -569,7 +569,7 @@ def test_pending_request_stores_request_id(self): loop = asyncio.new_event_loop() future: asyncio.Future = loop.create_future() - message = Ping(request_id=RequestId(data=b"\x01\x02\x03\x04"), enr_seq=Uint64(1)) + message = Ping(request_id=RequestId(data=b"\x01\x02\x03\x04"), enr_seq=SeqNumber(1)) pending = PendingRequest( request_id=b"\x01\x02\x03\x04", @@ -594,7 +594,7 @@ def test_pending_request_future_completion(self): async def test_future(): future: asyncio.Future = loop.create_future() - message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) pending = PendingRequest( request_id=b"\x01", dest_node_id=NodeId(bytes(32)), @@ -610,7 +610,7 @@ async def test_future(): # Complete the future with a response. response = Pong( request_id=RequestId(data=b"\x01"), - enr_seq=Uint64(2), + enr_seq=SeqNumber(2), recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(9000), ) @@ -630,7 +630,7 @@ def test_pending_request_future_cancellation(self): future: asyncio.Future = loop.create_future() - message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) pending = PendingRequest( request_id=b"\x01", dest_node_id=NodeId(bytes(32)), @@ -659,8 +659,8 @@ def test_request_id_bytes_for_dict_lookup(self): future1: asyncio.Future = loop.create_future() future2: asyncio.Future = loop.create_future() - message1 = Ping(request_id=RequestId(data=request_id_1), enr_seq=Uint64(1)) - message2 = Ping(request_id=RequestId(data=request_id_2), enr_seq=Uint64(2)) + message1 = Ping(request_id=RequestId(data=request_id_1), enr_seq=SeqNumber(1)) + message2 = Ping(request_id=RequestId(data=request_id_2), enr_seq=SeqNumber(2)) pending1 = PendingRequest( request_id=request_id_1, @@ -701,7 +701,7 @@ def test_pending_request_stores_nonce_for_whoareyou_matching(self): nonce = b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c" - message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) pending = PendingRequest( request_id=b"\x01", dest_node_id=NodeId(bytes(32)), @@ -722,7 +722,7 @@ def test_pending_request_stores_message_for_retransmission(self): loop = asyncio.new_event_loop() future: asyncio.Future = loop.create_future() - message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(42)) + message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(42)) pending = PendingRequest( request_id=b"\x01", dest_node_id=NodeId(bytes(32)), @@ -1100,7 +1100,7 @@ async def test_response_completes_pending_request_future( pong = Pong( request_id=RequestId(data=request_id), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(9000), ) @@ -1163,7 +1163,7 @@ async def test_unmatched_message_dispatched_to_handler( ping = Ping( request_id=RequestId(data=b"\xff\xff"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), ) await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) @@ -1183,7 +1183,7 @@ async def test_unmatched_message_without_handler_is_silent( ping = Ping( request_id=RequestId(data=b"\xff\xff"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), ) # Should not raise. @@ -1203,11 +1203,11 @@ async def test_decoded_message_touches_session( with patch.object(transport._session_cache, "touch") as mock_touch: ping = Ping( request_id=RequestId(data=b"\xff"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), ) await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) - mock_touch.assert_called_once_with(remote_node_id, "192.168.1.1", 30303) + mock_touch.assert_called_once_with(remote_node_id, "192.168.1.1", Port(30303)) class TestHandlePacketDispatch: @@ -1333,7 +1333,7 @@ async def test_send_whoareyou_uses_cached_enr_seq( # Register a remote ENR with seq=42. remote_enr = ENR( signature=Bytes64(bytes(64)), - seq=Uint64(42), + seq=SeqNumber(42), pairs={"id": b"v4"}, ) transport.register_enr(remote_node_id, remote_enr) @@ -1357,6 +1357,6 @@ async def test_send_whoareyou_uses_cached_enr_seq( # Verify enr_seq=42 was passed, not 0. call_kwargs = mock_create.call_args - assert call_kwargs[1]["remote_enr_seq"] == 42 + assert call_kwargs[1]["remote_enr_seq"] == SeqNumber(42) await transport.stop() diff --git a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py index eb584773..6b9c9a20 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py @@ -53,8 +53,8 @@ encode_whoareyou_authdata, ) from lean_spec.subspecs.networking.discovery.routing import log2_distance, xor_distance -from lean_spec.subspecs.networking.types import NodeId -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64, Uint64 +from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64 from lean_spec.types.uint import Uint8 from tests.lean_spec.helpers import make_challenge_data from tests.lean_spec.subspecs.networking.discovery.conftest import ( @@ -428,7 +428,7 @@ def test_whoareyou_packet_roundtrip(self): """WHOAREYOU packet encodes and decodes correctly.""" nonce = bytes.fromhex("0102030405060708090a0b0c") id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") - enr_seq = 0 + enr_seq = SeqNumber(0) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -449,7 +449,7 @@ def test_whoareyou_packet_roundtrip(self): decoded_authdata = decode_whoareyou_authdata(header.authdata) assert bytes(decoded_authdata.id_nonce) == id_nonce - assert int(decoded_authdata.enr_seq) == enr_seq + assert decoded_authdata.enr_seq == enr_seq def test_handshake_packet_roundtrip(self): """HANDSHAKE packet encodes and decodes correctly.""" @@ -506,7 +506,7 @@ def test_official_ping_message_rlp_encoding(self): # PING with request ID [0x00, 0x00, 0x00, 0x01] and enr_seq = 1 ping = Ping( request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), ) encoded = encode_message(ping) @@ -526,7 +526,7 @@ def test_official_pong_message_rlp_encoding(self): """ pong = Pong( request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=Uint64(1), + enr_seq=SeqNumber(1), recipient_ip=IPv4(b"\x7f\x00\x00\x01"), # 127.0.0.1 recipient_port=Port(30303), ) @@ -615,7 +615,7 @@ def test_whoareyou_packet_header_structure(self): """ nonce = bytes(12) id_nonce = bytes(16) - enr_seq = 0 + enr_seq = SeqNumber(0) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) diff --git a/tests/lean_spec/subspecs/networking/enr/test_enr.py b/tests/lean_spec/subspecs/networking/enr/test_enr.py index c59ceccb..e45470f9 100644 --- a/tests/lean_spec/subspecs/networking/enr/test_enr.py +++ b/tests/lean_spec/subspecs/networking/enr/test_enr.py @@ -22,6 +22,7 @@ from lean_spec.subspecs.networking.enr import ENR, keys from lean_spec.subspecs.networking.enr.enr import ENR_PREFIX +from lean_spec.subspecs.networking.types import SeqNumber from lean_spec.types import Bytes64, SSZValueError, Uint64 from lean_spec.types.byte_arrays import Bytes4 from lean_spec.types.rlp import encode_rlp @@ -261,7 +262,7 @@ def test_identity_scheme_returns_v4(self) -> None: """identity_scheme property returns 'v4' for valid ENR.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: b"\x02" + b"\x00" * 32}, ) assert enr.identity_scheme == "v4" @@ -270,7 +271,7 @@ def test_identity_scheme_returns_none_when_missing(self) -> None: """identity_scheme returns None when 'id' key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.SECP256K1: b"\x02" + b"\x00" * 32}, ) assert enr.identity_scheme is None @@ -280,7 +281,7 @@ def test_public_key_returns_33_bytes(self) -> None: expected_key = b"\x03" + b"\xab" * 32 enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: expected_key}, ) public_key = enr.public_key @@ -292,7 +293,7 @@ def test_public_key_returns_none_when_missing(self) -> None: """public_key returns None when secp256k1 key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.public_key is None @@ -301,7 +302,7 @@ def test_ip4_formats_address_correctly(self) -> None: """ip4 property formats IPv4 address as dotted string.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ keys.ID: b"v4", keys.SECP256K1: b"\x02" + b"\x00" * 32, @@ -321,7 +322,7 @@ def test_ip4_various_addresses(self) -> None: for ip_bytes, expected in test_cases: enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.IP: ip_bytes}, ) assert enr.ip4 == expected @@ -330,7 +331,7 @@ def test_ip4_returns_none_when_missing(self) -> None: """ip4 returns None when 'ip' key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.ip4 is None @@ -339,7 +340,7 @@ def test_ip4_returns_none_for_wrong_length(self) -> None: """ip4 returns None when IP bytes are not 4 bytes.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.IP: b"\x7f\x00\x00"}, # Only 3 bytes ) assert enr.ip4 is None @@ -350,7 +351,7 @@ def test_ip6_formats_address_correctly(self) -> None: ipv6_bytes = b"\x00" * 15 + b"\x01" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.IP6: ipv6_bytes}, ) assert enr.ip6 == "0000:0000:0000:0000:0000:0000:0000:0001" @@ -359,7 +360,7 @@ def test_ip6_returns_none_when_missing(self) -> None: """ip6 returns None when 'ip6' key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.ip6 is None @@ -368,7 +369,7 @@ def test_ip6_returns_none_for_wrong_length(self) -> None: """ip6 returns None when IP bytes are not 16 bytes.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.IP6: b"\x00" * 8}, # Only 8 bytes ) assert enr.ip6 is None @@ -377,7 +378,7 @@ def test_udp_port_extracts_correctly(self) -> None: """udp_port extracts port number from big-endian bytes.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.UDP: (30303).to_bytes(2, "big")}, ) assert enr.udp_port == 30303 @@ -393,7 +394,7 @@ def test_udp_port_various_values(self) -> None: for port_bytes, expected in test_cases: enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.UDP: port_bytes}, ) assert enr.udp_port == expected @@ -402,7 +403,7 @@ def test_udp_port_returns_none_when_missing(self) -> None: """udp_port returns None when 'udp' key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.udp_port is None @@ -415,7 +416,7 @@ def test_is_valid_returns_true_for_complete_v4_enr(self) -> None: """is_valid() returns True for complete v4 ENR.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: b"\x02" + b"\x00" * 32}, ) assert enr.is_valid() @@ -424,7 +425,7 @@ def test_is_valid_returns_false_for_missing_public_key(self) -> None: """is_valid() returns False when secp256k1 key is missing.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert not enr.is_valid() @@ -433,7 +434,7 @@ def test_is_valid_returns_false_for_wrong_identity_scheme(self) -> None: """is_valid() returns False for non-v4 identity scheme.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v5", keys.SECP256K1: b"\x02" + b"\x00" * 32}, ) assert not enr.is_valid() @@ -442,7 +443,7 @@ def test_is_valid_returns_false_for_missing_identity_scheme(self) -> None: """is_valid() returns False when 'id' key is missing.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.SECP256K1: b"\x02" + b"\x00" * 32}, ) assert not enr.is_valid() @@ -452,7 +453,7 @@ def test_is_valid_returns_false_for_wrong_pubkey_length(self) -> None: # 32 bytes (uncompressed prefix missing) enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: b"\x00" * 32}, ) assert not enr.is_valid() @@ -460,7 +461,7 @@ def test_is_valid_returns_false_for_wrong_pubkey_length(self) -> None: # 65 bytes (uncompressed format) enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: b"\x04" + b"\x00" * 64}, ) assert not enr.is_valid() @@ -471,7 +472,7 @@ def test_construction_fails_for_wrong_signature_length(self) -> None: with pytest.raises(SSZValueError, match="requires exactly 64 bytes"): ENR( signature=Bytes64(b"\x00" * 63), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: b"\x02" + b"\x00" * 32}, ) @@ -483,7 +484,7 @@ def test_multiaddr_with_ipv4_and_udp(self) -> None: """multiaddr() generates QUIC format with IPv4 and UDP.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ keys.ID: b"v4", keys.IP: b"\xc0\xa8\x01\x01", # 192.168.1.1 @@ -497,7 +498,7 @@ def test_multiaddr_with_ipv6_and_udp(self) -> None: ipv6_bytes = b"\x00" * 15 + b"\x01" # ::1 enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ keys.ID: b"v4", keys.IP6: ipv6_bytes, @@ -510,7 +511,7 @@ def test_multiaddr_returns_none_without_udp(self) -> None: """multiaddr() returns None when UDP port is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ keys.ID: b"v4", keys.IP: b"\xc0\xa8\x01\x01", @@ -522,7 +523,7 @@ def test_multiaddr_returns_none_without_ip(self) -> None: """multiaddr() returns None when no IP address is present.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.UDP: (9000).to_bytes(2, "big")}, ) assert enr.multiaddr() is None @@ -531,7 +532,7 @@ def test_multiaddr_prefers_ipv4_over_ipv6(self) -> None: """multiaddr() uses IPv4 when both IPv4 and IPv6 are present.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ keys.ID: b"v4", keys.IP: b"\xc0\xa8\x01\x01", # 192.168.1.1 @@ -549,7 +550,7 @@ def test_str_includes_seq(self) -> None: """__str__() includes sequence number.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(42), + seq=SeqNumber(42), pairs={keys.ID: b"v4"}, ) result = str(enr) @@ -559,7 +560,7 @@ def test_str_includes_ip(self) -> None: """__str__() includes IP address when present.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.IP: b"\xc0\xa8\x01\x01"}, ) result = str(enr) @@ -569,7 +570,7 @@ def test_str_includes_udp_port(self) -> None: """__str__() includes UDP port when present.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.UDP: (30303).to_bytes(2, "big")}, ) result = str(enr) @@ -579,7 +580,7 @@ def test_str_minimal_enr(self) -> None: """__str__() works for minimal ENR.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={}, ) result = str(enr) @@ -595,7 +596,7 @@ def test_get_existing_key(self) -> None: """get() returns value for existing key.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.get(keys.ID) == b"v4" @@ -604,7 +605,7 @@ def test_get_missing_key(self) -> None: """get() returns None for missing key.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.get(keys.IP) is None @@ -613,7 +614,7 @@ def test_has_existing_key(self) -> None: """has() returns True for existing key.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.IP: b"\x7f\x00\x00\x01"}, ) assert enr.has(keys.ID) @@ -623,7 +624,7 @@ def test_has_missing_key(self) -> None: """has() returns False for missing key.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert not enr.has(keys.IP) @@ -765,7 +766,7 @@ def test_eth2_data_parses_from_enr(self) -> None: eth2_bytes = b"\x12\x34\x56\x78" + b"\x02\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x01" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.ETH2: eth2_bytes}, ) @@ -780,7 +781,7 @@ def test_eth2_data_returns_none_when_missing(self) -> None: """eth2_data returns None when eth2 key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.eth2_data is None @@ -789,7 +790,7 @@ def test_eth2_data_returns_none_for_short_data(self) -> None: """eth2_data returns None when eth2 key is too short.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.ETH2: b"\x12\x34\x56\x78"}, # Only 4 bytes ) assert enr.eth2_data is None @@ -804,7 +805,7 @@ def test_attestation_subnets_parses_from_enr(self) -> None: attnets_bytes = b"\xff" * 8 enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.ATTNETS: attnets_bytes}, ) @@ -816,7 +817,7 @@ def test_attestation_subnets_returns_none_when_missing(self) -> None: """attestation_subnets returns None when attnets key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.attestation_subnets is None @@ -825,7 +826,7 @@ def test_attestation_subnets_returns_none_for_wrong_length(self) -> None: """attestation_subnets returns None when attnets key is not 8 bytes.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.ATTNETS: b"\xff\xff\xff\xff"}, # Only 4 bytes ) assert enr.attestation_subnets is None @@ -840,7 +841,7 @@ def test_sync_committee_subnets_parses_from_enr(self) -> None: syncnets_bytes = b"\x0f" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SYNCNETS: syncnets_bytes}, ) @@ -853,7 +854,7 @@ def test_sync_committee_subnets_returns_none_when_missing(self) -> None: """sync_committee_subnets returns None when syncnets key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.sync_committee_subnets is None @@ -862,7 +863,7 @@ def test_sync_committee_subnets_returns_none_for_wrong_length(self) -> None: """sync_committee_subnets returns None when syncnets key is not 1 byte.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SYNCNETS: b"\x0f\x00"}, # 2 bytes ) assert enr.sync_committee_subnets is None @@ -877,12 +878,12 @@ def test_compatible_with_same_fork_digest(self) -> None: enr1 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.ETH2: eth2_bytes}, ) enr2 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(2), + seq=SeqNumber(2), pairs={keys.ID: b"v4", keys.ETH2: eth2_bytes}, ) @@ -895,12 +896,12 @@ def test_incompatible_with_different_fork_digest(self) -> None: enr1 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.ETH2: eth2_bytes1}, ) enr2 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(2), + seq=SeqNumber(2), pairs={keys.ID: b"v4", keys.ETH2: eth2_bytes2}, ) @@ -912,12 +913,12 @@ def test_incompatible_when_self_missing_eth2(self) -> None: enr1 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, # No eth2 ) enr2 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(2), + seq=SeqNumber(2), pairs={keys.ID: b"v4", keys.ETH2: eth2_bytes}, ) @@ -929,12 +930,12 @@ def test_incompatible_when_other_missing_eth2(self) -> None: enr1 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.ETH2: eth2_bytes}, ) enr2 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(2), + seq=SeqNumber(2), pairs={keys.ID: b"v4"}, # No eth2 ) @@ -944,12 +945,12 @@ def test_incompatible_when_both_missing_eth2(self) -> None: """ENRs are incompatible when both lack eth2 key.""" enr1 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) enr2 = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(2), + seq=SeqNumber(2), pairs={keys.ID: b"v4"}, ) @@ -1188,7 +1189,7 @@ def test_to_string_produces_valid_enr_format(self) -> None: """to_string() produces valid 'enr:' prefixed string.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: b"\x02" + b"\x00" * 32}, ) result = enr.to_string() @@ -1241,7 +1242,7 @@ def test_self_signed_enr_verifies(self) -> None: # Create ENR. enr = ENR( signature=Bytes64(sig_64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4", keys.SECP256K1: compressed_pubkey}, ) @@ -1268,7 +1269,7 @@ def test_tampered_content_fails_verification(self) -> None: # Create ENR with different sequence number (content mismatch) tampered_enr = ENR( signature=enr.signature, - seq=Uint64(int(enr.seq) + 1), # Different sequence + seq=SeqNumber(int(enr.seq) + 1), # Different sequence pairs=enr.pairs, ) @@ -1278,7 +1279,7 @@ def test_missing_public_key_fails_verification(self) -> None: """ENR without public key fails verification.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, # No secp256k1 key ) @@ -1300,7 +1301,7 @@ def test_node_id_none_without_public_key(self) -> None: """compute_node_id() returns None when public key is missing.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) @@ -1314,7 +1315,7 @@ def test_udp6_port_extracts_correctly(self) -> None: """udp6_port extracts IPv6-specific UDP port.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ keys.ID: b"v4", keys.UDP6: (30304).to_bytes(2, "big"), @@ -1326,7 +1327,7 @@ def test_udp6_port_returns_none_when_missing(self) -> None: """udp6_port returns None when udp6 key is absent.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={keys.ID: b"v4"}, ) assert enr.udp6_port is None @@ -1335,7 +1336,7 @@ def test_ipv6_udp_port_independent_of_ipv4(self) -> None: """IPv6 UDP port is independent from IPv4 UDP port.""" enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={ keys.ID: b"v4", keys.UDP: (30303).to_bytes(2, "big"), diff --git a/tests/lean_spec/subspecs/networking/enr/test_eth2.py b/tests/lean_spec/subspecs/networking/enr/test_eth2.py index b04737b3..fb437a79 100644 --- a/tests/lean_spec/subspecs/networking/enr/test_eth2.py +++ b/tests/lean_spec/subspecs/networking/enr/test_eth2.py @@ -3,6 +3,7 @@ import pytest from pydantic import ValidationError +from lean_spec.subspecs.containers.validator import SubnetId from lean_spec.subspecs.networking.enr import Eth2Data from lean_spec.subspecs.networking.enr.eth2 import ( FAR_FUTURE_EPOCH, @@ -82,7 +83,7 @@ def test_subscribed_subnets_list(self) -> None: subnets = AttestationSubnets.from_subnet_ids([10, 20, 30]) result = subnets.subscribed_subnets() - assert result == [10, 20, 30] + assert result == [SubnetId(10), SubnetId(20), SubnetId(30)] def test_invalid_subnet_id_in_from_subnet_ids(self) -> None: """from_subnet_ids() raises for invalid subnet IDs.""" @@ -112,7 +113,7 @@ def test_from_subnet_ids_with_duplicates(self) -> None: """from_subnet_ids handles duplicates correctly.""" subnets = AttestationSubnets.from_subnet_ids([5, 5, 5, 10]) assert subnets.subscription_count() == 2 - assert subnets.subscribed_subnets() == [5, 10] + assert subnets.subscribed_subnets() == [SubnetId(5), SubnetId(10)] def test_encode_bytes_empty(self) -> None: """Empty subscriptions serialize to 8 zero bytes.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/test_gossipsub.py b/tests/lean_spec/subspecs/networking/gossipsub/test_gossipsub.py index 18184710..fa19d20a 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/test_gossipsub.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/test_gossipsub.py @@ -2,6 +2,7 @@ import pytest +from lean_spec.subspecs.containers.validator import SubnetId from lean_spec.subspecs.networking import PeerId from lean_spec.subspecs.networking.client.event_source import GossipHandler from lean_spec.subspecs.networking.gossipsub import ( @@ -140,8 +141,8 @@ def test_gossip_topic_factory_methods(self) -> None: assert GossipTopic.block("0xabcd1234") == GossipTopic( kind=TopicKind.BLOCK, fork_digest="0xabcd1234" ) - assert GossipTopic.attestation_subnet("0xabcd1234", 0) == GossipTopic( - kind=TopicKind.ATTESTATION_SUBNET, fork_digest="0xabcd1234", subnet_id=0 + assert GossipTopic.attestation_subnet("0xabcd1234", SubnetId(0)) == GossipTopic( + kind=TopicKind.ATTESTATION_SUBNET, fork_digest="0xabcd1234", subnet_id=SubnetId(0) ) def test_format_topic_string(self) -> None: diff --git a/tests/lean_spec/subspecs/networking/test_peer.py b/tests/lean_spec/subspecs/networking/test_peer.py index 83131195..2edf5ea4 100644 --- a/tests/lean_spec/subspecs/networking/test_peer.py +++ b/tests/lean_spec/subspecs/networking/test_peer.py @@ -9,8 +9,8 @@ from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH from lean_spec.subspecs.networking.peer import PeerInfo from lean_spec.subspecs.networking.reqresp import Status -from lean_spec.subspecs.networking.types import ConnectionState, Direction, GoodbyeReason -from lean_spec.types import Bytes32, Bytes64, Uint64 +from lean_spec.subspecs.networking.types import ConnectionState, Direction, GoodbyeReason, SeqNumber +from lean_spec.types import Bytes32, Bytes64 def peer(name: str) -> PeerId: @@ -111,7 +111,7 @@ def _make_enr_with_eth2(self, fork_digest_bytes: bytes) -> ENR: ) return ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"eth2": eth2_bytes, "id": b"v4"}, ) @@ -125,7 +125,7 @@ def test_fork_digest_none_without_eth2(self) -> None: # ENR without eth2 key enr = ENR( signature=Bytes64(b"\x00" * 64), - seq=Uint64(1), + seq=SeqNumber(1), pairs={"id": b"v4"}, ) info = PeerInfo(peer_id=peer("test"), enr=enr) diff --git a/tests/lean_spec/subspecs/validator/test_registry.py b/tests/lean_spec/subspecs/validator/test_registry.py index 1967fb88..e0b8ac06 100644 --- a/tests/lean_spec/subspecs/validator/test_registry.py +++ b/tests/lean_spec/subspecs/validator/test_registry.py @@ -92,7 +92,9 @@ def test_from_secret_keys(self) -> None: key_0 = MagicMock(name="key_0") key_2 = MagicMock(name="key_2") - registry = ValidatorRegistry.from_secret_keys({0: key_0, 2: key_2}) + registry = ValidatorRegistry.from_secret_keys( + {ValidatorIndex(0): key_0, ValidatorIndex(2): key_2} + ) assert registry_state(registry) == {ValidatorIndex(0): key_0, ValidatorIndex(2): key_2} diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index 5ae5bcfc..b9c0d394 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -536,7 +536,7 @@ async def capture_block(block: SignedBlockWithAttestation) -> None: is_valid = TARGET_SIGNATURE_SCHEME.verify( pk=proposer_public_key, - epoch=signed_block.message.block.slot, + slot=signed_block.message.block.slot, message=message_bytes, sig=signed_block.signature.proposer_signature, ) @@ -580,7 +580,7 @@ async def capture_attestation(attestation: SignedAttestation) -> None: is_valid = TARGET_SIGNATURE_SCHEME.verify( pk=public_key, - epoch=signed_att.message.slot, + slot=signed_att.message.slot, message=message_bytes, sig=signed_att.signature, ) @@ -676,7 +676,7 @@ async def capture_block(block: SignedBlockWithAttestation) -> None: is_valid = TARGET_SIGNATURE_SCHEME.verify( pk=public_key, - epoch=signed_block.message.block.slot, + slot=signed_block.message.block.slot, message=message_bytes, sig=signed_block.signature.proposer_signature, ) @@ -715,7 +715,7 @@ async def test_block_includes_pending_attestations( public_keys=public_keys, signatures=signatures, message=data_root, - epoch=attestation_data.slot, + slot=attestation_data.slot, ) aggregated_payloads = {SignatureKey(vid, data_root): [proof] for vid in participants} @@ -889,17 +889,16 @@ async def capture_block(block: SignedBlockWithAttestation) -> None: computed_state_root = hash_tree_root(stored_state) assert produced_block.state_root == computed_state_root - async def test_signature_uses_correct_slot_as_epoch( + async def test_signature_uses_correct_slot( self, key_manager: XmssKeyManager, real_sync_service: SyncService, real_registry: ValidatorRegistry, ) -> None: """ - Verify signatures use the correct slot as the XMSS epoch parameter. + Verify signatures use the correct slot as the XMSS slot parameter. - XMSS is stateful and uses epochs for one-time signature keys. - The slot number serves as the epoch in the lean protocol. + XMSS is stateful and uses slots for one-time signature keys. """ clock = SlotClock(genesis_time=Uint64(0)) attestations_produced: list[SignedAttestation] = [] @@ -918,27 +917,27 @@ async def capture_attestation(attestation: SignedAttestation) -> None: await service._produce_attestations(test_slot) - # Verify each signature was created with the correct epoch (slot) + # Verify each signature was created with the correct slot for signed_att in attestations_produced: validator_id = signed_att.validator_id public_key = key_manager.get_public_key(validator_id) message_bytes = signed_att.message.data_root_bytes() - # Verification must use the same epoch that was used for signing + # Verification must use the same slot that was used for signing is_valid = TARGET_SIGNATURE_SCHEME.verify( pk=public_key, - epoch=test_slot, # Must match the signing slot + slot=test_slot, # Must match the signing slot message=message_bytes, sig=signed_att.signature, ) assert is_valid, f"Signature for validator {validator_id} at slot {test_slot} failed" - # Verify with wrong epoch should fail - wrong_epoch = test_slot + Slot(1) + # Verify with wrong slot should fail + wrong_slot = test_slot + Slot(1) is_invalid = TARGET_SIGNATURE_SCHEME.verify( pk=public_key, - epoch=wrong_epoch, + slot=wrong_slot, message=message_bytes, sig=signed_att.signature, ) - assert not is_invalid, "Signature should fail with wrong epoch" + assert not is_invalid, "Signature should fail with wrong slot" diff --git a/tests/lean_spec/subspecs/xmss/test_interface.py b/tests/lean_spec/subspecs/xmss/test_interface.py index dd4efd70..c1efa105 100644 --- a/tests/lean_spec/subspecs/xmss/test_interface.py +++ b/tests/lean_spec/subspecs/xmss/test_interface.py @@ -4,6 +4,7 @@ import pytest +from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.xmss.interface import ( TEST_SIGNATURE_SCHEME, GeneralizedXmssScheme, @@ -13,34 +14,34 @@ def _test_correctness_roundtrip( scheme: GeneralizedXmssScheme, - activation_epoch: int, - num_active_epochs: int, + activation_slot: int, + num_active_slots: int, ) -> None: """ A helper to perform a full key_gen -> sign -> verify roundtrip. - It generates a key pair, signs a message at a specific epoch, and + It generates a key pair, signs a message at a specific slot, and verifies the signature. It also checks that verification fails for - an incorrect message or epoch. + an incorrect message or slot. """ # KEY GENERATION # # Generate a new key pair for the specified active range. - pk, sk = scheme.key_gen(Uint64(activation_epoch), Uint64(num_active_epochs)) + pk, sk = scheme.key_gen(Slot(activation_slot), Uint64(num_active_slots)) # SIGN & VERIFY # - # Pick a sample epoch within the active range to test signing. - test_epoch = Uint64(activation_epoch + num_active_epochs // 2) + # Pick a sample slot within the active range to test signing. + test_slot = Slot(activation_slot + num_active_slots // 2) message = Bytes32(b"\x42" * 32) - # Sign the message at the chosen epoch. + # Sign the message at the chosen slot. # # This might take a moment as it may try multiple `rho` values. - signature = scheme.sign(sk, test_epoch, message) + signature = scheme.sign(sk, test_slot, message) # Verification of the valid signature must succeed. - is_valid = scheme.verify(pk, test_epoch, message, signature) + is_valid = scheme.verify(pk, test_slot, message, signature) assert is_valid, "Verification of a valid signature failed" # TEST INVALID CASES @@ -55,52 +56,52 @@ def _test_correctness_roundtrip( # In that case, verification will succeed, which is expected behavior for identical codewords. # # We detect this by checking if both messages encode to the same codeword. - original_codeword = scheme.encoder.encode(pk.parameter, message, signature.rho, test_epoch) + original_codeword = scheme.encoder.encode(pk.parameter, message, signature.rho, test_slot) tampered_codeword = scheme.encoder.encode( - pk.parameter, tampered_message, signature.rho, test_epoch + pk.parameter, tampered_message, signature.rho, test_slot ) if tampered_codeword != original_codeword: # Different codewords: verification must fail - is_invalid_msg = scheme.verify(pk, test_epoch, tampered_message, signature) + is_invalid_msg = scheme.verify(pk, test_slot, tampered_message, signature) assert not is_invalid_msg, "Verification succeeded for a tampered message" else: # Codeword collision: verification succeeds (expected with small test parameters) - is_collision_valid = scheme.verify(pk, test_epoch, tampered_message, signature) + is_collision_valid = scheme.verify(pk, test_slot, tampered_message, signature) assert is_collision_valid, "Verification failed despite identical codewords" - # Verification must fail if the epoch is incorrect. - if num_active_epochs > 1: - wrong_epoch = Uint64(int(test_epoch) + 1) - is_invalid_epoch = scheme.verify(pk, wrong_epoch, message, signature) - assert not is_invalid_epoch, "Verification succeeded for an incorrect epoch" + # Verification must fail if the slot is incorrect. + if num_active_slots > 1: + wrong_slot = Slot(int(test_slot) + 1) + is_invalid_slot = scheme.verify(pk, wrong_slot, message, signature) + assert not is_invalid_slot, "Verification succeeded for an incorrect slot" @pytest.mark.parametrize( - "activation_epoch, num_active_epochs", + "activation_slot, num_active_slots", [ pytest.param( 4, 4, id="Standard case with a short, active lifetime", marks=pytest.mark.slow ), - pytest.param(0, 8, id="Lifetime starting at epoch 0", marks=pytest.mark.slow), - pytest.param(7, 5, id="Lifetime starting at an odd-numbered epoch", marks=pytest.mark.slow), - pytest.param(12, 1, id="Lifetime with only a single active epoch"), + pytest.param(0, 8, id="Lifetime starting at slot 0", marks=pytest.mark.slow), + pytest.param(7, 5, id="Lifetime starting at an odd-numbered slot", marks=pytest.mark.slow), + pytest.param(12, 1, id="Lifetime with only a single active slot"), ], ) -def test_signature_scheme_correctness(activation_epoch: int, num_active_epochs: int) -> None: +def test_signature_scheme_correctness(activation_slot: int, num_active_slots: int) -> None: """Runs an end-to-end test of the signature scheme.""" _test_correctness_roundtrip( scheme=TEST_SIGNATURE_SCHEME, - activation_epoch=activation_epoch, - num_active_epochs=num_active_epochs, + activation_slot=activation_slot, + num_active_slots=num_active_slots, ) def test_get_activation_interval() -> None: """Tests that get_activation_interval returns the correct range.""" scheme = TEST_SIGNATURE_SCHEME - # Use 8 epochs (half of LIFETIME=16) - pk, sk = scheme.key_gen(Uint64(4), Uint64(8)) + # Use 8 slots (half of LIFETIME=16) + pk, sk = scheme.key_gen(Slot(4), Uint64(8)) interval = scheme.get_activation_interval(sk) @@ -116,14 +117,14 @@ def test_get_prepared_interval() -> None: """Tests that get_prepared_interval returns the correct range.""" scheme = TEST_SIGNATURE_SCHEME # Use full lifetime - pk, sk = scheme.key_gen(Uint64(0), Uint64(16)) + pk, sk = scheme.key_gen(Slot(0), Uint64(16)) interval = scheme.get_prepared_interval(sk) # Verify it's a range assert isinstance(interval, range) - # Verify it has at least 2 * sqrt(LIFETIME) epochs + # Verify it has at least 2 * sqrt(LIFETIME) slots leafs_per_bottom_tree = 1 << (scheme.config.LOG_LIFETIME // 2) min_prepared = 2 * leafs_per_bottom_tree assert len(interval) >= min_prepared @@ -132,9 +133,9 @@ def test_get_prepared_interval() -> None: def test_advance_preparation() -> None: """Tests that advance_preparation correctly slides the window.""" scheme = TEST_SIGNATURE_SCHEME - # Request 3 bottom trees' worth of epochs to ensure room to advance + # Request 3 bottom trees' worth of slots to ensure room to advance leafs_per_bottom_tree = 1 << (scheme.config.LOG_LIFETIME // 2) - pk, sk = scheme.key_gen(Uint64(0), Uint64(3 * leafs_per_bottom_tree)) + pk, sk = scheme.key_gen(Slot(0), Uint64(3 * leafs_per_bottom_tree)) # Get initial prepared interval initial_interval = scheme.get_prepared_interval(sk) @@ -157,11 +158,11 @@ def test_advance_preparation() -> None: def test_sign_requires_prepared_interval() -> None: - """Tests that sign raises an error if epoch is outside prepared interval.""" + """Tests that sign raises an error if slot is outside prepared interval.""" scheme = TEST_SIGNATURE_SCHEME - # Request 3 bottom trees' worth of epochs to have room for testing + # Request 3 bottom trees' worth of slots to have room for testing leafs_per_bottom_tree = 1 << (scheme.config.LOG_LIFETIME // 2) - pk, sk = scheme.key_gen(Uint64(0), Uint64(3 * leafs_per_bottom_tree)) + pk, sk = scheme.key_gen(Slot(0), Uint64(3 * leafs_per_bottom_tree)) # Get the prepared interval prepared_interval = scheme.get_prepared_interval(sk) @@ -169,7 +170,7 @@ def test_sign_requires_prepared_interval() -> None: # Try to sign outside the prepared interval (but inside activation interval) activation_interval = scheme.get_activation_interval(sk) # Pick an epoch just beyond the prepared interval - outside_epoch = Uint64(prepared_interval.stop) + outside_epoch = Slot(prepared_interval.stop) # Verify it's inside activation but outside prepared assert int(outside_epoch) in activation_interval @@ -185,10 +186,10 @@ def test_deterministic_signing() -> None: """Tests that signing the same message with the same key produces the same signature.""" scheme = TEST_SIGNATURE_SCHEME # Use full lifetime - pk, sk = scheme.key_gen(Uint64(0), Uint64(16)) + pk, sk = scheme.key_gen(Slot(0), Uint64(16)) # Use epoch within prepared interval - epoch = Uint64(4) + epoch = Slot(4) message = Bytes32(b"\x42" * 32) # Sign twice @@ -209,36 +210,36 @@ class TestVerifySecurityBounds: This prevents denial-of-service via malformed signatures. """ - def test_rejects_epoch_beyond_lifetime(self) -> None: - """verify returns False when epoch exceeds scheme LIFETIME.""" + def test_rejects_slot_beyond_lifetime(self) -> None: + """verify returns False when slot exceeds scheme LIFETIME.""" scheme = TEST_SIGNATURE_SCHEME # Generate valid keys. - pk, sk = scheme.key_gen(Uint64(0), Uint64(scheme.config.LIFETIME)) + pk, sk = scheme.key_gen(Slot(0), Uint64(int(scheme.config.LIFETIME))) # Sign a valid message at a valid epoch. - valid_epoch = Uint64(4) + valid_epoch = Slot(4) message = Bytes32(b"\x42" * 32) signature = scheme.sign(sk, valid_epoch, message) # Verify with an epoch beyond LIFETIME. - invalid_epoch = Uint64(int(scheme.config.LIFETIME) + 1) + invalid_epoch = Slot(int(scheme.config.LIFETIME) + 1) # Must return False, not raise. result = scheme.verify(pk, invalid_epoch, message, signature) assert result is False - def test_rejects_very_large_epoch(self) -> None: - """verify returns False for absurdly large epoch values.""" + def test_rejects_very_large_slot(self) -> None: + """verify returns False for absurdly large slot values.""" scheme = TEST_SIGNATURE_SCHEME - pk, sk = scheme.key_gen(Uint64(0), Uint64(scheme.config.LIFETIME)) + pk, sk = scheme.key_gen(Slot(0), Uint64(int(scheme.config.LIFETIME))) - valid_epoch = Uint64(4) + valid_epoch = Slot(4) message = Bytes32(b"\x42" * 32) signature = scheme.sign(sk, valid_epoch, message) # Try to verify with a huge epoch. - huge_epoch = Uint64(2**32) + huge_epoch = Slot(2**32) # Must return False, not raise. result = scheme.verify(pk, huge_epoch, message, signature) diff --git a/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py b/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py index 6ac400f4..481016c3 100644 --- a/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py +++ b/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py @@ -1,5 +1,6 @@ """Tests for SSZ serialization of XMSS types.""" +from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.xmss.constants import TEST_CONFIG from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME @@ -9,9 +10,9 @@ def test_public_key_ssz_roundtrip() -> None: """Test that PublicKey can be SSZ serialized and deserialized.""" # Generate a key pair - activation_epoch = Uint64(0) - num_active_epochs = Uint64(32) - public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_epoch, num_active_epochs) + activation_slot = Slot(0) + num_active_slots = Uint64(32) + public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) # Serialize to bytes using SSZ pk_bytes = public_key.encode_bytes() @@ -28,12 +29,12 @@ def test_public_key_ssz_roundtrip() -> None: def test_signature_ssz_roundtrip() -> None: """Test that Signature can be SSZ serialized and deserialized.""" # Generate a key pair and sign a message - activation_epoch = Uint64(0) - num_active_epochs = Uint64(32) - public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_epoch, num_active_epochs) + activation_slot = Slot(0) + num_active_slots = Uint64(32) + public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) message = Bytes32(bytes([42] * 32)) - epoch = Uint64(0) + epoch = Slot(0) signature = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) # Serialize to bytes using SSZ @@ -55,9 +56,9 @@ def test_signature_ssz_roundtrip() -> None: def test_secret_key_ssz_roundtrip() -> None: """Test that SecretKey can be SSZ serialized and deserialized.""" # Generate a key pair - activation_epoch = Uint64(0) - num_active_epochs = Uint64(32) - public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_epoch, num_active_epochs) + activation_slot = Slot(0) + num_active_slots = Uint64(32) + public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) # Serialize to bytes using SSZ sk_bytes = secret_key.encode_bytes() @@ -68,8 +69,8 @@ def test_secret_key_ssz_roundtrip() -> None: # Verify the recovered secret key matches the original assert recovered_sk.prf_key == secret_key.prf_key assert recovered_sk.parameter == secret_key.parameter - assert recovered_sk.activation_epoch == secret_key.activation_epoch - assert recovered_sk.num_active_epochs == secret_key.num_active_epochs + assert recovered_sk.activation_slot == secret_key.activation_slot + assert recovered_sk.num_active_slots == secret_key.num_active_slots assert recovered_sk.top_tree == secret_key.top_tree assert recovered_sk.left_bottom_tree_index == secret_key.left_bottom_tree_index assert recovered_sk.left_bottom_tree == secret_key.left_bottom_tree @@ -78,7 +79,7 @@ def test_secret_key_ssz_roundtrip() -> None: # Verify the recovered secret key can still sign message = Bytes32(bytes([99] * 32)) - epoch = Uint64(1) + epoch = Slot(1) signature = TEST_SIGNATURE_SCHEME.sign(recovered_sk, epoch, message) assert TEST_SIGNATURE_SCHEME.verify(public_key, epoch, message, signature) @@ -86,9 +87,9 @@ def test_secret_key_ssz_roundtrip() -> None: def test_deterministic_serialization() -> None: """Test that serialization is deterministic.""" # Generate a key pair - activation_epoch = Uint64(0) - num_active_epochs = Uint64(32) - public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_epoch, num_active_epochs) + activation_slot = Slot(0) + num_active_slots = Uint64(32) + public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) # Serialize multiple times pk_bytes1 = public_key.encode_bytes() @@ -102,7 +103,7 @@ def test_deterministic_serialization() -> None: # Sign a message multiple times with deterministic randomness message = Bytes32(bytes([42] * 32)) - epoch = Uint64(0) + epoch = Slot(0) sig1 = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) sig2 = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) @@ -116,12 +117,12 @@ def test_deterministic_serialization() -> None: def test_signature_size_matches_config() -> None: """Verify SIGNATURE_LEN_BYTES matches actual SSZ-encoded size.""" - activation_epoch = Uint64(0) - num_active_epochs = Uint64(32) - public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_epoch, num_active_epochs) + activation_slot = Slot(0) + num_active_slots = Uint64(32) + public_key, secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) message = Bytes32(bytes([42] * 32)) - epoch = Uint64(0) + epoch = Slot(0) signature = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) encoded = signature.encode_bytes() @@ -130,9 +131,9 @@ def test_signature_size_matches_config() -> None: def test_public_key_size_matches_config() -> None: """Verify PUBLIC_KEY_LEN_BYTES matches actual SSZ-encoded size.""" - activation_epoch = Uint64(0) - num_active_epochs = Uint64(32) - public_key, _ = TEST_SIGNATURE_SCHEME.key_gen(activation_epoch, num_active_epochs) + activation_slot = Slot(0) + num_active_slots = Uint64(32) + public_key, _ = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) encoded = public_key.encode_bytes() assert len(encoded) == TEST_CONFIG.PUBLIC_KEY_LEN_BYTES diff --git a/tests/lean_spec/subspecs/xmss/test_utils.py b/tests/lean_spec/subspecs/xmss/test_utils.py index b7d9ae77..6a324ae7 100644 --- a/tests/lean_spec/subspecs/xmss/test_utils.py +++ b/tests/lean_spec/subspecs/xmss/test_utils.py @@ -92,21 +92,21 @@ def test_expand_activation_time( # Verify alignment c = 1 << (log_lifetime // 2) - actual_start_epoch = start_tree * c - actual_end_epoch = end_tree * c - assert actual_start_epoch % c == 0 - assert actual_end_epoch % c == 0 + actual_start_slot = start_tree * c + actual_end_slot = end_tree * c + assert actual_start_slot % c == 0 + assert actual_end_slot % c == 0 # Verify it covers the desired range (if the desired range fits within lifetime) lifetime = c * c - desired_end_epoch = desired_activation + desired_num - if desired_end_epoch <= lifetime: - assert actual_start_epoch <= desired_activation - assert actual_end_epoch >= desired_end_epoch + desired_end_slot = desired_activation + desired_num + if desired_end_slot <= lifetime: + assert actual_start_slot <= desired_activation + assert actual_end_slot >= desired_end_slot else: # If desired range exceeds lifetime, verify it's clamped to lifetime bounds - assert actual_start_epoch >= 0 - assert actual_end_epoch <= lifetime + assert actual_start_slot >= 0 + assert actual_end_slot <= lifetime def test_hash_subtree_from_prf_key() -> None: From 7552e0209efc9f9a087876e07226ccb3e44bd9d0 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 20 Feb 2026 00:11:04 +0100 Subject: [PATCH 2/5] cleanup --- packages/testing/src/consensus_testing/keys.py | 2 +- src/lean_spec/subspecs/xmss/prf.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 9e91bc19..5fdb04f6 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -320,7 +320,7 @@ def sign_attestation_data( # Advance key state until slot is in prepared interval prepared = self.scheme.get_prepared_interval(sk) - while int(slot) not in prepared: + while slot not in prepared: activation = self.scheme.get_activation_interval(sk) if prepared.stop >= activation.stop: raise ValueError(f"Slot {slot} exceeds key lifetime {activation.stop}") diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index 2f514e37..8dd455a9 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -164,7 +164,7 @@ def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> HashDigestVe PRF_DOMAIN_SEP + PRF_DOMAIN_SEP_DOMAIN_ELEMENT + key - + int(epoch).to_bytes(4, "big") + + epoch.to_bytes(4, "big") + chain_index.to_bytes(8, "big") ) @@ -222,9 +222,9 @@ def get_randomness( PRF_DOMAIN_SEP + PRF_DOMAIN_SEP_RANDOMNESS + key - + int(epoch).to_bytes(4, "big") + + epoch.to_bytes(4, "big") + message - + int(counter).to_bytes(8, "big") + + counter.to_bytes(8, "big") ) # Extract enough bytes for RAND_LEN_FE field elements From b3e871f6e7fd46fa60fd9d6b77f5c8b2b77d5b2e Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 20 Feb 2026 00:15:33 +0100 Subject: [PATCH 3/5] small fix --- packages/testing/src/consensus_testing/keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 5fdb04f6..9e91bc19 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -320,7 +320,7 @@ def sign_attestation_data( # Advance key state until slot is in prepared interval prepared = self.scheme.get_prepared_interval(sk) - while slot not in prepared: + while int(slot) not in prepared: activation = self.scheme.get_activation_interval(sk) if prepared.stop >= activation.stop: raise ValueError(f"Slot {slot} exceeds key lifetime {activation.stop}") From 1b23e705a45b80f357c8e21a0e23e2570859e5c8 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 20 Feb 2026 00:17:45 +0100 Subject: [PATCH 4/5] small fix --- tests/interop/helpers/node_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/interop/helpers/node_runner.py b/tests/interop/helpers/node_runner.py index 5ee9dc1e..e6cea22a 100644 --- a/tests/interop/helpers/node_runner.py +++ b/tests/interop/helpers/node_runner.py @@ -363,7 +363,7 @@ async def start_node( # This matches the Ethereum gossip specification. if validator_indices: for idx in validator_indices: - subnet_id = idx % int(ATTESTATION_COMMITTEE_COUNT) + subnet_id = int(idx) % int(ATTESTATION_COMMITTEE_COUNT) topic = f"/leanconsensus/{self.fork_digest}/attestation_{subnet_id}/ssz_snappy" event_source.subscribe_gossip_topic(topic) From 4075732deaff9506070d1e85dc408449040384a5 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 20 Feb 2026 00:27:43 +0100 Subject: [PATCH 5/5] small fix --- tests/interop/helpers/node_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/interop/helpers/node_runner.py b/tests/interop/helpers/node_runner.py index e6cea22a..3ce09e5c 100644 --- a/tests/interop/helpers/node_runner.py +++ b/tests/interop/helpers/node_runner.py @@ -428,7 +428,7 @@ async def start_all( # This allows the gossipsub mesh to form before validators start # producing blocks and attestations. Otherwise, early blocks/attestations # would be "Published message to 0 peers" because the mesh is empty. - aggregator_indices = set(range(int(ATTESTATION_COMMITTEE_COUNT))) + aggregator_indices = {ValidatorIndex(i) for i in range(int(ATTESTATION_COMMITTEE_COUNT))} for i in range(num_nodes): validator_indices = validators_per_node[i] if i < len(validators_per_node) else []