Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
logger = logging.getLogger(__name__)

BURN_ADDRESS = "0" * 40
TRUSTED_PEERS = set()
LOCALHOST_PEERS = {"127.0.0.1", "::1", "localhost", "0:0:0:0:0:0:0:1"}


# ──────────────────────────────────────────────
Expand Down Expand Up @@ -68,6 +70,11 @@ def mine_and_process_block(chain, mempool, miner_pk):
return mined_block
else:
logger.error("❌ Block rejected by chain")
restored = 0
for tx in pending_txs:
if mempool.add_transaction(tx):
restored += 1
logger.info("Mempool: Restored %d/%d txs after rejection", restored, len(pending_txs))
return None


Expand All @@ -81,15 +88,31 @@ def make_network_handler(chain, mempool):
async def handler(data):
msg_type = data.get("type")
payload = data.get("data")
peer_addr = data.get("_peer_addr", "unknown")

if msg_type == "sync":
peer_host = peer_addr.rsplit(":", 1)[0] if ":" in peer_addr else peer_addr
peer_host = peer_host.strip("[]")
is_trusted = peer_addr in TRUSTED_PEERS or peer_host in TRUSTED_PEERS
is_localhost = peer_host in LOCALHOST_PEERS
if chain.state.accounts and not (is_trusted or is_localhost):
logger.warning("🔒 Rejected sync from untrusted peer %s", peer_addr)
return

# Merge remote state into local state (for accounts we don't have yet)
remote_accounts = payload.get("accounts", {})
remote_accounts = payload.get("accounts") if isinstance(payload, dict) else None
if not isinstance(remote_accounts, dict):
logger.warning("🔒 Rejected sync from %s with invalid accounts payload", peer_addr)
return

for addr, acc in remote_accounts.items():
if not isinstance(acc, dict):
logger.warning("🔒 Skipping malformed account %r from %s", addr, peer_addr)
continue
if addr not in chain.state.accounts:
chain.state.accounts[addr] = acc
logger.info("🔄 Synced account %s... (balance=%d)", addr[:12], acc.get("balance", 0))
logger.info("🔄 State sync complete — %d accounts", len(chain.state.accounts))
logger.info("🔄 Accepted state sync from %s — %d accounts", peer_addr, len(chain.state.accounts))

