From 2bf781fcecd8c58b97168334f9d1bfd1442e2ba3 Mon Sep 17 00:00:00 2001 From: anuragShingare30 Date: Tue, 24 Feb 2026 17:51:27 +0530 Subject: [PATCH 1/4] feat: demo for node connectivity --- cli.py | 344 ++++++++++++++++++++++++++++++++++++ minichain/chain.py | 160 ++++++++++++++++- minichain/p2p.py | 429 ++++++++++++++++++++++++++++++++++++--------- 3 files changed, 849 insertions(+), 84 deletions(-) create mode 100644 cli.py diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..5d8b73a --- /dev/null +++ b/cli.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +""" +MiniChain Node CLI + +Simple command-line interface for running a MiniChain testnet node. + +Usage: + # Start first node (no peers) + python cli.py --port 9000 + + # Start second node, connect to first + python cli.py --port 9001 --peers 127.0.0.1:9000 + + # Enable mining + python cli.py --port 9000 --mine + + # With custom miner address + python cli.py --port 9000 --mine --miner
+""" + +import argparse +import asyncio +import logging +import signal +import sys + +from minichain import Blockchain, Block, Mempool, P2PNetwork, Transaction, mine_block +from minichain.pow import calculate_hash + +logger = logging.getLogger(__name__) + + +class Node: + """MiniChain testnet node.""" + + def __init__(self, port: int, peers: list = None, mining: bool = False, miner_address: str = None): + self.port = port + self.initial_peers = peers or [] + self.mining = mining + self.miner_address = miner_address or "0" * 40 # Burn address if not set + + self.chain = Blockchain() + self.mempool = Mempool() + self.network = P2PNetwork() + self.network.configure(port=port) + + self._running = False + self._sync_complete = False + + async def start(self): + """Start the node.""" + self._running = True + + # Register message handler + self.network.register_handler(self._handle_message) + + # Start network + await self.network.start() + logger.info("Node ID: %s", self.network.node_id) + + # Connect to initial peers + for peer_addr in self.initial_peers: + await self.network.connect_to_peer(peer_addr) + + # Request chain sync from peers + await self._sync_chain() + + # Start main loop + await self._run() + + async def stop(self): + """Stop the node gracefully.""" + logger.info("Stopping node...") + self._running = False + await self.network.stop() + logger.info("Node stopped") + + async def _sync_chain(self): + """Sync chain from connected peers.""" + if not self.network.peers: + logger.info("No peers connected, starting with genesis chain") + self._sync_complete = True + return + + logger.info("Requesting chain from %d peer(s)...", len(self.network.peers)) + + for peer in list(self.network.peers.values()): + await self.network.request_chain(peer) + + # Wait a bit for responses + await asyncio.sleep(2) + self._sync_complete = True + logger.info("Chain sync complete. Height: %d", len(self.chain.chain)) + + async def _handle_message(self, message: dict, sender): + """Handle incoming P2P messages.""" + msg_type = message.get("type") + msg_data = message.get("data", {}) + + try: + if msg_type == "tx": + await self._handle_tx(msg_data, sender) + + elif msg_type == "block": + await self._handle_block(msg_data, sender) + + elif msg_type == "get_chain": + await self._handle_get_chain(sender) + + elif msg_type == "chain": + await self._handle_chain(msg_data) + + except Exception: + logger.exception("Error handling message: %s", message) + + async def _handle_tx(self, tx_data: dict, sender): + """Handle incoming transaction.""" + try: + tx = Transaction(**tx_data) + if self.mempool.add_transaction(tx): + logger.info("Added tx to mempool from %s...", tx.sender[:8]) + # Relay to other peers (exclude sender) + sender_id = sender.node_id if sender else None + await self.network.broadcast({"type": "tx", "data": tx_data}, exclude_peer=sender_id) + except Exception as e: + logger.warning("Invalid transaction: %s", e) + + async def _handle_block(self, block_data: dict, sender): + """Handle incoming block.""" + try: + transactions = [Transaction(**tx) for tx in block_data.get("transactions", [])] + + block = Block( + index=block_data.get("index"), + previous_hash=block_data.get("previous_hash"), + transactions=transactions, + timestamp=block_data.get("timestamp"), + difficulty=block_data.get("difficulty") + ) + block.nonce = block_data.get("nonce", 0) + block.hash = block_data.get("hash") + + if self.chain.add_block(block): + logger.info("Added block #%d to chain", block.index) + # Relay to other peers + sender_id = sender.node_id if sender else None + await self.network.broadcast({"type": "block", "data": block_data}, exclude_peer=sender_id) + else: + logger.warning("Rejected block #%d", block.index) + + except Exception as e: + logger.warning("Invalid block: %s", e) + + async def _handle_get_chain(self, sender): + """Handle chain request - send our chain to requester.""" + if sender is None: + return + + chain_data = [block.to_dict() for block in self.chain.chain] + await self.network.send_chain(sender, chain_data) + logger.info("Sent chain (%d blocks) to peer %s", len(chain_data), sender.node_id) + + async def _handle_chain(self, chain_data: list): + """Handle received chain - validate and potentially replace ours.""" + if not chain_data: + return + + logger.info("Received chain with %d blocks", len(chain_data)) + + # Only replace if longer + if len(chain_data) <= len(self.chain.chain): + logger.info("Received chain not longer than ours, ignoring") + return + + # Validate and replace + if self.chain.replace_chain(chain_data): + logger.info("Replaced chain with received chain (new height: %d)", len(self.chain.chain)) + else: + logger.warning("Received chain validation failed") + + async def _run(self): + """Main node loop.""" + logger.info("Node running. Press Ctrl+C to stop.") + logger.info("Connected to %d peer(s)", self.network.get_peer_count()) + + mine_interval = 10 # seconds between mining attempts + last_mine_time = 0 + + while self._running: + try: + # Mining loop + if self.mining and self._sync_complete: + import time + now = time.time() + if now - last_mine_time >= mine_interval: + await self._try_mine_block() + last_mine_time = now + + # Keep event loop responsive + await asyncio.sleep(1) + + except asyncio.CancelledError: + break + + async def _try_mine_block(self): + """Attempt to mine a new block.""" + pending_txs = self.mempool.get_transactions_for_block() + + block = Block( + index=self.chain.last_block.index + 1, + previous_hash=self.chain.last_block.hash, + transactions=pending_txs, + ) + + try: + # Mine with low difficulty for testnet + mined_block = mine_block(block, difficulty=4, timeout_seconds=5) + + if self.chain.add_block(mined_block): + logger.info("Mined block #%d! Hash: %s...", mined_block.index, mined_block.hash[:16]) + + # Credit mining reward + self.chain.state.credit_mining_reward(self.miner_address) + + # Broadcast to peers + await self.network.broadcast_block(mined_block) + + except Exception as e: + # Mining timeout or other error - this is normal + pass + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="MiniChain Testnet Node", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python cli.py --port 9000 # Start first node + python cli.py --port 9001 --peers 127.0.0.1:9000 # Connect to peer + python cli.py --port 9000 --mine # Enable mining + """ + ) + + parser.add_argument( + "--port", "-p", + type=int, + default=9000, + help="Port to listen on (default: 9000)" + ) + + parser.add_argument( + "--peers", + type=str, + default="", + help="Comma-separated list of peer addresses (e.g., 192.168.1.10:9000,192.168.1.20:9001)" + ) + + parser.add_argument( + "--mine", + action="store_true", + help="Enable mining" + ) + + parser.add_argument( + "--miner", + type=str, + default=None, + help="Miner wallet address for rewards" + ) + + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug logging" + ) + + return parser.parse_args() + + +async def run_node(args): + """Run the node with given arguments.""" + # Parse peers + peers = [] + if args.peers: + peers = [p.strip() for p in args.peers.split(",") if p.strip()] + + # Create and start node + node = Node( + port=args.port, + peers=peers, + mining=args.mine, + miner_address=args.miner + ) + + # Handle shutdown signals + loop = asyncio.get_event_loop() + + def shutdown_handler(): + logger.info("Shutdown signal received") + asyncio.create_task(node.stop()) + + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, shutdown_handler) + except NotImplementedError: + # Windows doesn't support add_signal_handler + pass + + try: + await node.start() + except KeyboardInterrupt: + pass + finally: + await node.stop() + + +def main(): + """Main entry point.""" + args = parse_args() + + # Setup logging + level = logging.DEBUG if args.debug else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S" + ) + + print(f""" +╔══════════════════════════════════════════════╗ +║ MiniChain Testnet Node ║ +╚══════════════════════════════════════════════╝ + Port: {args.port} + Mining: {'Enabled' if args.mine else 'Disabled'} + Peers: {args.peers if args.peers else 'None'} +""") + + asyncio.run(run_node(args)) + + +if __name__ == "__main__": + main() diff --git a/minichain/chain.py b/minichain/chain.py index 78ac73f..9b201ce 100644 --- a/minichain/chain.py +++ b/minichain/chain.py @@ -1,4 +1,5 @@ from .block import Block +from .transaction import Transaction from .state import State from .pow import calculate_hash import logging @@ -12,6 +13,9 @@ class Blockchain: Manages the blockchain, validates blocks, and commits state transitions. """ + # Expected genesis hash (all zeros) + GENESIS_HASH = "0" * 64 + def __init__(self): self.chain = [] self.state = State() @@ -27,7 +31,7 @@ def _create_genesis_block(self): previous_hash="0", transactions=[] ) - genesis_block.hash = "0" * 64 + genesis_block.hash = self.GENESIS_HASH self.chain.append(genesis_block) @property @@ -35,9 +39,15 @@ def last_block(self): """ Returns the most recent block in the chain. """ - with self._lock: # Acquire lock for thread-safe access + with self._lock: # Acquire lock for thread-safe access return self.chain[-1] + @property + def height(self): + """Returns the current chain height (number of blocks).""" + with self._lock: + return len(self.chain) + def add_block(self, block): """ Validates and adds a block to the chain if all transactions succeed. @@ -75,3 +85,149 @@ def add_block(self, block): self.state = temp_state self.chain.append(block) return True + + def validate_chain(self, chain_data: list) -> bool: + """ + Validate a chain received from a peer. + + Checks: + 1. Genesis block matches our expected genesis + 2. Each block's hash is valid + 3. Each block's previous_hash links correctly + 4. All transactions in each block are valid + + Args: + chain_data: List of block dictionaries + + Returns: + True if chain is valid, False otherwise + """ + if not chain_data: + return False + + # Validate genesis block + genesis = chain_data[0] + if genesis.get("hash") != self.GENESIS_HASH: + logger.warning("Chain validation failed: Invalid genesis hash") + return False + + if genesis.get("index") != 0: + logger.warning("Chain validation failed: Genesis index not 0") + return False + + # Validate each subsequent block + temp_state = State() # Fresh state for validation + + for i in range(1, len(chain_data)): + block_data = chain_data[i] + prev_block = chain_data[i - 1] + + # Check index linkage + if block_data.get("index") != prev_block.get("index") + 1: + logger.warning("Chain validation failed: Invalid index at block %d", i) + return False + + # Check previous hash linkage + if block_data.get("previous_hash") != prev_block.get("hash"): + logger.warning("Chain validation failed: Invalid previous_hash at block %d", i) + return False + + # Reconstruct block and verify hash + try: + transactions = [Transaction(**tx) for tx in block_data.get("transactions", [])] + block = Block( + index=block_data.get("index"), + previous_hash=block_data.get("previous_hash"), + transactions=transactions, + timestamp=block_data.get("timestamp"), + difficulty=block_data.get("difficulty") + ) + block.nonce = block_data.get("nonce", 0) + + # Verify hash matches + computed_hash = calculate_hash(block.to_header_dict()) + if block_data.get("hash") != computed_hash: + logger.warning("Chain validation failed: Invalid hash at block %d", i) + return False + + # Validate and apply transactions + for tx in transactions: + if not temp_state.validate_and_apply(tx): + logger.warning("Chain validation failed: Invalid tx in block %d", i) + return False + + except Exception as e: + logger.warning("Chain validation failed at block %d: %s", i, e) + return False + + return True + + def replace_chain(self, chain_data: list) -> bool: + """ + Replace the current chain with a longer valid chain. + + Uses "longest valid chain wins" rule. + + Args: + chain_data: List of block dictionaries from peer + + Returns: + True if chain was replaced, False otherwise + """ + with self._lock: + # Only replace if longer + if len(chain_data) <= len(self.chain): + logger.info("Received chain not longer than ours (%d <= %d)", + len(chain_data), len(self.chain)) + return False + + # Validate the received chain + if not self.validate_chain(chain_data): + logger.warning("Received chain failed validation") + return False + + # Rebuild chain and state from scratch + logger.info("Replacing chain: %d -> %d blocks", len(self.chain), len(chain_data)) + + # Clear and rebuild + self.chain = [] + self.state = State() + + # Add genesis + genesis_data = chain_data[0] + genesis_block = Block( + index=0, + previous_hash="0", + transactions=[] + ) + genesis_block.hash = self.GENESIS_HASH + self.chain.append(genesis_block) + + # Add each subsequent block + for i in range(1, len(chain_data)): + block_data = chain_data[i] + transactions = [Transaction(**tx) for tx in block_data.get("transactions", [])] + + block = Block( + index=block_data.get("index"), + previous_hash=block_data.get("previous_hash"), + transactions=transactions, + timestamp=block_data.get("timestamp"), + difficulty=block_data.get("difficulty") + ) + block.nonce = block_data.get("nonce", 0) + block.hash = block_data.get("hash") + + # Apply transactions to state + for tx in transactions: + self.state.validate_and_apply(tx) + + self.chain.append(block) + + logger.info("Chain replaced successfully. New height: %d", len(self.chain)) + return True + + def to_dict_list(self) -> list: + """Export chain as list of block dictionaries.""" + with self._lock: + return [block.to_dict() for block in self.chain] diff --git a/minichain/p2p.py b/minichain/p2p.py index aacbf49..12a0305 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -1,118 +1,383 @@ import json import logging +import asyncio +import struct +import uuid +from typing import Dict, Set, Optional, Callable, Any logger = logging.getLogger(__name__) +# Message frame: 4-byte length prefix + JSON body +HEADER_SIZE = 4 +MAX_MESSAGE_SIZE = 10 * 1024 * 1024 # 10MB limit -class P2PNetwork: - """ - A minimal abstraction for Peer-to-Peer networking. - Expected incoming message interface for handle_message(): - msg must have attribute: - - data: bytes (JSON-encoded payload) +class Peer: + """Represents a connected peer.""" + + def __init__(self, node_id: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str): + self.node_id = node_id + self.reader = reader + self.writer = writer + self.address = address + + async def send(self, message: dict): + """Send a JSON message with length prefix.""" + try: + data = json.dumps(message).encode('utf-8') + header = struct.pack('>I', len(data)) + self.writer.write(header + data) + await self.writer.drain() + except Exception as e: + logger.error("Failed to send to peer %s: %s", self.address, e) + raise + + async def receive(self) -> Optional[dict]: + """Receive a length-prefixed JSON message.""" + try: + header = await self.reader.readexactly(HEADER_SIZE) + length = struct.unpack('>I', header)[0] + + if length > MAX_MESSAGE_SIZE: + logger.warning("Message too large from %s: %d bytes", self.address, length) + return None + + data = await self.reader.readexactly(length) + return json.loads(data.decode('utf-8')) + except asyncio.IncompleteReadError: + return None + except Exception as e: + logger.error("Failed to receive from peer %s: %s", self.address, e) + return None + + def close(self): + """Close the connection.""" + try: + self.writer.close() + except Exception: + pass + - JSON structure: - { - "type": "tx" | "block", - "data": {...} - } +class P2PNetwork: + """ + Real TCP-based P2P networking for MiniChain. + + Features: + - TCP server listening on a port + - Connect to known peers by IP:port + - Length-prefixed JSON message framing + - Explicit node ID + - Message deduplication (broadcast guard) """ - def __init__(self, handler_callback=None): - self._handler_callback = None + def __init__(self, handler_callback: Optional[Callable] = None): + self.node_id: str = str(uuid.uuid4())[:8] # Short unique ID + self.host: str = "0.0.0.0" + self.port: int = 9000 + + self._handler_callback: Optional[Callable] = None if handler_callback is not None: self.register_handler(handler_callback) - self.pubsub = None # Will be set in real implementation + + self.peers: Dict[str, Peer] = {} # node_id -> Peer + self.seen_messages: Set[str] = set() # For broadcast deduplication + self._server: Optional[asyncio.Server] = None + self._running: bool = False + self._tasks: list = [] - def register_handler(self, handler_callback): + def register_handler(self, handler_callback: Callable): + """Register callback for incoming messages.""" if not callable(handler_callback): raise ValueError("handler_callback must be callable") self._handler_callback = handler_callback + def configure(self, host: str = "0.0.0.0", port: int = 9000): + """Configure network settings before starting.""" + self.host = host + self.port = port + async def start(self): - logger.info("Network: Listening on /ip4/0.0.0.0/tcp/0") - # In real libp2p, we would await host.start() here + """Start the TCP server.""" + self._running = True + self._server = await asyncio.start_server( + self._handle_incoming_connection, + self.host, + self.port + ) + logger.info("Node %s listening on %s:%d", self.node_id, self.host, self.port) async def stop(self): - """Clean up network resources cleanly upon shutdown.""" - logger.info("Network: Shutting down") - if self.pubsub: - try: - shutdown_meth = None - for method_name in ('close', 'stop', 'aclose', 'shutdown'): - if hasattr(self.pubsub, method_name): - shutdown_meth = getattr(self.pubsub, method_name) - break - - if shutdown_meth: - import asyncio - res = shutdown_meth() - if asyncio.iscoroutine(res): - await res - except Exception as e: - logger.error("Network: Error shutting down pubsub: %s", e) - finally: - self.pubsub = None + """Clean shutdown of network.""" + logger.info("Node %s shutting down...", self.node_id) + self._running = False + + # Close all peer connections + for peer in list(self.peers.values()): + peer.close() + self.peers.clear() + + # Cancel background tasks + for task in self._tasks: + task.cancel() + self._tasks.clear() + + # Close server + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + + logger.info("Node %s shutdown complete", self.node_id) - async def _broadcast_message(self, topic, msg_type, payload): - msg = json.dumps({"type": msg_type, "data": payload}) - if self.pubsub: - try: - await self.pubsub.publish(topic, msg.encode()) - except Exception as e: - logger.error("Network: Publish failed: %s", e) - else: - logger.debug("Network: pubsub not initialized (mock mode)") + async def connect_to_peer(self, address: str) -> bool: + """ + Connect to a peer by address (ip:port). + Returns True if connection successful. + """ + try: + host, port_str = address.split(':') + port = int(port_str) + + logger.info("Node %s connecting to %s...", self.node_id, address) + reader, writer = await asyncio.open_connection(host, port) + + # Send HELLO + hello_msg = { + "type": "hello", + "data": { + "node_id": self.node_id, + "port": self.port + } + } + data = json.dumps(hello_msg).encode('utf-8') + header = struct.pack('>I', len(data)) + writer.write(header + data) + await writer.drain() + + # Wait for HELLO response + resp_header = await asyncio.wait_for(reader.readexactly(HEADER_SIZE), timeout=5.0) + length = struct.unpack('>I', resp_header)[0] + resp_data = await asyncio.wait_for(reader.readexactly(length), timeout=5.0) + response = json.loads(resp_data.decode('utf-8')) + + if response.get("type") != "hello": + logger.warning("Invalid response from %s", address) + writer.close() + return False + + peer_node_id = response["data"]["node_id"] + + # Check for self-connection + if peer_node_id == self.node_id: + logger.warning("Detected self-connection, closing") + writer.close() + return False + + # Check for duplicate connection + if peer_node_id in self.peers: + logger.info("Already connected to %s", peer_node_id) + writer.close() + return True + + peer = Peer(peer_node_id, reader, writer, address) + self.peers[peer_node_id] = peer + + # Start listening for messages from this peer + task = asyncio.create_task(self._listen_to_peer(peer)) + self._tasks.append(task) + + logger.info("Node %s connected to peer %s at %s", self.node_id, peer_node_id, address) + return True + + except Exception as e: + logger.error("Failed to connect to %s: %s", address, e) + return False - async def broadcast_transaction(self, tx): - sender = getattr(tx, "sender", "") - logger.info("Network: Broadcasting Tx from %s...", sender[:5]) + async def _handle_incoming_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + """Handle a new incoming connection.""" + address = writer.get_extra_info('peername') + address_str = f"{address[0]}:{address[1]}" if address else "unknown" + try: - payload = tx.to_dict() - except (TypeError, ValueError) as e: - logger.error("Network: Failed to serialize tx: %s", e) + # Wait for HELLO + header = await asyncio.wait_for(reader.readexactly(HEADER_SIZE), timeout=10.0) + length = struct.unpack('>I', header)[0] + data = await asyncio.wait_for(reader.readexactly(length), timeout=10.0) + message = json.loads(data.decode('utf-8')) + + if message.get("type") != "hello": + logger.warning("Expected HELLO from %s, got %s", address_str, message.get("type")) + writer.close() + return + + peer_node_id = message["data"]["node_id"] + + # Check for self-connection + if peer_node_id == self.node_id: + logger.warning("Detected self-connection from %s, closing", address_str) + writer.close() + return + + # Check for duplicate + if peer_node_id in self.peers: + logger.info("Duplicate connection from %s, closing new one", peer_node_id) + writer.close() + return + + # Send HELLO response + hello_resp = { + "type": "hello", + "data": { + "node_id": self.node_id, + "port": self.port + } + } + resp_data = json.dumps(hello_resp).encode('utf-8') + resp_header = struct.pack('>I', len(resp_data)) + writer.write(resp_header + resp_data) + await writer.drain() + + peer = Peer(peer_node_id, reader, writer, address_str) + self.peers[peer_node_id] = peer + + logger.info("Node %s accepted connection from peer %s", self.node_id, peer_node_id) + + # Listen for messages + await self._listen_to_peer(peer) + + except asyncio.TimeoutError: + logger.warning("Timeout waiting for HELLO from %s", address_str) + writer.close() + except Exception as e: + logger.error("Error handling connection from %s: %s", address_str, e) + writer.close() + + async def _listen_to_peer(self, peer: Peer): + """Listen for messages from a peer.""" + while self._running: + message = await peer.receive() + if message is None: + break + + await self._process_message(message, peer) + + # Peer disconnected + if peer.node_id in self.peers: + del self.peers[peer.node_id] + peer.close() + logger.info("Peer %s disconnected", peer.node_id) + + async def _process_message(self, message: dict, sender: Peer): + """Process an incoming message.""" + msg_type = message.get("type") + msg_data = message.get("data", {}) + + # Generate message ID for deduplication + msg_id = self._get_message_id(message) + + # Broadcast guard: skip if already seen + if msg_id and msg_id in self.seen_messages: return - await self._broadcast_message("minichain-global", "tx", payload) + + if msg_id: + self.seen_messages.add(msg_id) + # Limit seen set size + if len(self.seen_messages) > 10000: + self.seen_messages = set(list(self.seen_messages)[-5000:]) + + # Handle internal message types + if msg_type == "get_chain": + # Respond with chain (handled by callback) + pass + + # Forward to registered handler + if self._handler_callback: + try: + # Pass sender info for response handling + await self._handler_callback(message, sender) + except Exception: + logger.exception("Error in message handler for: %s", message) - async def broadcast_block(self, block): - logger.info("Network: Broadcasting Block #%d", block.index) - await self._broadcast_message("minichain-global", "block", block.to_dict()) + def _get_message_id(self, message: dict) -> Optional[str]: + """Get unique ID for message deduplication.""" + msg_type = message.get("type") + msg_data = message.get("data", {}) + + if msg_type == "tx": + # Use transaction hash + return msg_data.get("signature") or json.dumps(msg_data, sort_keys=True) + elif msg_type == "block": + # Use block hash + return msg_data.get("hash") + + return None - async def handle_message(self, msg): + async def broadcast(self, message: dict, exclude_peer: Optional[str] = None): """ - Callback when a p2p message is received. + Broadcast a message to all connected peers. + exclude_peer: node_id to exclude (typically the sender) """ + msg_id = self._get_message_id(message) + if msg_id: + self.seen_messages.add(msg_id) + + for node_id, peer in list(self.peers.items()): + if node_id == exclude_peer: + continue + try: + await peer.send(message) + except Exception as e: + logger.error("Failed to broadcast to %s: %s", node_id, e) + + async def broadcast_transaction(self, tx): + """Broadcast a transaction to all peers.""" + logger.info("Node %s broadcasting tx from %s...", self.node_id, tx.sender[:8]) + message = {"type": "tx", "data": tx.to_dict()} + await self.broadcast(message) + async def broadcast_block(self, block): + """Broadcast a block to all peers.""" + logger.info("Node %s broadcasting block #%d", self.node_id, block.index) + message = {"type": "block", "data": block.to_dict()} + await self.broadcast(message) + + async def request_chain(self, peer: Peer) -> Optional[list]: + """Request the full chain from a peer.""" try: - if not hasattr(msg, "data"): - raise TypeError("Incoming message missing 'data' attribute") + await peer.send({"type": "get_chain", "data": {}}) + # Response will be handled by the message callback + return None + except Exception as e: + logger.error("Failed to request chain from %s: %s", peer.node_id, e) + return None - if not isinstance(msg.data, (bytes, bytearray)): - raise TypeError("msg.data must be bytes") + async def send_chain(self, peer: Peer, chain_data: list): + """Send the full chain to a peer.""" + try: + await peer.send({"type": "chain", "data": chain_data}) + except Exception as e: + logger.error("Failed to send chain to %s: %s", peer.node_id, e) - if len(msg.data) > 1024 * 1024: # 1MB limit - logger.warning("Network: Message too large") - return + def get_peer_count(self) -> int: + """Return number of connected peers.""" + return len(self.peers) - try: - decoded = msg.data.decode('utf-8') - except UnicodeDecodeError as e: - logger.warning("Network Error: UnicodeDecodeError during message decode: %s", e) - return - data = json.loads(decoded) + def get_peer_ids(self) -> list: + """Return list of connected peer node_ids.""" + return list(self.peers.keys()) - if not isinstance(data, dict) or "type" not in data or "data" not in data: - raise ValueError("Invalid message format") - except (TypeError, ValueError, json.JSONDecodeError) as e: - logger.warning("Network Error parsing message: %s", e) - return +# Legacy compatibility - handle_message interface for tests +async def _legacy_handle_message(self, msg): + """Legacy callback interface for backward compatibility.""" + if not hasattr(msg, "data"): + raise TypeError("Incoming message missing 'data' attribute") + + if not isinstance(msg.data, (bytes, bytearray)): + raise TypeError("msg.data must be bytes") + + data = json.loads(msg.data.decode('utf-8')) + + if self._handler_callback: + await self._handler_callback(data, None) - try: - if self._handler_callback: - await self._handler_callback(data) - else: - logger.warning("Network Error: No handler_callback registered") - except Exception: - logger.exception("Error in network handler callback for data: %s", data) From 69f864c6f760f1aa356aa9cc43428221d0525b45 Mon Sep 17 00:00:00 2001 From: anuragShingare30 Date: Wed, 25 Feb 2026 19:10:25 +0530 Subject: [PATCH 2/4] update: functions and cli commands --- cli.py | 230 ++++++++++++++++++++++++++++++++++++--- minichain/mempool.py | 13 +++ minichain/state.py | 11 ++ minichain/transaction.py | 11 +- 4 files changed, 248 insertions(+), 17 deletions(-) diff --git a/cli.py b/cli.py index 5d8b73a..9cadbdb 100644 --- a/cli.py +++ b/cli.py @@ -46,6 +46,7 @@ def __init__(self, port: int, peers: list = None, mining: bool = False, miner_ad self._running = False self._sync_complete = False + self._chain_received_event = asyncio.Event() async def start(self): """Start the node.""" @@ -87,8 +88,12 @@ async def _sync_chain(self): for peer in list(self.network.peers.values()): await self.network.request_chain(peer) - # Wait a bit for responses - await asyncio.sleep(2) + # Wait for chain response with timeout + try: + await asyncio.wait_for(self._chain_received_event.wait(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning("Chain sync timeout - no response from peers") + self._sync_complete = True logger.info("Chain sync complete. Height: %d", len(self.chain.chain)) @@ -156,7 +161,7 @@ async def _handle_get_chain(self, sender): if sender is None: return - chain_data = [block.to_dict() for block in self.chain.chain] + chain_data = self.chain.to_dict_list() await self.network.send_chain(sender, chain_data) logger.info("Sent chain (%d blocks) to peer %s", len(chain_data), sender.node_id) @@ -167,9 +172,10 @@ async def _handle_chain(self, chain_data: list): logger.info("Received chain with %d blocks", len(chain_data)) - # Only replace if longer - if len(chain_data) <= len(self.chain.chain): - logger.info("Received chain not longer than ours, ignoring") + # Only replace if longer or equal (let replace_chain validate) + if len(chain_data) < len(self.chain.chain): + logger.info("Received chain shorter than ours, ignoring") + self._chain_received_event.set() return # Validate and replace @@ -177,17 +183,45 @@ async def _handle_chain(self, chain_data: list): logger.info("Replaced chain with received chain (new height: %d)", len(self.chain.chain)) else: logger.warning("Received chain validation failed") + + # Signal that we received a chain response + self._chain_received_event.set() async def _run(self): """Main node loop.""" - logger.info("Node running. Press Ctrl+C to stop.") + logger.info("Node running. Type 'help' for commands.") logger.info("Connected to %d peer(s)", self.network.get_peer_count()) mine_interval = 10 # seconds between mining attempts last_mine_time = 0 + + # Start input reader task + input_task = asyncio.create_task(self._read_input()) while self._running: try: + # Check for user input + if input_task.done(): + try: + cmd = input_task.result() + if cmd is not None: + result = self._handle_command(cmd) + if result == False: + break + elif result == "sync": + self._chain_received_event.clear() + await self._sync_chain() + elif isinstance(result, tuple) and result[0] == "connect": + success = await self.network.connect_to_peer(result[1]) + if success: + print(f"Connected to {result[1]}") + else: + print(f"Failed to connect to {result[1]}") + except Exception: + pass + # Start new input reader + input_task = asyncio.create_task(self._read_input()) + # Mining loop if self.mining and self._sync_complete: import time @@ -197,19 +231,45 @@ async def _run(self): last_mine_time = now # Keep event loop responsive - await asyncio.sleep(1) + await asyncio.sleep(0.1) except asyncio.CancelledError: break + + # Cancel input task on exit + if not input_task.done(): + input_task.cancel() + + async def _read_input(self): + """Read a line from stdin asynchronously.""" + loop = asyncio.get_event_loop() + try: + return await loop.run_in_executor(None, sys.stdin.readline) + except Exception: + return None async def _try_mine_block(self): """Attempt to mine a new block.""" pending_txs = self.mempool.get_transactions_for_block() + # Create coinbase transaction for mining reward + coinbase_tx = Transaction( + sender="0" * 40, # Coinbase has no sender + receiver=self.miner_address, + amount=50, # Mining reward + nonce=0, + data=None, + signature=None, # Coinbase doesn't need a signature + timestamp=None # Will be set to current time + ) + + # Insert coinbase transaction at the beginning + all_txs = [coinbase_tx] + pending_txs + block = Block( index=self.chain.last_block.index + 1, previous_hash=self.chain.last_block.hash, - transactions=pending_txs, + transactions=all_txs, ) try: @@ -219,15 +279,138 @@ async def _try_mine_block(self): if self.chain.add_block(mined_block): logger.info("Mined block #%d! Hash: %s...", mined_block.index, mined_block.hash[:16]) - # Credit mining reward - self.chain.state.credit_mining_reward(self.miner_address) - # Broadcast to peers await self.network.broadcast_block(mined_block) except Exception as e: - # Mining timeout or other error - this is normal - pass + # Mining timeout or other error - return transactions to mempool + for tx in pending_txs: + self.mempool.add_transaction(tx) + + def _print_help(self): + """Print available commands.""" + print(""" +Available commands: + status - Show node status + peers - List connected peers + chain - Show chain status + balance - Show balance of an address + mempool - Show pending transactions + block - Show block details + connect - Connect to a new peer + mine - Toggle mining on/off + sync - Request chain sync from peers + help - Show this help + exit / quit - Stop the node +""") + + def _handle_command(self, cmd: str): + """Handle interactive command.""" + parts = cmd.strip().split() + if not parts: + return True # Continue running + + command = parts[0].lower() + args = parts[1:] + + if command in ("exit", "quit"): + return False # Stop running + + elif command == "help": + self._print_help() + + elif command == "status": + print(f"Node ID: {self.network.node_id}") + print(f"Port: {self.port}") + print(f"Peers: {self.network.get_peer_count()}") + print(f"Chain height: {len(self.chain.chain)}") + print(f"Mempool: {self.mempool.size()} txns") + print(f"Mining: {'ON' if self.mining else 'OFF'}") + print(f"Synced: {'Yes' if self._sync_complete else 'No'}") + + elif command == "peers": + peers = self.network.get_peer_ids() + if peers: + print(f"Connected peers ({len(peers)}):") + for pid in peers: + peer = self.network.peers.get(pid) + addr = peer.address if peer else "unknown" + print(f" {pid} @ {addr}") + else: + print("No peers connected") + + elif command == "chain": + height = len(self.chain.chain) + last = self.chain.last_block + print(f"Chain height: {height}") + print(f"Last block: #{last.index}") + print(f" Hash: {last.hash[:32]}...") + print(f" Txns: {len(last.transactions)}") + print(f"Mining: {'ON' if self.mining else 'OFF'}") + print(f"Synced: {'Yes' if self._sync_complete else 'No'}") + + elif command == "balance": + if not args: + print("Usage: balance
") + else: + addr = args[0] + account = self.chain.state.get_account(addr) + print(f"Address: {addr[:16]}...") + print(f" Balance: {account['balance']}") + print(f" Nonce: {account['nonce']}") + + elif command == "mempool": + pending = self.mempool.get_pending_transactions() + if pending: + print(f"Pending transactions ({len(pending)}):") + for tx in pending[:10]: # Show first 10 + print(f" {tx.sender[:8]}... -> {tx.receiver[:8] if tx.receiver else 'deploy'}... : {tx.amount}") + if len(pending) > 10: + print(f" ... and {len(pending) - 10} more") + else: + print("Mempool is empty") + + elif command == "block": + if not args: + print("Usage: block ") + else: + try: + idx = int(args[0]) + if 0 <= idx < len(self.chain.chain): + block = self.chain.chain[idx] + print(f"Block #{block.index}") + print(f" Hash: {block.hash}") + print(f" Prev: {block.previous_hash[:32]}...") + print(f" Time: {block.timestamp}") + print(f" Txns: {len(block.transactions)}") + print(f" Difficulty: {block.difficulty}") + print(f" Nonce: {block.nonce}") + else: + print(f"Block {idx} not found (height: {len(self.chain.chain)})") + except ValueError: + print("Invalid block index") + + elif command == "mine": + self.mining = not self.mining + print(f"Mining: {'ON' if self.mining else 'OFF'}") + + elif command == "connect": + if not args: + print("Usage: connect ") + else: + addr = args[0] + print(f"Connecting to {addr}...") + return ("connect", addr) + + elif command == "sync": + print("Requesting chain sync...") + # Will be handled in async context + return "sync" + + else: + print(f"Unknown command: {command}. Type 'help' for commands.") + + return True # Continue running def parse_args(): @@ -279,8 +462,13 @@ def parse_args(): return parser.parse_args() +# Module-level variable to store shutdown task +shutdown_task = None + async def run_node(args): """Run the node with given arguments.""" + global shutdown_task + # Parse peers peers = [] if args.peers: @@ -298,8 +486,9 @@ async def run_node(args): loop = asyncio.get_event_loop() def shutdown_handler(): + global shutdown_task logger.info("Shutdown signal received") - asyncio.create_task(node.stop()) + shutdown_task = asyncio.create_task(node.stop()) for sig in (signal.SIGINT, signal.SIGTERM): try: @@ -313,7 +502,14 @@ def shutdown_handler(): except KeyboardInterrupt: pass finally: - await node.stop() + # If shutdown was already initiated by signal handler, wait for it + if shutdown_task is not None: + try: + await asyncio.wait_for(shutdown_task, timeout=5.0) + except asyncio.TimeoutError: + logger.warning("Shutdown task timed out") + else: + await node.stop() def main(): @@ -335,6 +531,8 @@ def main(): Port: {args.port} Mining: {'Enabled' if args.mine else 'Disabled'} Peers: {args.peers if args.peers else 'None'} + + Type 'help' for available commands. """) asyncio.run(run_node(args)) diff --git a/minichain/mempool.py b/minichain/mempool.py index 06a60d0..8ddba2c 100644 --- a/minichain/mempool.py +++ b/minichain/mempool.py @@ -60,3 +60,16 @@ def get_transactions_for_block(self): self._seen_tx_ids.difference_update(confirmed_ids) return txs + + def get_pending_transactions(self): + """ + Returns a copy of pending transactions without clearing the pool. + Used for inspection/display purposes. + """ + with self._lock: + return self._pending_txs[:] + + def size(self): + """Returns the number of pending transactions.""" + with self._lock: + return len(self._pending_txs) diff --git a/minichain/state.py b/minichain/state.py index ce9a6f0..6736c05 100644 --- a/minichain/state.py +++ b/minichain/state.py @@ -14,6 +14,7 @@ def __init__(self): self.contract_machine = ContractMachine(self) DEFAULT_MINING_REWARD = 50 + COINBASE_ADDRESS = "0" * 40 # Special address for coinbase transactions def get_account(self, address): if address not in self.accounts: @@ -26,6 +27,10 @@ def get_account(self, address): return self.accounts[address] def verify_transaction_logic(self, tx): + # Coinbase transactions don't need signature/balance/nonce validation + if tx.sender == self.COINBASE_ADDRESS: + return True + if not tx.verify(): logger.error(f"Error: Invalid signature for tx from {tx.sender[:8]}...") return False @@ -74,6 +79,12 @@ def apply_transaction(self, tx): if not self.verify_transaction_logic(tx): return False + # Coinbase transactions only credit the receiver, no deduction from sender + if tx.sender == self.COINBASE_ADDRESS: + receiver = self.get_account(tx.receiver) + receiver['balance'] += tx.amount + return True + sender = self.accounts[tx.sender] # Deduct funds and increment nonce diff --git a/minichain/transaction.py b/minichain/transaction.py index cdf4d99..5e17521 100644 --- a/minichain/transaction.py +++ b/minichain/transaction.py @@ -12,7 +12,16 @@ def __init__(self, sender, receiver, amount, nonce, data=None, signature=None, t self.amount = amount self.nonce = nonce self.data = data # Preserve None (do NOT normalize to "") - self.timestamp = round(timestamp * 1000) if timestamp is not None else round(time.time() * 1000) # Integer milliseconds for determinism + # Handle timestamp: if already in milliseconds (large int), use as-is + # Otherwise convert from seconds to milliseconds + if timestamp is None: + self.timestamp = round(time.time() * 1000) + elif isinstance(timestamp, int) and timestamp > 1e12: + # Already in milliseconds (timestamps after year 2001 in ms are > 1e12) + self.timestamp = timestamp + else: + # Timestamp in seconds, convert to milliseconds + self.timestamp = round(timestamp * 1000) self.signature = signature # Hex str def to_dict(self): From b4fee1c4bc96fb4f8cdfd0ee91ec5174191fadcb Mon Sep 17 00:00:00 2001 From: anuragShingare30 Date: Wed, 25 Feb 2026 19:10:46 +0530 Subject: [PATCH 3/4] small changes --- minichain/chain.py | 55 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/minichain/chain.py b/minichain/chain.py index 9b201ce..9b974eb 100644 --- a/minichain/chain.py +++ b/minichain/chain.py @@ -66,10 +66,19 @@ def add_block(self, block): return False # Verify block hash - if block.hash != calculate_hash(block.to_header_dict()): + computed_hash = calculate_hash(block.to_header_dict()) + if block.hash != computed_hash: logger.warning("Block %s rejected: Invalid hash %s", block.index, block.hash) return False + # Verify proof-of-work meets difficulty target + difficulty = block.difficulty or 0 + if difficulty > 0: + required_prefix = "0" * difficulty + if not computed_hash.startswith(required_prefix): + logger.warning("Block %s rejected: Hash does not meet difficulty %d", block.index, difficulty) + return False + # Validate transactions on a temporary state copy temp_state = self.state.copy() @@ -150,6 +159,14 @@ def validate_chain(self, chain_data: list) -> bool: logger.warning("Chain validation failed: Invalid hash at block %d", i) return False + # Verify proof-of-work meets difficulty target + difficulty = block_data.get("difficulty", 0) or 0 + if difficulty > 0: + required_prefix = "0" * difficulty + if not computed_hash.startswith(required_prefix): + logger.warning("Chain validation failed: Hash does not meet difficulty %d at block %d", difficulty, i) + return False + # Validate and apply transactions for tx in transactions: if not temp_state.validate_and_apply(tx): @@ -175,33 +192,39 @@ def replace_chain(self, chain_data: list) -> bool: True if chain was replaced, False otherwise """ with self._lock: - # Only replace if longer - if len(chain_data) <= len(self.chain): - logger.info("Received chain not longer than ours (%d <= %d)", + # Only replace if longer (or equal during initial sync) + if len(chain_data) < len(self.chain): + logger.info("Received chain shorter than ours (%d < %d)", len(chain_data), len(self.chain)) return False + + # If equal length, only replace if it validates (essentially a no-op for same chain) + if len(chain_data) == len(self.chain): + # Validate but don't bother replacing if identical + if self.validate_chain(chain_data): + logger.debug("Received chain same length as ours and valid") + return True # Consider it a successful sync + return False # Validate the received chain if not self.validate_chain(chain_data): logger.warning("Received chain failed validation") return False - # Rebuild chain and state from scratch + # Build new chain and state locally for atomic replacement logger.info("Replacing chain: %d -> %d blocks", len(self.chain), len(chain_data)) - # Clear and rebuild - self.chain = [] - self.state = State() + new_chain = [] + new_state = State() # Add genesis - genesis_data = chain_data[0] genesis_block = Block( index=0, previous_hash="0", transactions=[] ) genesis_block.hash = self.GENESIS_HASH - self.chain.append(genesis_block) + new_chain.append(genesis_block) # Add each subsequent block for i in range(1, len(chain_data)): @@ -218,11 +241,17 @@ def replace_chain(self, chain_data: list) -> bool: block.nonce = block_data.get("nonce", 0) block.hash = block_data.get("hash") - # Apply transactions to state + # Apply transactions to new state for tx in transactions: - self.state.validate_and_apply(tx) + if not new_state.validate_and_apply(tx): + logger.warning("Chain rebuild failed: Invalid tx in block %d", i) + return False + + new_chain.append(block) - self.chain.append(block) + # Atomically assign new chain and state + self.chain = new_chain + self.state = new_state logger.info("Chain replaced successfully. New height: %d", len(self.chain)) return True From 0347e7fd4e4f54e776e42916a7c27e235f5b6b27 Mon Sep 17 00:00:00 2001 From: anuragShingare30 Date: Wed, 25 Feb 2026 19:13:48 +0530 Subject: [PATCH 4/4] small refactor in p2p file --- minichain/p2p.py | 135 +++++++++++++++++++++++------------------------ 1 file changed, 65 insertions(+), 70 deletions(-) diff --git a/minichain/p2p.py b/minichain/p2p.py index 12a0305..239ea84 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -3,7 +3,8 @@ import asyncio import struct import uuid -from typing import Dict, Set, Optional, Callable, Any +from collections import OrderedDict +from typing import Dict, Optional, Callable, Any logger = logging.getLogger(__name__) @@ -39,7 +40,12 @@ async def receive(self) -> Optional[dict]: length = struct.unpack('>I', header)[0] if length > MAX_MESSAGE_SIZE: - logger.warning("Message too large from %s: %d bytes", self.address, length) + logger.warning("Message too large from %s: %d bytes, closing connection", self.address, length) + self.writer.close() + try: + await self.writer.wait_closed() + except Exception: + pass return None data = await self.reader.readexactly(length) @@ -80,7 +86,7 @@ def __init__(self, handler_callback: Optional[Callable] = None): self.register_handler(handler_callback) self.peers: Dict[str, Peer] = {} # node_id -> Peer - self.seen_messages: Set[str] = set() # For broadcast deduplication + self.seen_messages: OrderedDict[str, None] = OrderedDict() # For broadcast deduplication (ordered for proper eviction) self._server: Optional[asyncio.Server] = None self._running: bool = False self._tasks: list = [] @@ -141,7 +147,10 @@ async def connect_to_peer(self, address: str) -> bool: logger.info("Node %s connecting to %s...", self.node_id, address) reader, writer = await asyncio.open_connection(host, port) - # Send HELLO + # Create temporary peer for handshake + temp_peer = Peer(node_id="pending", reader=reader, writer=writer, address=address) + + # Send HELLO using Peer.send hello_msg = { "type": "hello", "data": { @@ -149,20 +158,19 @@ async def connect_to_peer(self, address: str) -> bool: "port": self.port } } - data = json.dumps(hello_msg).encode('utf-8') - header = struct.pack('>I', len(data)) - writer.write(header + data) - await writer.drain() + await temp_peer.send(hello_msg) - # Wait for HELLO response - resp_header = await asyncio.wait_for(reader.readexactly(HEADER_SIZE), timeout=5.0) - length = struct.unpack('>I', resp_header)[0] - resp_data = await asyncio.wait_for(reader.readexactly(length), timeout=5.0) - response = json.loads(resp_data.decode('utf-8')) + # Wait for HELLO response using Peer.receive with timeout + try: + response = await asyncio.wait_for(temp_peer.receive(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning("Timeout waiting for HELLO response from %s", address) + temp_peer.close() + return False - if response.get("type") != "hello": + if response is None or response.get("type") != "hello": logger.warning("Invalid response from %s", address) - writer.close() + temp_peer.close() return False peer_node_id = response["data"]["node_id"] @@ -170,20 +178,21 @@ async def connect_to_peer(self, address: str) -> bool: # Check for self-connection if peer_node_id == self.node_id: logger.warning("Detected self-connection, closing") - writer.close() + temp_peer.close() return False # Check for duplicate connection if peer_node_id in self.peers: logger.info("Already connected to %s", peer_node_id) - writer.close() + temp_peer.close() return True - peer = Peer(peer_node_id, reader, writer, address) - self.peers[peer_node_id] = peer + # Update peer with actual node_id and register + temp_peer.node_id = peer_node_id + self.peers[peer_node_id] = temp_peer # Start listening for messages from this peer - task = asyncio.create_task(self._listen_to_peer(peer)) + task = asyncio.create_task(self._listen_to_peer(temp_peer)) self._tasks.append(task) logger.info("Node %s connected to peer %s at %s", self.node_id, peer_node_id, address) @@ -198,16 +207,21 @@ async def _handle_incoming_connection(self, reader: asyncio.StreamReader, writer address = writer.get_extra_info('peername') address_str = f"{address[0]}:{address[1]}" if address else "unknown" + # Create temporary peer for handshake + temp_peer = Peer(node_id="pending", reader=reader, writer=writer, address=address_str) + try: - # Wait for HELLO - header = await asyncio.wait_for(reader.readexactly(HEADER_SIZE), timeout=10.0) - length = struct.unpack('>I', header)[0] - data = await asyncio.wait_for(reader.readexactly(length), timeout=10.0) - message = json.loads(data.decode('utf-8')) + # Wait for HELLO using Peer.receive with timeout + try: + message = await asyncio.wait_for(temp_peer.receive(), timeout=10.0) + except asyncio.TimeoutError: + logger.warning("Timeout waiting for HELLO from %s", address_str) + temp_peer.close() + return - if message.get("type") != "hello": - logger.warning("Expected HELLO from %s, got %s", address_str, message.get("type")) - writer.close() + if message is None or message.get("type") != "hello": + logger.warning("Expected HELLO from %s, got %s", address_str, message.get("type") if message else None) + temp_peer.close() return peer_node_id = message["data"]["node_id"] @@ -215,16 +229,16 @@ async def _handle_incoming_connection(self, reader: asyncio.StreamReader, writer # Check for self-connection if peer_node_id == self.node_id: logger.warning("Detected self-connection from %s, closing", address_str) - writer.close() + temp_peer.close() return # Check for duplicate if peer_node_id in self.peers: logger.info("Duplicate connection from %s, closing new one", peer_node_id) - writer.close() + temp_peer.close() return - # Send HELLO response + # Send HELLO response using Peer.send hello_resp = { "type": "hello", "data": { @@ -232,25 +246,20 @@ async def _handle_incoming_connection(self, reader: asyncio.StreamReader, writer "port": self.port } } - resp_data = json.dumps(hello_resp).encode('utf-8') - resp_header = struct.pack('>I', len(resp_data)) - writer.write(resp_header + resp_data) - await writer.drain() + await temp_peer.send(hello_resp) - peer = Peer(peer_node_id, reader, writer, address_str) - self.peers[peer_node_id] = peer + # Update peer with actual node_id and register + temp_peer.node_id = peer_node_id + self.peers[peer_node_id] = temp_peer logger.info("Node %s accepted connection from peer %s", self.node_id, peer_node_id) # Listen for messages - await self._listen_to_peer(peer) + await self._listen_to_peer(temp_peer) - except asyncio.TimeoutError: - logger.warning("Timeout waiting for HELLO from %s", address_str) - writer.close() except Exception as e: logger.error("Error handling connection from %s: %s", address_str, e) - writer.close() + temp_peer.close() async def _listen_to_peer(self, peer: Peer): """Listen for messages from a peer.""" @@ -270,7 +279,6 @@ async def _listen_to_peer(self, peer: Peer): async def _process_message(self, message: dict, sender: Peer): """Process an incoming message.""" msg_type = message.get("type") - msg_data = message.get("data", {}) # Generate message ID for deduplication msg_id = self._get_message_id(message) @@ -280,14 +288,14 @@ async def _process_message(self, message: dict, sender: Peer): return if msg_id: - self.seen_messages.add(msg_id) - # Limit seen set size - if len(self.seen_messages) > 10000: - self.seen_messages = set(list(self.seen_messages)[-5000:]) + self.seen_messages[msg_id] = None + # Limit seen dict size - evict oldest entries + while len(self.seen_messages) > 10000: + self.seen_messages.popitem(last=False) # Remove oldest # Handle internal message types if msg_type == "get_chain": - # Respond with chain (handled by callback) + # No-op here: actual chain response is handled by _handler_callback pass # Forward to registered handler @@ -319,7 +327,7 @@ async def broadcast(self, message: dict, exclude_peer: Optional[str] = None): """ msg_id = self._get_message_id(message) if msg_id: - self.seen_messages.add(msg_id) + self.seen_messages[msg_id] = None for node_id, peer in list(self.peers.items()): if node_id == exclude_peer: @@ -341,15 +349,18 @@ async def broadcast_block(self, block): message = {"type": "block", "data": block.to_dict()} await self.broadcast(message) - async def request_chain(self, peer: Peer) -> Optional[list]: - """Request the full chain from a peer.""" + async def request_chain(self, peer: Peer) -> None: + """ + Request the full chain from a peer. + + This method only sends a "get_chain" request. The actual chain + is delivered asynchronously via the message handler callback + that processes incoming "chain" messages. + """ try: await peer.send({"type": "get_chain", "data": {}}) - # Response will be handled by the message callback - return None except Exception as e: logger.error("Failed to request chain from %s: %s", peer.node_id, e) - return None async def send_chain(self, peer: Peer, chain_data: list): """Send the full chain to a peer.""" @@ -365,19 +376,3 @@ def get_peer_count(self) -> int: def get_peer_ids(self) -> list: """Return list of connected peer node_ids.""" return list(self.peers.keys()) - - -# Legacy compatibility - handle_message interface for tests -async def _legacy_handle_message(self, msg): - """Legacy callback interface for backward compatibility.""" - if not hasattr(msg, "data"): - raise TypeError("Incoming message missing 'data' attribute") - - if not isinstance(msg.data, (bytes, bytearray)): - raise TypeError("msg.data must be bytes") - - data = json.loads(msg.data.decode('utf-8')) - - if self._handler_callback: - await self._handler_callback(data, None) -