Files
Compsys-2021-Assignments/A3/python/tracker/tracker.py
Nikolaj 57c3e9f626 :yay:
2021-11-05 13:30:55 +01:00

195 lines
7.1 KiB
Python

#!/bin/python3
import os
import argparse
import struct
import socketserver
import json
import time
import random
import threading
import ipaddress
HEADER_LENGTH=16
HASH_LENGTH=32
CMD_LENGTH=HASH_LENGTH + 4 + 2
IP_LENGTH=4
PORT_LENGTH=2
MAX_PEERS_REPORTED=20
MAX_PEER_AGE_SECONDS=1800
def is_good_client(ip, goods):
return len([x for x in goods if ip == x]) > 0
class TrackerServer(socketserver.StreamRequestHandler):
active_clients = {}
lock = threading.Lock()
last_cleanup = 0
def handle(self):
# self.request is the TCP socket connected to the client
header = self.request.recv(HEADER_LENGTH)
if len(header) != HEADER_LENGTH:
self.respond_error(f'Header must be at least {HEADER_LENGTH} bytes, got {len(header)}')
return
if header[:4] != bytes('CASC', 'ascii'):
self.respond_error(f'Header must start with CASC ({bytes("CASC", "ascii").hex()}), got: {header[:4].hex()}')
return
protocol_version = struct.unpack('!I', header[4:8])[0]
command = struct.unpack('!I', header[8:12])[0]
datalen = struct.unpack('!I', header[12:16])[0]
if protocol_version != 1:
self.respond_error(f'Protocol version must be 1, got: {protocol_version}')
return
if command == 1 or command == 2:
if datalen != CMD_LENGTH:
self.respond_error(f'Data for command=1 must be of length {CMD_LENGTH}, got {datalen}')
return
data = self.request.recv(CMD_LENGTH)
if len(data) != CMD_LENGTH:
self.respond_error(f'Data from socket had length {len(data)}, expected {CMD_LENGTH}')
return
hash = data[:32]
hashhex = hash.hex()
ipportkey = data[32:].hex()
with self.lock:
hashfound = hashhex in self.active_clients
if not hashfound:
self.respond_error(f"The supplied hash '{hashhex}' is not tracked by this tracker")
return
if command == 2:
# For local testing you may need to remove this check
# if data[32:36] == b'\0\0\0\0':
# self.respond_error(f'The IP {({".".join([str(x) for x in data[32:36]])})} is not supported')
# return
if self.banprivateips and ipaddress.IPv4Address(data[32:36]).is_private:
self.respond_error(f'The IP reported to the tracker ({".".join([str(x) for x in data[32:36]])}) appears to be a private IP')
return
# Avoid peers registering as good peers without approval
if is_good_client(data[32:36], self.goodips):
if str(ipaddress.IPv4Address(data[32:36])) == self.client_address[0]:
print(f"Registering a good client at {ipaddress.IPv4Address(data[32:36])}:{int.from_bytes(data[36:38],'big')}")
else:
self.respond_error(f'The IP reported to the tracker ({".".join([str(x) for x in data[32:36]])}) is a reserved IP')
return
else:
print("Not a good client")
with self.lock:
self.active_clients[hashhex][ipportkey] = int(time.time())
print(f"Reporting clients to {ipaddress.IPv4Address(data[32:36])}:{int.from_bytes(data[36:38],'big')}")
self.report_clients(hash.hex(), ipportkey)
else:
self.respond_error(f'Only commands (list=1, subscribe=2) supported, got: {command}')
return
def report_clients(self, hashhex, selfipport=None):
now = int(time.time())
with self.lock:
# Perform cleanup, if required
if self.last_cleanup == 0:
self.last_cleanup = now
elif now - self.last_cleanup > (MAX_PEER_AGE_SECONDS * 2):
self.last_cleanup = now
for h in self.active_clients:
for x in list([n for n in h if now - n > MAX_PEER_AGE_SECONDS]):
del h[x]
# Build list of peers
d = self.active_clients[hashhex]
peers = list([(
bytes.fromhex(x)[:IP_LENGTH], # Ip is tupple item 0
bytes.fromhex(x)[IP_LENGTH:IP_LENGTH+PORT_LENGTH], # Port is tupple item 1
d[x], # Timestamp is tupple item 2
is_good_client(bytes.fromhex(x)[:IP_LENGTH], self.goodips) # Good flag is tupple item 3
) for x in d if now - d[x] < MAX_PEER_AGE_SECONDS and x != selfipport])
# Random order
random.shuffle(peers)
peers.sort(key=lambda x: x[3] != 1) # Keep good ones at top
# Build response
result = bytearray()
for n,ix in zip(peers, range(MAX_PEERS_REPORTED)):
result.extend(n[0]) # ip
result.extend(n[1]) # port
result.extend(struct.pack('!l', n[2])) #timestamp
result.extend(struct.pack('!B', 1 if n[3] else 0)) # Good flag
result.extend(struct.pack('!B', 0)) # reserved
header = bytearray()
header.extend(struct.pack('!B', 0))
header.extend(struct.pack('!I', len(result)))
header.extend(result)
self.request.sendall(header)
def respond_error(self, msg):
msgdata = bytes(msg, 'utf-8')
data = bytearray(struct.pack('!B', 1))
data.extend(struct.pack('!I', len(msgdata)))
data.extend(msgdata)
self.request.sendall(data)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config", help="The configuration file")
args = parser.parse_args()
if not os.path.isfile(args.config):
print(f"Config file not found: {args.config}")
with open(args.config) as jf:
config = json.load(jf)
if not "port" in config:
print("Config is missing port")
exit(1)
if not "allowedhashes" in config:
print("Config is missing allowedhashes")
exit(1)
if len(config["allowedhashes"]) <= 0:
print("Need at least one allowed hash")
exit(1)
if not "good-ips" in config:
config["good-ips"] = []
HOST, PORT = "0.0.0.0", int(config['port'])
hashes = list(bytes.fromhex(x) for x in config['allowedhashes'])
invalids = [x for x in hashes if len(x) != 32]
if len(invalids) > 0:
print(f"Invalid hashes: {', '.join([x.hex() for x in invalids])}")
exit(1)
goodips = list([bytes(map(int, ip.split('.'))) for ip in config["good-ips"]])
invalids = [x for x in goodips if len(x) != 4]
if len(invalids) > 0:
print(f"Invalid IPs: {', '.join(['.'.join([str(x) for x in ip]) for ip in invalids])}")
exit(1)
TrackerServer.active_clients = dict(zip([x.hex() for x in hashes], [{} for x in range(len(hashes))]))
TrackerServer.goodips = goodips
TrackerServer.banprivateips = False
if 'ban-private' in config:
TrackerServer.banprivateips = config['ban-private']
with socketserver.ThreadingTCPServer((HOST, PORT), TrackerServer) as server:
server.serve_forever()