elif msg_type == "tx":
tx = Transaction(**payload)
Expand Down Expand Up @@ -134,15 +157,15 @@ async def handler(data):
╔════════════════════════════════════════════════╗
║ MiniChain Commands ║
╠════════════════════════════════════════════════╣
║ balance show all balances ║
║ send <to> <amount> send coins ║
║ mine mine a block ║
║ peers show connected peers ║
║ connect <host:port> connect to a peer ║
║ address show your public key ║
║ chain show chain summary ║
║ help show this help
║ quit shut down
║ balance - show all balances ║
║ send <to> <amount> - send coins ║
║ mine - mine a block ║
║ peers - show connected peers ║
║ connect <host>:<port> - connect to a peer ║
║ address - show your public key ║
║ chain - show chain summary ║
║ help - show this help ║
║ quit - shut down ║
╚════════════════════════════════════════════════╝
"""

Expand Down Expand Up @@ -220,7 +243,11 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter):
except ValueError:
print(" Invalid format. Use host:port")
continue
await network.connect_to_peer(host, port)
success = await network.connect_to_peer(host, port)
if success:
print(f" Connected to {host}:{port}")
else:
print(f" Failed to connect to {host}:{port}")

# ── address ──
elif cmd == "address":
Expand Down Expand Up @@ -249,7 +276,7 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter):
# Main entry point
# ──────────────────────────────────────────────

async def run_node(port: int, connect_to: str | None, fund: int):
async def run_node(port: int, connect_to: str | None, fund: int, host: str = "127.0.0.1"):
"""Boot the node, optionally connect to a peer, then enter the CLI."""
sk, pk = create_wallet()

Expand All @@ -271,14 +298,9 @@ async def on_peer_connected(writer):
await writer.drain()
logger.info("🔄 Sent state sync to new peer")

network._on_peer_connected = on_peer_connected
network.set_on_peer_connected(on_peer_connected)

await network.start(port=port)

# Fund this node's wallet so it can transact in the demo
if fund > 0:
chain.state.credit_mining_reward(pk, reward=fund)
logger.info("💰 Funded %s... with %d coins", pk[:12], fund)
await network.start(port=port, host=host)

# Connect to a seed peer if requested
if connect_to:
Expand All @@ -288,6 +310,11 @@ async def on_peer_connected(writer):
except ValueError:
logger.error("Invalid --connect format. Use host:port")

# Fund this node's wallet so it can transact in the demo
if fund > 0:
chain.state.credit_mining_reward(pk, reward=fund)
logger.info("💰 Funded %s... with %d coins", pk[:12], fund)

# Nonce counter kept as a mutable list so the CLI closure can mutate it
nonce_counter = [0]

Expand All @@ -299,6 +326,7 @@ async def on_peer_connected(writer):

def main():
parser = argparse.ArgumentParser(description="MiniChain Node — Testnet Demo")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host/IP to bind the P2P server (default: 127.0.0.1)")
parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)")
parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (host:port)")
parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)")
Expand All @@ -311,7 +339,7 @@ def main():
)

try:
asyncio.run(run_node(args.port, args.connect, args.fund))
asyncio.run(run_node(args.port, args.connect, args.fund, args.host))
except KeyboardInterrupt:
print("\nNode shut down.")

Expand Down
29 changes: 25 additions & 4 deletions minichain/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,31 @@ def register_handler(self, handler_callback):
raise ValueError("handler_callback must be callable")
self._handler_callback = handler_callback

def set_on_peer_connected(self, callback):
if callback is not None and not asyncio.iscoroutinefunction(callback):
raise ValueError("on_peer_connected callback must be an async callable")
self._on_peer_connected = callback

# ------------------------------------------------------------------
# Server lifecycle
# ------------------------------------------------------------------

async def start(self, port: int = 9000):
async def start(self, port: int = 9000, host: str = "127.0.0.1"):
"""Start listening for incoming peer connections on the given port."""
self._port = port
self._server = await asyncio.start_server(
self._handle_incoming, "0.0.0.0", port
self._handle_incoming, host, port
)
logger.info("Network: Listening on 0.0.0.0:%d", port)
logger.info("Network: Listening on %s:%d", host, port)

async def stop(self):
"""Gracefully shut down the server and disconnect all peers."""
logger.info("Network: Shutting down")
for task in self._listen_tasks:
task.cancel()
if self._listen_tasks:
await asyncio.gather(*self._listen_tasks, return_exceptions=True)
self._listen_tasks.clear()
for _, writer in self._peers:
try:
writer.close()
Expand All @@ -75,6 +83,11 @@ async def connect_to_peer(self, host: str, port: int) -> bool:
try:
reader, writer = await asyncio.open_connection(host, port)
self._peers.append((reader, writer))
if self._on_peer_connected:
try:
await self._on_peer_connected(writer)
except Exception:
logger.exception("Network: Error during peer sync")
task = asyncio.create_task(self._listen_to_peer(reader, writer, f"{host}:{port}"))
self._listen_tasks.append(task)
logger.info("Network: Connected to peer %s:%d", host, port)
Expand Down Expand Up @@ -110,6 +123,8 @@ async def _listen_to_peer(self, reader: asyncio.StreamReader, writer: asyncio.St
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Network: Malformed message from %s", addr)
continue
if isinstance(data, dict):
data["_peer_addr"] = addr

if self._handler_callback:
try:
Expand Down Expand Up @@ -144,7 +159,13 @@ async def _broadcast_raw(self, payload: dict):
await writer.drain()
except Exception:
disconnected.append((reader, writer))
for pair in disconnected:
for reader, writer in disconnected:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
pair = (reader, writer)
if pair in self._peers:
self._peers.remove(pair)

Expand Down