436 lines
16 KiB
Python
436 lines
16 KiB
Python
#!/bin/python3
|
|
|
|
import argparse
|
|
import os
|
|
import hashlib
|
|
import struct
|
|
import math
|
|
import random
|
|
import socket
|
|
import time
|
|
import socketserver
|
|
import threading
|
|
from typing import Any, List, Union
|
|
|
|
HASH_LENGTH=32
|
|
|
|
# Helper method to mask network package fragmentation
|
|
def readbytes(s, count) -> bytes:
|
|
res = bytearray()
|
|
while count > 0:
|
|
msg = s.recv(count)
|
|
if not msg:
|
|
raise Exception('Server did not respond')
|
|
res.extend(msg)
|
|
count -= len(msg)
|
|
|
|
return bytes(res)
|
|
|
|
|
|
class CascadeBlock(object):
|
|
def __init__(self, index: int, offset: int, length: int, hash: bytes) -> None:
|
|
self.index = index
|
|
self.offset = offset
|
|
self.length = length
|
|
self.hash = hash
|
|
if len(self.hash) != HASH_LENGTH:
|
|
raise Exception(f'Hash length should be {HASH_LENGTH} but was {len(self.hash)} (offset: {self.offset})')
|
|
self.completed = False
|
|
|
|
class CascadeFile(object):
|
|
def __init__(self, path: str, destination: str = None) -> None:
|
|
self.destination = destination
|
|
with open(path, 'rb') as f:
|
|
self.cascadehash = hashlib.sha256(f.read()).digest()
|
|
|
|
with open(path, 'rb') as f:
|
|
self.header = f.read(64)
|
|
if bytes('CASCADE1', 'ascii') != self.header[:8]:
|
|
raise Exception(f'Signature in file header is invalid: {path}')
|
|
self.targetsize = struct.unpack('!Q', self.header[16:24])[0]
|
|
self.blocksize = struct.unpack('!Q', self.header[24:32])[0]
|
|
self.targethash = self.header[32:64]
|
|
self.trailblocksize = self.targetsize % self.blocksize
|
|
if self.trailblocksize == 0: # In case the file size is evenly divisible with the block size
|
|
self.trailblocksize = self.blocksize
|
|
|
|
blockcount = math.ceil(self.targetsize / self.blocksize)
|
|
|
|
self.blocks = []
|
|
for n in range(blockcount):
|
|
hash = f.read(HASH_LENGTH)
|
|
if not hash:
|
|
raise Exception(f'Incorrect number of hashes in file: {path}, got {n}, but expected {blockcount}')
|
|
|
|
self.blocks.append(
|
|
CascadeBlock(
|
|
n,
|
|
n * self.blocksize,
|
|
self.trailblocksize if n == blockcount - 1 else self.blocksize,
|
|
hash
|
|
)
|
|
)
|
|
|
|
if destination is not None:
|
|
self.prepare_download(destination)
|
|
|
|
def prepare_download(self, destination: str) -> None:
|
|
# Make sure the target exists
|
|
if not os.path.isfile(destination):
|
|
with open(destination, 'wb'):
|
|
pass
|
|
|
|
# Check existing blocks for completion
|
|
with open(destination, 'rb') as f:
|
|
for n in range(len(self.blocks)):
|
|
size = self.blocksize if n != len(self.blocks) - 1 else self.trailblocksize
|
|
data = f.read(size)
|
|
if not data:
|
|
break
|
|
if len(data) != size:
|
|
break
|
|
if hashlib.sha256(data).digest() == self.blocks[n].hash:
|
|
self.blocks[n].completed = True
|
|
|
|
# Create a queue of missing blocks
|
|
self.queue = list([x for x in self.blocks if not x.completed])
|
|
|
|
|
|
class Peer(object):
|
|
def __init__(self, ip: str, port: int, timestamp: int, good: bool):
|
|
self.ip = ip
|
|
self.port = port
|
|
self.timestamp = timestamp
|
|
self.good = good
|
|
|
|
def download_block(self, cascadehash, blockno, blocksize) -> bytes:
|
|
req = bytearray(bytes('CASCADE1', 'ascii'))
|
|
req.extend(struct.pack('!Q', 0))
|
|
req.extend(struct.pack('!Q', 0))
|
|
req.extend(struct.pack('!Q', blockno))
|
|
req.extend(cascadehash)
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect((self.ip, self.port))
|
|
s.sendall(req)
|
|
data = s.recv(9)
|
|
size = struct.unpack('!Q', data[1:9])[0]
|
|
|
|
if data[0] == 0:
|
|
if size != blocksize:
|
|
raise Exception(f'The reported block size was {size} but {blocksize} was expected')
|
|
return readbytes(s, size)
|
|
elif data[0] == 1:
|
|
msg = readbytes(s, size).decode('utf-8')
|
|
raise Exception(f'Client does not have file: {msg}')
|
|
elif data[0] == 2:
|
|
msg = readbytes(s, size).decode('utf-8')
|
|
print(f'Client does not have block {blockno}: {msg}')
|
|
return None
|
|
elif data[0] == 3:
|
|
msg = readbytes(s, size).decode('utf-8')
|
|
raise Exception(f'Client reported blockno was out-of-bounds: {msg}')
|
|
elif data[0] == 4:
|
|
msg = readbytes(s, size).decode('utf-8')
|
|
raise Exception(f'Client reported invalid request: {msg}')
|
|
else:
|
|
raise Exception(f'Client gave unsupported error code: {data[0]}')
|
|
|
|
|
|
class Tracker(object):
|
|
def __init__(self, ip: str, port: int):
|
|
self.ip = ip
|
|
self.port = port
|
|
|
|
def list(self, hash: bytes, ip: Union[bytes, str], port: int) -> List[Peer]:
|
|
return self.send_to_server(self.build_request(1, hash, ip, port))
|
|
|
|
def subscribe(self, hash: bytes, ip: Union[bytes, str], port: int) -> List[Peer]:
|
|
return self.send_to_server(self.build_request(2, hash, ip, port))
|
|
|
|
def build_request(self, command: int, hash: bytes, ip: Union[bytes, str], port: int) -> bytes:
|
|
if isinstance(ip, bytes) or isinstance(ip, bytearray):
|
|
ipbytes = ip
|
|
else:
|
|
ipbytes = bytes(map(int, ip.split('.')))
|
|
|
|
if len(ipbytes) != 4:
|
|
raise Exception('Incorrect IP address?')
|
|
|
|
req = bytearray(bytes('CASC', 'ascii'))
|
|
req.extend(struct.pack('!I', 1))
|
|
req.extend(struct.pack('!I', command))
|
|
req.extend(struct.pack('!I', len(hash) + 4 + 2))
|
|
req.extend(hash)
|
|
req.extend(ipbytes)
|
|
req.extend(struct.pack('!H', port))
|
|
|
|
return req
|
|
|
|
|
|
def unparse_peer(self, data: bytes) -> Peer:
|
|
return Peer(
|
|
'.'.join([str(x) for x in data[0:4]]),
|
|
struct.unpack('!H', data[4:6])[0],
|
|
struct.unpack('!I', data[6:10])[0],
|
|
True if data[10] == 1 else False
|
|
)
|
|
|
|
def send_to_server(self, request: bytes) -> List[Peer]:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect((self.ip, self.port))
|
|
s.sendall(request)
|
|
data = s.recv(5)
|
|
if not data:
|
|
raise Exception('Server did not respond')
|
|
if len(data) != 5:
|
|
raise Exception('Network protocol error')
|
|
|
|
size = struct.unpack('!I', data[1:5])[0]
|
|
if size > 1024 * 1024:
|
|
raise Exception(f'Too big server response: {size}')
|
|
if data[0] != 0:
|
|
msg = readbytes(s, size).decode('utf-8')
|
|
raise Exception(f'Tracker gave error: {msg}')
|
|
else:
|
|
if size % 12 != 0:
|
|
raise Exception('Tracker gave a peer list with a size not a modulo of two')
|
|
|
|
blob = readbytes(s, size)
|
|
peers = size // 12
|
|
return list([self.unparse_peer(blob[x*12:(x+1)*12]) for x in range(peers)])
|
|
|
|
|
|
class CascadePeerServe(socketserver.StreamRequestHandler):
|
|
activefiles = {}
|
|
activefileslock = threading.Lock()
|
|
HEADER_LENGTH = 8+16+8+HASH_LENGTH
|
|
|
|
def handle(self) -> None:
|
|
while True:
|
|
try:
|
|
header = self.request.recv(self.HEADER_LENGTH)
|
|
except:
|
|
return
|
|
|
|
if header is None:
|
|
return
|
|
|
|
if len(header) != self.HEADER_LENGTH:
|
|
self.reporterror(4, f'Invalid header length, got {len(header)} but expected {self.HEADER_LENGTH}')
|
|
return
|
|
|
|
if header[:8] != bytes('CASCADE1', 'ascii'):
|
|
self.reporterror(4, f'Invalid header, should start with "CASCADE1", got {header[:8].hex()}')
|
|
return
|
|
|
|
blockno = struct.unpack('!Q', header[24:32])[0]
|
|
hash = header[32:64]
|
|
|
|
file = None
|
|
with self.activefileslock:
|
|
if hash in self.activefiles:
|
|
file = self.activefiles[hash]
|
|
|
|
if file is None:
|
|
self.reporterror(1, f'Not serving file with hash: {hash.hex()}')
|
|
return
|
|
|
|
if blockno >= len(file.blocks):
|
|
self.reporterror(3, f'File has {len(file.blocks)} blocks, but {blockno} was requested')
|
|
return
|
|
|
|
block = file.blocks[blockno]
|
|
if not block.completed:
|
|
self.reporterror(2, f'{blockno} is not currently held by this peer')
|
|
continue
|
|
|
|
with open(file.destination, 'rb') as f:
|
|
if f.seek(block.offset) != block.offset:
|
|
self.reporterror(2, f'{blockno} is not currently held by this peer (corrupt state)')
|
|
return
|
|
|
|
data = f.read(block.length)
|
|
if len(data) != block.length:
|
|
self.reporterror(2, f'{blockno} is not currently held by this peer (corrupt state)')
|
|
return
|
|
|
|
resp = bytearray()
|
|
resp.extend(struct.pack('!B', 0))
|
|
resp.extend(struct.pack('!Q', block.length))
|
|
resp.extend(data)
|
|
|
|
self.request.sendall(resp)
|
|
|
|
def reporterror(self, code: int, msg: str) -> None:
|
|
msgdata = bytes(msg, 'utf-8')
|
|
data = bytearray(struct.pack('!B', code))
|
|
data.extend(struct.pack('!Q', len(msgdata)))
|
|
data.extend(msgdata)
|
|
|
|
self.request.sendall(data)
|
|
|
|
class P2PServer(object):
|
|
def __init__(self, tracker, ip, port) -> None:
|
|
self.tracker = tracker
|
|
self.ip = ip
|
|
self.ipbytes = bytes(map(int, selfaddr.split('.')))
|
|
self.port = port
|
|
self.stopped = False
|
|
self.refreshsemaphore = threading.Semaphore(0)
|
|
self.peers = {}
|
|
|
|
self.serverthread = threading.Thread(target=self.run_peer_server, daemon=True)
|
|
self.refreshthread = threading.Thread(target=self.run_peer_subscribe, daemon=True)
|
|
self.serverthread.start()
|
|
self.refreshthread.start()
|
|
|
|
def resubscribe(self) -> None:
|
|
self.refreshsemaphore.release()
|
|
|
|
def join(self) -> None:
|
|
self.serverthread.join()
|
|
|
|
def stop(self) -> None:
|
|
self.stopped = True
|
|
self.server.shutdown()
|
|
self.refreshsemaphore.release()
|
|
self.refreshthread.join()
|
|
self.serverthread.join()
|
|
|
|
def run_peer_server(self) -> None:
|
|
print(f"Running peer server on {self.ip}:{self.port}")
|
|
with socketserver.ThreadingTCPServer((self.ip, self.port), CascadePeerServe) as server:
|
|
self.server = server
|
|
server.serve_forever(poll_interval=10)
|
|
|
|
def run_peer_subscribe(self) -> None:
|
|
while not self.stopped:
|
|
with CascadePeerServe.activefileslock:
|
|
hashes = list(CascadePeerServe.activefiles)
|
|
|
|
for h in hashes:
|
|
try:
|
|
self.peers[h] = self.tracker.subscribe(h, self.ipbytes, self.port)
|
|
except Exception as e:
|
|
print(f"Tracker register failed: {e}")
|
|
|
|
self.refreshsemaphore.acquire(timeout=60*10)
|
|
|
|
def run_peer_download(tracker: Tracker, source: str, localip: str, localport: int, server:P2PServer = None, output:str = None, randomseq:bool =True) -> None:
|
|
if not os.path.isfile(source):
|
|
print("File not found: {args.source}")
|
|
exit(1)
|
|
|
|
if output is None:
|
|
output = os.path.splitext(source)[0]
|
|
|
|
print(f"Preparing download of {source} to {output}")
|
|
file = CascadeFile(source, output)
|
|
|
|
if len(file.queue) != 0:
|
|
print(f"Download will require {len(file.queue)} blocks of size {file.blocksize}")
|
|
|
|
with CascadePeerServe.activefileslock:
|
|
CascadePeerServe.activefiles[file.cascadehash] = file
|
|
|
|
if server is not None:
|
|
server.resubscribe() # Trigger server subscription with new file
|
|
|
|
peers = []
|
|
last_peer_update = 0
|
|
while len(file.queue) > 0:
|
|
|
|
# Ensure we have peers
|
|
while len(peers) == 0 or time.time() - last_peer_update > 60*5:
|
|
if time.time() - last_peer_update < 30:
|
|
print(f'Throttling peer update; wait for 30s')
|
|
time.sleep(30)
|
|
|
|
try:
|
|
peers = tracker.list(file.cascadehash, localip, localport)
|
|
except Exception as e:
|
|
print(f"Tracker error: {e}")
|
|
finally:
|
|
last_peer_update = time.time()
|
|
|
|
|
|
if len(peers) == 0:
|
|
print(f"We have no peers, sleeping 10s before trying again")
|
|
time.sleep(10)
|
|
else:
|
|
print(f"We got {len(peers)} peer{'s' if len(peers) == 1 else ''}")
|
|
|
|
# Pick a block
|
|
blockid = 0 if not randomseq else random.randrange(0, len(file.queue))
|
|
peerid = 0 if not randomseq else random.randrange(0, len(peers))
|
|
block = file.queue[blockid]
|
|
peer = peers[peerid]
|
|
|
|
# Grab it
|
|
print(f"Attempting to fetch block {block.index} from {peer.ip}:{peer.port}")
|
|
print(f"Attempting to fetch block with hash {file.cascadehash}")
|
|
try:
|
|
data = peer.download_block(file.cascadehash, block.index, block.length)
|
|
if data is None:
|
|
raise Exception('Peer did not have the requested block')
|
|
|
|
datahash = hashlib.sha256(data).digest()
|
|
if datahash != block.hash:
|
|
print(f"Invalid hash for block {block.index} from {peer.ip}:{peer.port}. Got {datahash.hex()} but expected {block.hash.hex()}")
|
|
raise Exception('Downloaded block was incorrect')
|
|
|
|
with open(output, 'r+b') as f:
|
|
f.seek(block.offset)
|
|
f.write(data)
|
|
|
|
block.completed = True
|
|
file.queue.remove(block)
|
|
print(f"Retrieved block {block.index}")
|
|
|
|
except Exception as e:
|
|
peers.remove(peer)
|
|
print(f"Download failure ignoring peer. {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("source", help="The cascade source file(s) with information about the download")
|
|
parser.add_argument("tracker", help="The tracker ip and port, eg: localhost:8888", default="127.0.0.1:8888")
|
|
parser.add_argument("self", help="The address to report serving files from, e.g.: 1.2.3.4:5555", default="127.0.0.1:7777")
|
|
parser.add_argument("-o", "--output", help="The target output file(s)", required=False)
|
|
parser.add_argument("-r", "--random", help="Download blocks in random order from random peers", type=bool, required=False, default=True)
|
|
parser.add_argument("-c", "--clientonly", help="Flag to set client-only mode (i.e. no serving)", type=bool, required=False, default=False)
|
|
args = parser.parse_args()
|
|
|
|
tracker = Tracker(args.tracker.split(':')[0], int(args.tracker.split(':')[1]))
|
|
selfaddr, selfport = args.self.split(':')
|
|
selfport = int(selfport)
|
|
|
|
server = None
|
|
if not args.clientonly:
|
|
server = P2PServer(tracker, selfaddr, selfport)
|
|
|
|
source_files = args.source.split(os.pathsep)
|
|
if args.output is None:
|
|
target_files = []
|
|
else:
|
|
target_files = args.output.split(os.pathsep)
|
|
|
|
for ix, f in zip(range(len(source_files)), source_files):
|
|
target = None
|
|
if ix < len(target_files):
|
|
target = target_files[ix]
|
|
run_peer_download(tracker, f, selfaddr, selfport, server, target, args.random)
|
|
|
|
if args.clientonly:
|
|
print("Download complete in client mode, stopping")
|
|
else:
|
|
print("Download complete serving forever ...")
|
|
server.join()
|
|
|
|
|
|
|
|
|
|
|