Files
Compsys-2021-Assignments/A3/python/peer/peer.py
Nikolaj 84fbccb94d 🐛
2021-11-05 13:49:35 +01:00

436 lines
16 KiB
Python

#!/bin/python3
import argparse
import os
import hashlib
import struct
import math
import random
import socket
import time
import socketserver
import threading
from typing import Any, List, Union
HASH_LENGTH=32
# Helper method to mask network package fragmentation
def readbytes(s, count) -> bytes:
res = bytearray()
while count > 0:
msg = s.recv(count)
if not msg:
raise Exception('Server did not respond')
res.extend(msg)
count -= len(msg)
return bytes(res)
class CascadeBlock(object):
def __init__(self, index: int, offset: int, length: int, hash: bytes) -> None:
self.index = index
self.offset = offset
self.length = length
self.hash = hash
if len(self.hash) != HASH_LENGTH:
raise Exception(f'Hash length should be {HASH_LENGTH} but was {len(self.hash)} (offset: {self.offset})')
self.completed = False
class CascadeFile(object):
def __init__(self, path: str, destination: str = None) -> None:
self.destination = destination
with open(path, 'rb') as f:
self.cascadehash = hashlib.sha256(f.read()).digest()
with open(path, 'rb') as f:
self.header = f.read(64)
if bytes('CASCADE1', 'ascii') != self.header[:8]:
raise Exception(f'Signature in file header is invalid: {path}')
self.targetsize = struct.unpack('!Q', self.header[16:24])[0]
self.blocksize = struct.unpack('!Q', self.header[24:32])[0]
self.targethash = self.header[32:64]
self.trailblocksize = self.targetsize % self.blocksize
if self.trailblocksize == 0: # In case the file size is evenly divisible with the block size
self.trailblocksize = self.blocksize
blockcount = math.ceil(self.targetsize / self.blocksize)
self.blocks = []
for n in range(blockcount):
hash = f.read(HASH_LENGTH)
if not hash:
raise Exception(f'Incorrect number of hashes in file: {path}, got {n}, but expected {blockcount}')
self.blocks.append(
CascadeBlock(
n,
n * self.blocksize,
self.trailblocksize if n == blockcount - 1 else self.blocksize,
hash
)
)
if destination is not None:
self.prepare_download(destination)
def prepare_download(self, destination: str) -> None:
# Make sure the target exists
if not os.path.isfile(destination):
with open(destination, 'wb'):
pass
# Check existing blocks for completion
with open(destination, 'rb') as f:
for n in range(len(self.blocks)):
size = self.blocksize if n != len(self.blocks) - 1 else self.trailblocksize
data = f.read(size)
if not data:
break
if len(data) != size:
break
if hashlib.sha256(data).digest() == self.blocks[n].hash:
self.blocks[n].completed = True
# Create a queue of missing blocks
self.queue = list([x for x in self.blocks if not x.completed])
class Peer(object):
def __init__(self, ip: str, port: int, timestamp: int, good: bool):
self.ip = ip
self.port = port
self.timestamp = timestamp
self.good = good
def download_block(self, cascadehash, blockno, blocksize) -> bytes:
req = bytearray(bytes('CASCADE1', 'ascii'))
req.extend(struct.pack('!Q', 0))
req.extend(struct.pack('!Q', 0))
req.extend(struct.pack('!Q', blockno))
req.extend(cascadehash)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((self.ip, self.port))
s.sendall(req)
data = s.recv(9)
size = struct.unpack('!Q', data[1:9])[0]
if data[0] == 0:
if size != blocksize:
raise Exception(f'The reported block size was {size} but {blocksize} was expected')
return readbytes(s, size)
elif data[0] == 1:
msg = readbytes(s, size).decode('utf-8')
raise Exception(f'Client does not have file: {msg}')
elif data[0] == 2:
msg = readbytes(s, size).decode('utf-8')
print(f'Client does not have block {blockno}: {msg}')
return None
elif data[0] == 3:
msg = readbytes(s, size).decode('utf-8')
raise Exception(f'Client reported blockno was out-of-bounds: {msg}')
elif data[0] == 4:
msg = readbytes(s, size).decode('utf-8')
raise Exception(f'Client reported invalid request: {msg}')
else:
raise Exception(f'Client gave unsupported error code: {data[0]}')
class Tracker(object):
def __init__(self, ip: str, port: int):
self.ip = ip
self.port = port
def list(self, hash: bytes, ip: Union[bytes, str], port: int) -> List[Peer]:
return self.send_to_server(self.build_request(1, hash, ip, port))
def subscribe(self, hash: bytes, ip: Union[bytes, str], port: int) -> List[Peer]:
return self.send_to_server(self.build_request(2, hash, ip, port))
def build_request(self, command: int, hash: bytes, ip: Union[bytes, str], port: int) -> bytes:
if isinstance(ip, bytes) or isinstance(ip, bytearray):
ipbytes = ip
else:
ipbytes = bytes(map(int, ip.split('.')))
if len(ipbytes) != 4:
raise Exception('Incorrect IP address?')
req = bytearray(bytes('CASC', 'ascii'))
req.extend(struct.pack('!I', 1))
req.extend(struct.pack('!I', command))
req.extend(struct.pack('!I', len(hash) + 4 + 2))
req.extend(hash)
req.extend(ipbytes)
req.extend(struct.pack('!H', port))
return req
def unparse_peer(self, data: bytes) -> Peer:
return Peer(
'.'.join([str(x) for x in data[0:4]]),
struct.unpack('!H', data[4:6])[0],
struct.unpack('!I', data[6:10])[0],
True if data[10] == 1 else False
)
def send_to_server(self, request: bytes) -> List[Peer]:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((self.ip, self.port))
s.sendall(request)
data = s.recv(5)
if not data:
raise Exception('Server did not respond')
if len(data) != 5:
raise Exception('Network protocol error')
size = struct.unpack('!I', data[1:5])[0]
if size > 1024 * 1024:
raise Exception(f'Too big server response: {size}')
if data[0] != 0:
msg = readbytes(s, size).decode('utf-8')
raise Exception(f'Tracker gave error: {msg}')
else:
if size % 12 != 0:
raise Exception('Tracker gave a peer list with a size not a modulo of two')
blob = readbytes(s, size)
peers = size // 12
return list([self.unparse_peer(blob[x*12:(x+1)*12]) for x in range(peers)])
class CascadePeerServe(socketserver.StreamRequestHandler):
activefiles = {}
activefileslock = threading.Lock()
HEADER_LENGTH = 8+16+8+HASH_LENGTH
def handle(self) -> None:
while True:
try:
header = self.request.recv(self.HEADER_LENGTH)
except:
return
if header is None:
return
if len(header) != self.HEADER_LENGTH:
self.reporterror(4, f'Invalid header length, got {len(header)} but expected {self.HEADER_LENGTH}')
return
if header[:8] != bytes('CASCADE1', 'ascii'):
self.reporterror(4, f'Invalid header, should start with "CASCADE1", got {header[:8].hex()}')
return
blockno = struct.unpack('!Q', header[24:32])[0]
hash = header[32:64]
file = None
with self.activefileslock:
if hash in self.activefiles:
file = self.activefiles[hash]
if file is None:
self.reporterror(1, f'Not serving file with hash: {hash.hex()}')
return
if blockno >= len(file.blocks):
self.reporterror(3, f'File has {len(file.blocks)} blocks, but {blockno} was requested')
return
block = file.blocks[blockno]
if not block.completed:
self.reporterror(2, f'{blockno} is not currently held by this peer')
continue
with open(file.destination, 'rb') as f:
if f.seek(block.offset) != block.offset:
self.reporterror(2, f'{blockno} is not currently held by this peer (corrupt state)')
return
data = f.read(block.length)
if len(data) != block.length:
self.reporterror(2, f'{blockno} is not currently held by this peer (corrupt state)')
return
resp = bytearray()
resp.extend(struct.pack('!B', 0))
resp.extend(struct.pack('!Q', block.length))
resp.extend(data)
self.request.sendall(resp)
def reporterror(self, code: int, msg: str) -> None:
msgdata = bytes(msg, 'utf-8')
data = bytearray(struct.pack('!B', code))
data.extend(struct.pack('!Q', len(msgdata)))
data.extend(msgdata)
self.request.sendall(data)
class P2PServer(object):
def __init__(self, tracker, ip, port) -> None:
self.tracker = tracker
self.ip = ip
self.ipbytes = bytes(map(int, selfaddr.split('.')))
self.port = port
self.stopped = False
self.refreshsemaphore = threading.Semaphore(0)
self.peers = {}
self.serverthread = threading.Thread(target=self.run_peer_server, daemon=True)
self.refreshthread = threading.Thread(target=self.run_peer_subscribe, daemon=True)
self.serverthread.start()
self.refreshthread.start()
def resubscribe(self) -> None:
self.refreshsemaphore.release()
def join(self) -> None:
self.serverthread.join()
def stop(self) -> None:
self.stopped = True
self.server.shutdown()
self.refreshsemaphore.release()
self.refreshthread.join()
self.serverthread.join()
def run_peer_server(self) -> None:
print(f"Running peer server on {self.ip}:{self.port}")
with socketserver.ThreadingTCPServer((self.ip, self.port), CascadePeerServe) as server:
self.server = server
server.serve_forever(poll_interval=10)
def run_peer_subscribe(self) -> None:
while not self.stopped:
with CascadePeerServe.activefileslock:
hashes = list(CascadePeerServe.activefiles)
for h in hashes:
try:
self.peers[h] = self.tracker.subscribe(h, self.ipbytes, self.port)
except Exception as e:
print(f"Tracker register failed: {e}")
self.refreshsemaphore.acquire(timeout=60*10)
def run_peer_download(tracker: Tracker, source: str, localip: str, localport: int, server:P2PServer = None, output:str = None, randomseq:bool =True) -> None:
if not os.path.isfile(source):
print("File not found: {args.source}")
exit(1)
if output is None:
output = os.path.splitext(source)[0]
print(f"Preparing download of {source} to {output}")
file = CascadeFile(source, output)
if len(file.queue) != 0:
print(f"Download will require {len(file.queue)} blocks of size {file.blocksize}")
with CascadePeerServe.activefileslock:
CascadePeerServe.activefiles[file.cascadehash] = file
if server is not None:
server.resubscribe() # Trigger server subscription with new file
peers = []
last_peer_update = 0
while len(file.queue) > 0:
# Ensure we have peers
while len(peers) == 0 or time.time() - last_peer_update > 60*5:
if time.time() - last_peer_update < 30:
print(f'Throttling peer update; wait for 30s')
time.sleep(30)
try:
peers = tracker.list(file.cascadehash, localip, localport)
except Exception as e:
print(f"Tracker error: {e}")
finally:
last_peer_update = time.time()
if len(peers) == 0:
print(f"We have no peers, sleeping 10s before trying again")
time.sleep(10)
else:
print(f"We got {len(peers)} peer{'s' if len(peers) == 1 else ''}")
# Pick a block
blockid = 0 if not randomseq else random.randrange(0, len(file.queue))
peerid = 0 if not randomseq else random.randrange(0, len(peers))
block = file.queue[blockid]
peer = peers[peerid]
# Grab it
print(f"Attempting to fetch block {block.index} from {peer.ip}:{peer.port}")
print(f"Attempting to fetch block with hash {file.cascadehash}")
try:
data = peer.download_block(file.cascadehash, block.index, block.length)
if data is None:
raise Exception('Peer did not have the requested block')
datahash = hashlib.sha256(data).digest()
if datahash != block.hash:
print(f"Invalid hash for block {block.index} from {peer.ip}:{peer.port}. Got {datahash.hex()} but expected {block.hash.hex()}")
raise Exception('Downloaded block was incorrect')
with open(output, 'r+b') as f:
f.seek(block.offset)
f.write(data)
block.completed = True
file.queue.remove(block)
print(f"Retrieved block {block.index}")
except Exception as e:
peers.remove(peer)
print(f"Download failure ignoring peer. {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source", help="The cascade source file(s) with information about the download")
parser.add_argument("tracker", help="The tracker ip and port, eg: localhost:8888", default="127.0.0.1:8888")
parser.add_argument("self", help="The address to report serving files from, e.g.: 1.2.3.4:5555", default="127.0.0.1:7777")
parser.add_argument("-o", "--output", help="The target output file(s)", required=False)
parser.add_argument("-r", "--random", help="Download blocks in random order from random peers", type=bool, required=False, default=True)
parser.add_argument("-c", "--clientonly", help="Flag to set client-only mode (i.e. no serving)", type=bool, required=False, default=False)
args = parser.parse_args()
tracker = Tracker(args.tracker.split(':')[0], int(args.tracker.split(':')[1]))
selfaddr, selfport = args.self.split(':')
selfport = int(selfport)
server = None
if not args.clientonly:
server = P2PServer(tracker, selfaddr, selfport)
source_files = args.source.split(os.pathsep)
if args.output is None:
target_files = []
else:
target_files = args.output.split(os.pathsep)
for ix, f in zip(range(len(source_files)), source_files):
target = None
if ix < len(target_files):
target = target_files[ix]
run_peer_download(tracker, f, selfaddr, selfport, server, target, args.random)
if args.clientonly:
print("Download complete in client mode, stopping")
else:
print("Download complete serving forever ...")
server.join()