Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,16 @@ def sign_attestation_data(
Raises:
ValueError: If slot exceeds key lifetime.
"""
epoch = attestation_data.slot
slot = attestation_data.slot
kp = self[validator_id]
sk = kp.secret

# Advance key state until epoch is in prepared interval
# Advance key state until slot is in prepared interval
prepared = self.scheme.get_prepared_interval(sk)
while int(epoch) not in prepared:
while int(slot) not in prepared:
activation = self.scheme.get_activation_interval(sk)
if prepared.stop >= activation.stop:
raise ValueError(f"Epoch {epoch} exceeds key lifetime {activation.stop}")
raise ValueError(f"Slot {slot} exceeds key lifetime {activation.stop}")
sk = self.scheme.advance_preparation(sk)
prepared = self.scheme.get_prepared_interval(sk)

Expand All @@ -332,7 +332,7 @@ def sign_attestation_data(

# Sign hash tree root of the attestation data
message = attestation_data.data_root_bytes()
return self.scheme.sign(sk, epoch, message)
return self.scheme.sign(sk, slot, message)

def build_attestation_signatures(
self,
Expand All @@ -351,7 +351,7 @@ def build_attestation_signatures(
for agg in aggregated_attestations:
validator_ids = agg.aggregation_bits.to_validator_indices()
message = agg.data.data_root_bytes()
epoch = agg.data.slot
slot = agg.data.slot

public_keys: list[PublicKey] = [self.get_public_key(vid) for vid in validator_ids]
signatures: list[Signature] = [
Expand All @@ -370,19 +370,19 @@ def build_attestation_signatures(
public_keys=public_keys,
signatures=signatures,
message=message,
epoch=epoch,
slot=slot,
)
proofs.append(proof)

return AttestationSignatures(data=proofs)


def _generate_single_keypair(
scheme: GeneralizedXmssScheme, num_epochs: int, index: int
scheme: GeneralizedXmssScheme, num_slots: int, index: int
) -> dict[str, str]:
"""Generate one key pair (module-level for pickling in ProcessPoolExecutor)."""
print(f"Starting key #{index} generation...")
pk, sk = scheme.key_gen(Uint64(0), Uint64(num_epochs))
pk, sk = scheme.key_gen(Slot(0), Uint64(num_slots))
return KeyPair(public=pk, secret=sk).to_dict()


Expand All @@ -397,20 +397,20 @@ def _generate_keys(lean_env: str, count: int, max_slot: int) -> None:
Args:
lean_env: Name of the XMSS signature scheme to use (e.g. "test" or "prod").
count: Number of validators.
max_slot: Maximum slot (key lifetime = max_slot + 1 epochs).
max_slot: Maximum slot (key lifetime = max_slot + 1 slots).
"""
scheme = LEAN_ENV_TO_SCHEMES[lean_env]
keys_dir = get_keys_dir(lean_env)
num_epochs = max_slot + 1
num_slots = max_slot + 1
num_workers = os.cpu_count() or 1

print(
f"Generating {count} XMSS key pairs for {lean_env} environment "
f"({num_epochs} epochs) using {num_workers} cores..."
f"({num_slots} slots) using {num_workers} cores..."
)

with ProcessPoolExecutor(max_workers=num_workers) as executor:
worker_func = partial(_generate_single_keypair, scheme, num_epochs)
worker_func = partial(_generate_single_keypair, scheme, num_slots)
key_pairs = list(executor.map(worker_func, range(count)))

# Create keys directory (remove old one if it exists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pydantic import model_validator

from lean_spec.subspecs.chain.clock import Interval
from lean_spec.subspecs.chain.config import (
INTERVALS_PER_SLOT,
MILLISECONDS_PER_INTERVAL,
Expand Down Expand Up @@ -246,7 +247,7 @@ def make_fixture(self) -> Self:
# TickStep.time is a Unix timestamp in seconds.
# Convert to intervals since genesis for the store.
delta_ms = (Uint64(step.time) - store.config.genesis_time) * Uint64(1000)
target_interval = delta_ms // MILLISECONDS_PER_INTERVAL
target_interval = Interval(delta_ms // MILLISECONDS_PER_INTERVAL)
store, _ = store.on_tick(
target_interval, has_proposal=False, is_aggregator=True
)
Expand Down Expand Up @@ -276,7 +277,7 @@ def make_fixture(self) -> Self:
# Store rejects blocks from the future.
# This tick includes a block (has proposal).
# Always act as aggregator to ensure gossip signatures are aggregated
target_interval = block.slot * INTERVALS_PER_SLOT
target_interval = Interval(block.slot * INTERVALS_PER_SLOT)
store, _ = store.on_tick(
target_interval, has_proposal=True, is_aggregator=True
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _build_block_from_spec(
public_keys=signer_public_keys,
signatures=signer_signatures,
message=data_root,
epoch=attestation_data.slot,
slot=attestation_data.slot,
)
# Replace participants with claimed validator_ids (mismatch!)
invalid_proof = AggregatedSignatureProof(
Expand Down
3 changes: 2 additions & 1 deletion src/lean_spec/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from lean_spec.subspecs.containers import Block, BlockBody, Checkpoint, State
from lean_spec.subspecs.containers.block.types import AggregatedAttestations
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.validator import SubnetId
from lean_spec.subspecs.forkchoice import Store
from lean_spec.subspecs.genesis import GenesisConfig
from lean_spec.subspecs.networking.client import LiveNetworkEventSource
Expand Down Expand Up @@ -489,7 +490,7 @@ async def run_node(
# Subscribe to attestation subnet topics based on local validator id.
validator_id = validator_registry.primary_index() if validator_registry else None
if validator_id is None:
subnet_id = 0
subnet_id = SubnetId(0)
logger.info("No local validator id; subscribing to attestation subnet %d", subnet_id)
else:
subnet_id = validator_id.compute_subnet_id(ATTESTATION_COMMITTEE_COUNT)
Expand Down
9 changes: 5 additions & 4 deletions src/lean_spec/subspecs/chain/clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from .config import MILLISECONDS_PER_INTERVAL, MILLISECONDS_PER_SLOT, SECONDS_PER_SLOT

Interval = Uint64
"""Interval count since genesis (matches ``Store.time``)."""

class Interval(Uint64):
"""Interval count since genesis (matches ``Store.time``)."""


@dataclass(frozen=True, slots=True)
Expand Down Expand Up @@ -57,15 +58,15 @@ def current_slot(self) -> Slot:
def current_interval(self) -> Interval:
"""Get the current interval within the slot (0-4)."""
milliseconds_into_slot = self._milliseconds_since_genesis() % MILLISECONDS_PER_SLOT
return milliseconds_into_slot // MILLISECONDS_PER_INTERVAL
return Interval(milliseconds_into_slot // MILLISECONDS_PER_INTERVAL)

def total_intervals(self) -> Interval:
"""
Get total intervals elapsed since genesis.

This is the value expected by our store time type.
"""
return self._milliseconds_since_genesis() // MILLISECONDS_PER_INTERVAL
return Interval(self._milliseconds_since_genesis() // MILLISECONDS_PER_INTERVAL)

def current_time(self) -> Uint64:
"""Get current wall-clock time as Uint64 (Unix timestamp in seconds)."""
Expand Down
3 changes: 1 addition & 2 deletions src/lean_spec/subspecs/containers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,11 @@ def verify_signatures(
public_keys = [validators[vid].get_pubkey() for vid in validator_ids]

# Verify the aggregated signature against all public keys.
# Uses slot as epoch for XMSS one-time signature indexing.
try:
aggregated_signature.verify(
public_keys=public_keys,
message=attestation_data_root,
epoch=aggregated_attestation.data.slot,
slot=aggregated_attestation.data.slot,
)
except AggregationError as exc:
raise AssertionError(
Expand Down
4 changes: 2 additions & 2 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def aggregate_gossip_signatures(
public_keys=gossip_keys,
signatures=gossip_sigs,
message=data_root,
epoch=data.slot,
slot=data.slot,
)
attestation = AggregatedAttestation(aggregation_bits=participants, data=data)
results.append((attestation, proof))
Expand Down Expand Up @@ -915,7 +915,7 @@ def select_aggregated_proofs(
) # validators contributed to this attestation

# Validators that are missing in the current aggregation are put into remaining.
remaining: set[Uint64] = set(validator_ids)
remaining: set[ValidatorIndex] = set(validator_ids)

# Fallback to existing proofs
#
Expand Down
10 changes: 7 additions & 3 deletions src/lean_spec/subspecs/containers/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from .slot import Slot


class SubnetId(Uint64):
"""Subnet identifier (0-63) for attestation subnet partitioning."""


class ValidatorIndex(Uint64):
"""Represents a validator's unique index as a 64-bit unsigned integer."""

Expand All @@ -24,16 +28,16 @@ def is_valid(self, num_validators: int) -> bool:
"""Check if this index is within valid bounds for a registry of given size."""
return int(self) < num_validators

def compute_subnet_id(self, num_committees: "int | Uint64") -> int:
def compute_subnet_id(self, num_committees: int) -> SubnetId:
"""Compute the attestation subnet id for this validator.

Args:
num_committees: Positive number of committees.

Returns:
An integer subnet id in 0..(num_committees-1).
A SubnetId in 0..(num_committees-1).
"""
return int(self) % int(num_committees)
return SubnetId(int(self) % int(num_committees))


class ValidatorIndices(SSZList[ValidatorIndex]):
Expand Down
16 changes: 9 additions & 7 deletions src/lean_spec/subspecs/forkchoice/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import copy
from collections import defaultdict

from lean_spec.subspecs.chain.clock import Interval
from lean_spec.subspecs.chain.config import (
ATTESTATION_COMMITTEE_COUNT,
INTERVALS_PER_SLOT,
Expand Down Expand Up @@ -62,7 +63,7 @@ class Store(Container):
- or when the head is recomputed.
"""

time: Uint64
time: Interval
"""Current time in intervals since genesis."""

config: Config
Expand Down Expand Up @@ -214,7 +215,7 @@ def get_forkchoice_store(
# but the Store must treat the anchor block as the justified/finalized point.

return cls(
time=Uint64(anchor_slot * INTERVALS_PER_SLOT),
time=Interval(anchor_slot * INTERVALS_PER_SLOT),
config=anchor_state.config,
head=anchor_root,
safe_target=anchor_root,
Expand Down Expand Up @@ -476,7 +477,7 @@ def on_gossip_aggregated_attestation(
proof.verify(
public_keys=public_keys,
message=data.data_root_bytes(),
epoch=data.slot,
slot=data.slot,
)
except AggregationError as exc:
raise AssertionError(
Expand Down Expand Up @@ -1117,7 +1118,7 @@ def tick_interval(
Tuple of (new store with advanced time, list of new signed aggregated attestation).
"""
# Advance time by one interval
store = self.model_copy(update={"time": self.time + Uint64(1)})
store = self.model_copy(update={"time": Interval(int(self.time) + 1)})
current_interval = store.time % INTERVALS_PER_SLOT
new_aggregates: list[SignedAggregatedAttestation] = []

Expand All @@ -1139,7 +1140,7 @@ def tick_interval(
return store, new_aggregates

def on_tick(
self, target_interval: Uint64, has_proposal: bool, is_aggregator: bool = False
self, target_interval: Interval, has_proposal: bool, is_aggregator: bool = False
) -> tuple["Store", list[SignedAggregatedAttestation]]:
"""
Advance forkchoice store time to given interval count.
Expand All @@ -1163,7 +1164,8 @@ def on_tick(
# Tick forward one interval at a time
while store.time < target_interval:
# Check if proposal should be signaled for next interval
should_signal_proposal = has_proposal and (store.time + Uint64(1)) == target_interval
next_interval = Interval(int(store.time) + 1)
should_signal_proposal = has_proposal and next_interval == target_interval

# Advance by one interval with appropriate signaling
store, new_aggregates = store.tick_interval(should_signal_proposal, is_aggregator)
Expand Down Expand Up @@ -1193,7 +1195,7 @@ def get_proposal_head(self, slot: Slot) -> tuple["Store", Bytes32]:
Tuple of (new Store with updated time, head root for building).
"""
# Advance time to this slot's first interval
target_interval = Uint64(slot * INTERVALS_PER_SLOT)
target_interval = Interval(slot * INTERVALS_PER_SLOT)
store, _ = self.on_tick(target_interval, True)

# Process any pending attestations before proposal
Expand Down
21 changes: 12 additions & 9 deletions src/lean_spec/subspecs/networking/discovery/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from threading import Lock

from lean_spec.subspecs.networking.enr import ENR
from lean_spec.subspecs.networking.types import NodeId
from lean_spec.subspecs.networking.types import NodeId, SeqNumber
from lean_spec.types import Bytes32, Bytes33, Bytes64

from .config import HANDSHAKE_TIMEOUT_SECS
Expand All @@ -41,7 +41,7 @@
verify_id_nonce_signature,
)
from .keys import derive_keys_from_pubkey
from .messages import PacketFlag
from .messages import PacketFlag, Port
from .packet import (
HandshakeAuthdata,
WhoAreYouAuthdata,
Expand All @@ -52,6 +52,9 @@
)
from .session import Session, SessionCache

_DEFAULT_PORT = Port(0)
"""Default port value for optional port parameters."""

MAX_PENDING_HANDSHAKES = 100
"""Hard cap on concurrent pending handshakes to prevent resource exhaustion."""

Expand Down Expand Up @@ -97,7 +100,7 @@ class PendingHandshake:
challenge_nonce: bytes | None = None
"""12-byte nonce from the packet that triggered WHOAREYOU."""

remote_enr_seq: int = 0
remote_enr_seq: SeqNumber = SeqNumber(0)
"""ENR seq we sent in WHOAREYOU. If 0, remote MUST include their ENR in HANDSHAKE."""

started_at: float = field(default_factory=time.time)
Expand Down Expand Up @@ -144,7 +147,7 @@ def __init__(
local_node_id: NodeId,
local_private_key: bytes,
local_enr_rlp: bytes,
local_enr_seq: int,
local_enr_seq: SeqNumber,
session_cache: SessionCache,
timeout_secs: float = HANDSHAKE_TIMEOUT_SECS,
):
Expand Down Expand Up @@ -204,7 +207,7 @@ def create_whoareyou(
self,
remote_node_id: NodeId,
request_nonce: bytes,
remote_enr_seq: int,
remote_enr_seq: SeqNumber,
masking_iv: bytes,
) -> tuple[bytes, bytes, bytes, bytes]:
"""
Expand Down Expand Up @@ -255,7 +258,7 @@ def create_handshake_response(
remote_pubkey: bytes,
challenge_data: bytes,
remote_ip: str = "",
remote_port: int = 0,
remote_port: Port = _DEFAULT_PORT,
) -> tuple[bytes, bytes, bytes]:
"""
Create a HANDSHAKE packet in response to WHOAREYOU.
Expand Down Expand Up @@ -293,7 +296,7 @@ def create_handshake_response(

# Include our ENR if the remote's known seq is stale.
record = None
if int(whoareyou.enr_seq) < self._local_enr_seq:
if whoareyou.enr_seq < self._local_enr_seq:
record = self._local_enr_rlp

# Build authdata.
Expand Down Expand Up @@ -338,7 +341,7 @@ def handle_handshake(
remote_node_id: NodeId,
handshake: HandshakeAuthdata,
remote_ip: str = "",
remote_port: int = 0,
remote_port: Port = _DEFAULT_PORT,
) -> HandshakeResult:
"""
Process a received HANDSHAKE packet.
Expand Down Expand Up @@ -383,7 +386,7 @@ def handle_handshake(
# If we sent enr_seq=0, we signaled that we don't know the remote's ENR.
# Per spec, the remote MUST include their ENR in the HANDSHAKE response
# so we can verify their identity.
if remote_enr_seq == 0 and handshake.record is None:
if remote_enr_seq == SeqNumber(0) and handshake.record is None:
raise HandshakeError(
f"ENR required in HANDSHAKE from unknown node {remote_node_id.hex()[:16]}"
)
Expand Down
Loading
Loading