diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..2a20b46 --- /dev/null +++ b/api/main.py @@ -0,0 +1,174 @@ +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException, Query +from pydantic import BaseModel +from typing import List, Optional, Union, Dict + +from core import Blockchain, Block, State, Transaction +from core.merkle import MerkleTree +from node import Mempool +from core.mining import mine_and_process_block + + +blockchain: Optional[Blockchain] = None +mempool: Optional[Mempool] = None +pending_nonce_map: Dict[str, int] = {} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global blockchain, mempool + blockchain = Blockchain() + mempool = Mempool() + yield + blockchain.save_to_file() + + +app = FastAPI(title="MiniChain API", description="SPV-enabled blockchain API", lifespan=lifespan) + + +class TransactionResponse(BaseModel): + sender: str + receiver: Optional[str] = None + amount: int + nonce: int + data: Optional[Union[dict, str]] = None + timestamp: int + signature: Optional[str] = None + hash: Optional[str] = None + + +class BlockResponse(BaseModel): + index: int + previous_hash: str + merkle_root: Optional[str] + timestamp: int + difficulty: Optional[int] + nonce: int + hash: Optional[str] = None + transactions: List[TransactionResponse] + merkle_proofs: Optional[dict] = None + + +class VerifyTransactionResponse(BaseModel): + tx_hash: str + block_index: int + merkle_root: str + proof: List[dict] + verification_status: bool + message: str + + +class ChainInfo(BaseModel): + length: int + blocks: List[dict] + + +@app.get("/") +def root(): + return {"message": "MiniChain API with SPV Support"} + + +@app.get("/chain", response_model=ChainInfo) +def get_chain(): + chain_copy = blockchain.get_chain_copy() + + return { + "length": len(chain_copy), + "blocks": [block.to_dict() for block in chain_copy] + } + + +@app.get("/block/{block_index}", response_model=BlockResponse) +def get_block(block_index: int): + chain_copy = blockchain.get_chain_copy() + + if block_index < 0 or block_index >= len(chain_copy): + raise HTTPException(status_code=404, detail="Block not found") + + block = chain_copy[block_index] + + block_dict = block.to_dict() + + merkle_proofs = {} + for i, _ in enumerate(block.transactions): + tx_hash = block.get_tx_hash(i) + if tx_hash: + proof = block.get_merkle_proof(i) + if proof is not None: + merkle_proofs[tx_hash] = proof + + return { + **block_dict, + "merkle_proofs": merkle_proofs + } + + +@app.get("/verify_transaction", response_model=VerifyTransactionResponse) +def verify_transaction( + tx_hash: str = Query(..., description="Transaction hash to verify"), + block_index: int = Query(..., description="Block index to verify against") +): + chain_copy = blockchain.get_chain_copy() + + if block_index < 0 or block_index >= len(chain_copy): + raise HTTPException(status_code=404, detail="Block not found") + + block = chain_copy[block_index] + + tx_found = False + tx_index = -1 + for i, _ in enumerate(block.transactions): + tx_hash_computed = block.get_tx_hash(i) + if tx_hash_computed == tx_hash: + tx_found = True + tx_index = i + break + + if not tx_found: + return { + "tx_hash": tx_hash, + "block_index": block_index, + "merkle_root": block.merkle_root or "", + "proof": [], + "verification_status": False, + "message": "Transaction not found in block" + } + + proof = block.get_merkle_proof(tx_index) + merkle_root = block.merkle_root or "" + + if proof is None: + return { + "tx_hash": tx_hash, + "block_index": block_index, + "merkle_root": merkle_root, + "proof": [], + "verification_status": False, + "message": "Failed to generate Merkle proof" + } + + verification_status = MerkleTree.verify_proof(tx_hash, proof, merkle_root) + + return { + "tx_hash": tx_hash, + "block_index": block_index, + "merkle_root": merkle_root, + "proof": proof, + "verification_status": verification_status, + "message": "Transaction verified successfully" if verification_status else "Verification failed" + } + + +@app.post("/mine") +def mine_block_endpoint(): + block, *_ = mine_and_process_block(blockchain, mempool, pending_nonce_map) + + if block: + return {"message": "Block mined successfully", "block": block.to_dict()} + else: + raise HTTPException(status_code=400, detail="Failed to mine block") + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="127.0.0.1", port=8000) diff --git a/core/block.py b/core/block.py index 23f7536..01937e0 100644 --- a/core/block.py +++ b/core/block.py @@ -1,37 +1,9 @@ import time -import hashlib import json from typing import List, Optional from core.transaction import Transaction - - -def _sha256(data: str) -> str: - return hashlib.sha256(data.encode()).hexdigest() - - -def _calculate_merkle_root(transactions: List[Transaction]) -> Optional[str]: - if not transactions: - return None - - # Hash each transaction deterministically - tx_hashes = [ - _sha256(json.dumps(tx.to_dict(), sort_keys=True)) - for tx in transactions - ] - - # Build Merkle tree - while len(tx_hashes) > 1: - if len(tx_hashes) % 2 != 0: - tx_hashes.append(tx_hashes[-1]) # duplicate last if odd - - new_level = [] - for i in range(0, len(tx_hashes), 2): - combined = tx_hashes[i] + tx_hashes[i + 1] - new_level.append(_sha256(combined)) - - tx_hashes = new_level - - return tx_hashes[0] +from core.merkle import MerkleTree +from core.utils import _sha256 class Block: @@ -58,8 +30,8 @@ def __init__( self.nonce: int = 0 self.hash: Optional[str] = None - # NEW: compute merkle root once - self.merkle_root: Optional[str] = _calculate_merkle_root(self.transactions) + self._merkle_tree = MerkleTree([tx.to_dict() for tx in self.transactions]) + self.merkle_root: Optional[str] = self._merkle_tree.get_merkle_root() # ------------------------- # HEADER (used for mining) @@ -103,3 +75,14 @@ def compute_hash(self) -> str: sort_keys=True ) return _sha256(header_string) + + # ------------------------- + # MERKLE PROOF + # ------------------------- + def get_merkle_proof(self, tx_index: int) -> Optional[List[dict]]: + return self._merkle_tree.get_proof(tx_index) + + def get_tx_hash(self, tx_index: int) -> Optional[str]: + if tx_index < 0 or tx_index >= len(self._merkle_tree.tx_hashes): + return None + return self._merkle_tree.tx_hashes[tx_index] diff --git a/core/chain.py b/core/chain.py index 9545864..60e27e7 100644 --- a/core/chain.py +++ b/core/chain.py @@ -1,27 +1,157 @@ from core.block import Block from core.state import State +from core.transaction import Transaction from consensus import calculate_hash import logging import threading +import json +import os +from typing import Optional logger = logging.getLogger(__name__) +DEFAULT_CHAIN_FILE = "chain_data.json" + class Blockchain: """ Manages the blockchain, validates blocks, and commits state transitions. """ - def __init__(self): + def __init__(self, chain_file: Optional[str] = None): self.chain = [] self.state = State() self._lock = threading.RLock() - self._create_genesis_block() + self._chain_file = chain_file or DEFAULT_CHAIN_FILE + self._load_from_file() + + def get_chain_copy(self): + with self._lock: + return list(self.chain) + + def _load_from_file(self): + if not os.path.exists(self._chain_file): + self._create_genesis_block() + return + + try: + with open(self._chain_file, 'r') as f: + data = json.load(f) + + self.chain = [] + for block_data in data.get("chain", []): + transactions = [] + for tx in block_data.get("transactions", []): + t = Transaction( + sender=tx["sender"], + receiver=tx.get("receiver"), + amount=tx["amount"], + nonce=tx["nonce"], + data=tx.get("data"), + signature=tx.get("signature") + ) + t.timestamp = tx.get("timestamp", t.timestamp) + transactions.append(t) + + block = Block( + index=block_data["index"], + previous_hash=block_data["previous_hash"], + transactions=transactions, + timestamp=block_data.get("timestamp"), + difficulty=block_data.get("difficulty") + ) + block.nonce = block_data.get("nonce", 0) + block.hash = block_data.get("hash") + self.chain.append(block) + +if len(self.chain) == 0: + self._create_genesis_block() + logger.info("Chain file %s contained no blocks; created genesis block", self._chain_file) + return + + genesis = self.chain[0] + if genesis.hash != "0" * 64 or genesis.previous_hash != "0": + logger.warning("Loaded chain has invalid genesis block. Rejecting loaded chain.") + self._create_genesis_block() + logger.info("Created new genesis block after rejecting invalid chain") + return + + for i in range(1, len(self.chain)): + prev_block = self.chain[i - 1] + curr_block = self.chain[i] + curr_data = data["chain"][i] + + if curr_block.previous_hash != prev_block.hash: + logger.warning(f"Loaded chain has invalid previous_hash at block {i}. Rejecting loaded chain.") + self.chain = [] + break + + if curr_block.hash != calculate_hash(curr_block.to_header_dict()): + logger.warning(f"Loaded chain has invalid hash at block {i}. Rejecting loaded chain.") + self.chain = [] + break + + stored_merkle = curr_data.get("merkle_root") + computed_merkle = curr_block.merkle_root + + if stored_merkle != computed_merkle: + logger.warning(f"Loaded chain has invalid merkle_root at block {i}. Rejecting loaded chain.") + self.chain = [] + break + else: + if data.get("state"): + self.state = State.from_dict(data["state"]) + logger.info(f"Loaded chain with {len(self.chain)} blocks from {self._chain_file}") + return + +if not self.chain: + self._create_genesis_block() + logger.info("Created new genesis block after rejecting invalid chain") + + except Exception as e: + logger.warning(f"Failed to load chain from {self._chain_file}: {e}. Creating new genesis block.") + self._create_genesis_block() + + def _serialize_chain_data(self): + return { + "chain": [ + { + "index": block.index, + "previous_hash": block.previous_hash, + "merkle_root": block.merkle_root, + "timestamp": block.timestamp, + "difficulty": block.difficulty, + "nonce": block.nonce, + "hash": block.hash, + "transactions": [tx.to_dict() for tx in block.transactions] + } + for block in self.chain + ], + "state": self.state.to_dict() if hasattr(self.state, 'to_dict') else {} + } + +def _save_to_file_unlocked(self, data, block_count): + temp_file = None + try: + temp_file = self._chain_file + ".tmp" + with open(temp_file, 'w') as f: + json.dump(data, f, indent=2) + + os.replace(temp_file, self._chain_file) + logger.info("Saved chain with %s blocks to %s", block_count, self._chain_file) + + except Exception as e: + logger.error(f"Failed to save chain to {self._chain_file}: {e}") + if temp_file is not None and os.path.exists(temp_file): + os.remove(temp_file) + +def save_to_file(self): + with self._lock: + data = self._serialize_chain_data() + block_count = len(self.chain) + self._save_to_file_unlocked(data, block_count) def _create_genesis_block(self): - """ - Creates the genesis block with a fixed hash. - """ genesis_block = Block( index=0, previous_hash="0", @@ -32,46 +162,37 @@ def _create_genesis_block(self): @property def last_block(self): - """ - Returns the most recent block in the chain. - """ - with self._lock: # Acquire lock for thread-safe access + with self._lock: return self.chain[-1] def add_block(self, block): - """ - Validates and adds a block to the chain if all transactions succeed. - Uses a copied State to ensure atomic validation. - """ - with self._lock: - # Check previous hash linkage if block.previous_hash != self.last_block.hash: logger.warning("Block %s rejected: Invalid previous hash %s != %s", block.index, block.previous_hash, self.last_block.hash) return False - # Check index linkage if block.index != self.last_block.index + 1: logger.warning("Block %s rejected: Invalid index %s != %s", block.index, block.index, self.last_block.index + 1) return False - # Verify block hash if block.hash != calculate_hash(block.to_header_dict()): logger.warning("Block %s rejected: Invalid hash %s", block.index, block.hash) return False - # Validate transactions on a temporary state copy temp_state = self.state.copy() for tx in block.transactions: result = temp_state.validate_and_apply(tx) - # Reject block if any transaction fails if not result: logger.warning("Block %s rejected: Transaction failed validation", block.index) return False - # All transactions valid → commit state and append block - self.state = temp_state +self.state = temp_state self.chain.append(block) - return True + + data = self._serialize_chain_data() + block_count = len(self.chain) + + self._save_to_file_unlocked(data, block_count) + return True diff --git a/core/merkle.py b/core/merkle.py new file mode 100644 index 0000000..b8a6e83 --- /dev/null +++ b/core/merkle.py @@ -0,0 +1,96 @@ +import json +from core.utils import _sha256 + + +class MerkleTree: + LEAF_PREFIX = "leaf:" + NODE_PREFIX = "node:" + + def __init__(self, transactions: list[dict]): + self.transactions = transactions + self.tx_hashes = self._hash_transactions() + self.tree = self._build_tree() + self.root = self._get_root() + + def _hash_transactions(self) -> list[str]: + return [ + _sha256(self.LEAF_PREFIX + json.dumps(tx, sort_keys=True)) + for tx in self.transactions + ] + + def _build_tree(self) -> list[list[str]]: + if not self.tx_hashes: + return [] + + tree = [self.tx_hashes[:]] + + while len(tree[-1]) > 1: + current_level = list(tree[-1]) + if len(current_level) % 2 != 0: + current_level.append(current_level[-1]) + + new_level = [] + for i in range(0, len(current_level), 2): + combined = current_level[i] + current_level[i + 1] + new_level.append(_sha256(self.NODE_PREFIX + combined)) + + tree.append(new_level) + + return tree + + def _get_root(self) -> str | None: + if not self.tree: + return None + return self.tree[-1][0] if self.tree[-1] else None + + def get_merkle_root(self) -> str | None: + return self.root + + def get_proof(self, index: int) -> list[dict] | None: + if index < 0 or index >= len(self.tx_hashes): + return None + + proof = [] + for level_idx in range(len(self.tree) - 1): + level = self.tree[level_idx] + is_right = index % 2 == 1 + sibling_idx = index - 1 if is_right else index + 1 + + if sibling_idx < len(level): + proof.append({ + "hash": level[sibling_idx], + "position": "left" if is_right else "right" + }) + else: + proof.append({ + "hash": level[index], + "position": "left" if is_right else "right" + }) + + index //= 2 + + return proof + + @staticmethod + def verify_proof(tx_hash: str, proof: list[dict], merkle_root: str) -> bool: + current_hash = tx_hash + + for item in proof: + sibling_hash = item["hash"] + position = item["position"] + + if position == "left": + combined = sibling_hash + current_hash + else: + combined = current_hash + sibling_hash + + current_hash = _sha256(MerkleTree.NODE_PREFIX + combined) + + return current_hash == merkle_root + + +def calculate_merkle_root(transactions: list[dict]) -> str | None: + if not transactions: + return None + tree = MerkleTree(transactions) + return tree.get_merkle_root() diff --git a/core/mining.py b/core/mining.py new file mode 100644 index 0000000..b2eb455 --- /dev/null +++ b/core/mining.py @@ -0,0 +1,62 @@ +import logging +import re +from core import Transaction, Block +from consensus import mine_block, MiningExceededError + + +logger = logging.getLogger(__name__) + +BURN_ADDRESS = "0" * 40 + + +def mine_and_process_block(chain, mempool, pending_nonce_map): + pending_txs = mempool.get_transactions_for_block() + tx_hashes = [mempool._get_tx_id(tx) for tx in pending_txs] + + last_block = chain.last_block + block = Block( + index=last_block.index + 1, + previous_hash=last_block.hash, + transactions=pending_txs, + ) + + try: + mined_block = mine_block(block) +except MiningExceededError: + mempool.return_transactions(pending_txs) + logger.warning("Mining failed, transactions returned to mempool") + return None, [] + + if not hasattr(mined_block, "miner"): + mined_block.miner = BURN_ADDRESS + + deployed_contracts: list[str] = [] + + if chain.add_block(mined_block): + logger.info("Block #%s added", mined_block.index) + + miner_attr = getattr(mined_block, "miner", None) + if isinstance(miner_attr, str) and re.match(r'^[0-9a-fA-F]{40}$', miner_attr): + miner_address = miner_attr + else: + logger.warning("Invalid miner address. Crediting burn address.") + miner_address = BURN_ADDRESS + + chain.state.credit_mining_reward(miner_address) + +for tx in mined_block.transactions: + sync_nonce(chain.state, pending_nonce_map, tx.sender) + + return mined_block, deployed_contracts +else: + mempool.return_transactions(pending_txs) + logger.error("Block rejected by chain, transactions returned to mempool") + return None, [] + + +def sync_nonce(state, pending_nonce_map, address): + account = state.get_account(address) + if account and "nonce" in account: + pending_nonce_map[address] = account["nonce"] + else: + pending_nonce_map[address] = 0 diff --git a/core/state.py b/core/state.py index 17bc68c..1ad73d9 100644 --- a/core/state.py +++ b/core/state.py @@ -161,3 +161,16 @@ def credit_mining_reward(self, miner_address, reward=None): reward = reward if reward is not None else self.DEFAULT_MINING_REWARD account = self.get_account(miner_address) account['balance'] += reward + + def to_dict(self): + return { + "accounts": self.accounts, + "default_mining_reward": self.DEFAULT_MINING_REWARD + } + + @classmethod + def from_dict(cls, data): + new_state = cls() + new_state.accounts = data.get("accounts", {}) + new_state.DEFAULT_MINING_REWARD = data.get("default_mining_reward", 50) + return new_state diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000..9384b71 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,5 @@ +import hashlib + + +def _sha256(data: str) -> str: + return hashlib.sha256(data.encode()).hexdigest() diff --git a/main.py b/main.py index d9670c0..26adeae 100644 --- a/main.py +++ b/main.py @@ -1,19 +1,16 @@ import asyncio import logging -import re from nacl.signing import SigningKey from nacl.encoding import HexEncoder from core import Transaction, Blockchain, Block, State +from core.mining import mine_and_process_block, sync_nonce from node import Mempool from network import P2PNetwork -from consensus import mine_block logger = logging.getLogger(__name__) -BURN_ADDRESS = "0" * 40 - def create_wallet(): sk = SigningKey.generate() @@ -21,62 +18,6 @@ def create_wallet(): return sk, pk -def mine_and_process_block(chain, mempool, pending_nonce_map): - """ - Mine block and let Blockchain handle validation + state updates. - DO NOT manually apply transactions again. - """ - - pending_txs = mempool.get_transactions_for_block() - - block = Block( - index=chain.last_block.index + 1, - previous_hash=chain.last_block.hash, - transactions=pending_txs, - ) - - mined_block = mine_block(block) - - if not hasattr(mined_block, "miner"): - mined_block.miner = BURN_ADDRESS - - deployed_contracts: list[str] = [] - - if chain.add_block(mined_block): - logger.info("Block #%s added", mined_block.index) - - miner_attr = getattr(mined_block, "miner", None) - if isinstance(miner_attr, str) and re.match(r'^[0-9a-fA-F]{40}$', miner_attr): - miner_address = miner_attr - else: - logger.warning("Invalid miner address. Crediting burn address.") - miner_address = BURN_ADDRESS - - # Reward must go through chain.state - chain.state.credit_mining_reward(miner_address) - - for tx in mined_block.transactions: - sync_nonce(chain.state, pending_nonce_map, tx.sender) - - # Track deployed contracts if your state.apply_transaction returns address - result = chain.state.get_account(tx.receiver) if tx.receiver else None - if isinstance(result, dict): - deployed_contracts.append(tx.receiver) - - return mined_block, deployed_contracts - else: - logger.error("Block rejected by chain") - return None, [] - - -def sync_nonce(state, pending_nonce_map, address): - account = state.get_account(address) - if account and "nonce" in account: - pending_nonce_map[address] = account["nonce"] - else: - pending_nonce_map[address] = 0 - - async def node_loop(): logger.info("Starting MiniChain Node with Smart Contracts") diff --git a/node/mempool.py b/node/mempool.py index 8bb941a..2429bae 100644 --- a/node/mempool.py +++ b/node/mempool.py @@ -18,7 +18,7 @@ def _get_tx_id(self, tx): """ return calculate_hash(tx.to_dict()) - def add_transaction(self, tx): +def add_transaction(self, tx): """ Adds a transaction to the pool if: - Signature is valid @@ -38,14 +38,23 @@ def add_transaction(self, tx): if len(self._pending_txs) >= self.max_size: # Simple eviction: drop oldest or reject. Here we reject. - logger.warning("Mempool: Full, rejecting transaction") + logger.warning(f"Mempool: Pool full, transaction rejected") return False self._pending_txs.append(tx) self._seen_tx_ids.add(tx_id) - + logger.info(f"Mempool: Added transaction {tx_id}") return True + def return_transactions(self, transactions): + """ + Return transactions to the pool after failed mining attempt. + """ + tx_ids = {self._get_tx_id(tx) for tx in transactions} + with self._lock: + self._pending_txs.extend(transactions) + self._seen_tx_ids.update(tx_ids) + def get_transactions_for_block(self): """ Returns pending transactions and clears the pool. diff --git a/requirements.txt b/requirements.txt index 819e170..2410f74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,5 @@ pynacl==1.6.2 libp2p==0.5.0 +fastapi==0.129.2 +uvicorn==0.41.0 +pydantic==2.12.5 diff --git a/tests/test_merkle.py b/tests/test_merkle.py new file mode 100644 index 0000000..dd5b4b6 --- /dev/null +++ b/tests/test_merkle.py @@ -0,0 +1,137 @@ +import unittest +from core.merkle import MerkleTree, calculate_merkle_root + + +class TestMerkleTree(unittest.TestCase): + def test_empty_transactions(self): + root = calculate_merkle_root([]) + self.assertIsNone(root) + + def test_single_transaction(self): + tx = {"sender": "alice", "receiver": "bob", "amount": 10} + tree = MerkleTree([tx]) + root = tree.get_merkle_root() + self.assertIsNotNone(root) + self.assertEqual(len(root), 64) + + proof = tree.get_proof(0) + self.assertEqual(proof, []) + + tx_hash = tree.tx_hashes[0] + result = MerkleTree.verify_proof(tx_hash, proof, root) + self.assertTrue(result) + + def test_two_transactions(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5} + ] + tree = MerkleTree(txs) + root = tree.get_merkle_root() + self.assertIsNotNone(root) + + proof0 = tree.get_proof(0) + proof1 = tree.get_proof(1) + self.assertIsNotNone(proof0) + self.assertIsNotNone(proof1) + + result0 = MerkleTree.verify_proof(tree.tx_hashes[0], proof0, root) + result1 = MerkleTree.verify_proof(tree.tx_hashes[1], proof1, root) + self.assertTrue(result0) + self.assertTrue(result1) + + def test_odd_transaction_count(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5}, + {"sender": "charlie", "receiver": "dave", "amount": 3} + ] + tree = MerkleTree(txs) + root = tree.get_merkle_root() + self.assertIsNotNone(root) + + for i in range(len(txs)): + proof = tree.get_proof(i) + result = MerkleTree.verify_proof(tree.tx_hashes[i], proof, root) + self.assertTrue(result) + + def test_proof_generation(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5}, + {"sender": "charlie", "receiver": "dave", "amount": 3}, + {"sender": "dave", "receiver": "eve", "amount": 1} + ] + tree = MerkleTree(txs) + + for i in range(len(txs)): + proof = tree.get_proof(i) + self.assertIsNotNone(proof) + self.assertTrue(len(proof) > 0) + + def test_proof_verification(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5}, + {"sender": "charlie", "receiver": "dave", "amount": 3}, + {"sender": "dave", "receiver": "eve", "amount": 1} + ] + tree = MerkleTree(txs) + root = tree.get_merkle_root() + + for i, tx_hash in enumerate(tree.tx_hashes): + proof = tree.get_proof(i) + result = MerkleTree.verify_proof(tx_hash, proof, root) + self.assertTrue(result) + + def test_proof_verification_fails_wrong_root(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5} + ] + tree = MerkleTree(txs) + + wrong_root = "0" * 64 + tx_hash = tree.tx_hashes[0] + proof = tree.get_proof(0) + + result = MerkleTree.verify_proof(tx_hash, proof, wrong_root) + self.assertFalse(result) + + def test_proof_verification_fails_wrong_tx_hash(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5} + ] + tree = MerkleTree(txs) + + root = tree.get_merkle_root() + proof = tree.get_proof(0) + + tampered_tx_hash = "a" * 64 + result = MerkleTree.verify_proof(tampered_tx_hash, proof, root) + self.assertFalse(result) + + def test_calculate_merkle_root_matches_tree(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5}, + {"sender": "charlie", "receiver": "dave", "amount": 3} + ] + root1 = calculate_merkle_root(txs) + root2 = MerkleTree(txs).get_merkle_root() + self.assertEqual(root1, root2) + + def test_invalid_index(self): + txs = [ + {"sender": "alice", "receiver": "bob", "amount": 10}, + {"sender": "bob", "receiver": "charlie", "amount": 5} + ] + tree = MerkleTree(txs) + + self.assertIsNone(tree.get_proof(10)) + self.assertIsNone(tree.get_proof(-1)) + + +if __name__ == '__main__': + unittest.main()