HTTP implementation

This commit is contained in:
NikolajDanger
2023-06-18 13:17:40 +02:00
parent 7534eed13a
commit 904471004d
3 changed files with 52 additions and 6 deletions

View File

@ -4,10 +4,13 @@ import threading
import tempfile import tempfile
import hashlib import hashlib
from os import unlink from os import unlink
from http.server import HTTPServer
from time import time from time import time
from typing import Any, Dict, List from typing import Any, Dict, List
from core.base_recipe import BaseRecipe
from http_parser.parser import HttpParser
from meow_base.functionality.validation import valid_string, valid_dict from meow_base.functionality.validation import valid_string, valid_dict
from meow_base.core.vars import VALID_RECIPE_NAME_CHARS, VALID_VARIABLE_NAME_CHARS, DEBUG_INFO from meow_base.core.vars import VALID_RECIPE_NAME_CHARS, VALID_VARIABLE_NAME_CHARS, DEBUG_INFO
@ -25,6 +28,8 @@ TRIGGERING_PORT = "triggering_port"
NETWORK_EVENT_KEYS = { NETWORK_EVENT_KEYS = {
TRIGGERING_PORT: int, TRIGGERING_PORT: int,
WATCHDOG_HASH: str,
WATCHDOG_BASE: str,
**EVENT_KEYS **EVENT_KEYS
} }
@ -91,6 +96,8 @@ class NetworkMonitor(BaseMonitor):
self._print_target, self.debug_level = setup_debugging(print, logging) self._print_target, self.debug_level = setup_debugging(print, logging)
self.ports = set() self.ports = set()
self.listeners = [] self.listeners = []
if not hasattr(self, "listener_type"):
self.listener_type = Listener
if autostart: if autostart:
self.start() self.start()
@ -104,7 +111,9 @@ class NetworkMonitor(BaseMonitor):
self.ports = set( self.ports = set(
rule.pattern.triggering_port for rule in self._rules.values() rule.pattern.triggering_port for rule in self._rules.values()
) )
self.listeners = [Listener("127.0.0.1",i,2048,self) for i in self.ports] self.listeners = [
self.listener_type("127.0.0.1",i,2048,self) for i in self.ports
]
for listener in self.listeners: for listener in self.listeners:
listener.start() listener.start()
@ -200,8 +209,7 @@ class Listener():
self.socket.close() self.socket.close()
def receive_data(self,conn):
def handle_event(self, conn, time_stamp):
with conn: with conn:
with tempfile.NamedTemporaryFile("wb", delete=False) as tmp: with tempfile.NamedTemporaryFile("wb", delete=False) as tmp:
while True: while True:
@ -212,6 +220,11 @@ class Listener():
tmp_name = tmp.name tmp_name = tmp.name
return tmp_name
def handle_event(self, conn, time_stamp):
tmp_name = self.receive_data(conn)
with open(tmp_name, "rb") as file_pointer: with open(tmp_name, "rb") as file_pointer:
file_hash = hashlib.sha256(file_pointer.read()).hexdigest() file_hash = hashlib.sha256(file_pointer.read()).hexdigest()
@ -225,3 +238,35 @@ class Listener():
def stop(self): def stop(self):
self._stopped = True self._stopped = True
class HTTPMonitor(NetworkMonitor):
def __init__(self, patterns: Dict[str, NetworkEventPattern],
recipes: Dict[str, BaseRecipe], autostart=False,
name: str = "", print: Any = sys.stdout, logging: int = 0) -> None:
self.listener_type = HTTPListener()
super().__init__(patterns, recipes, autostart, name, print, logging)
class HTTPListener(Listener):
def receive_data(self,conn):
parser = HttpParser()
with conn:
with tempfile.NamedTemporaryFile("wb", delete=False) as tmp:
while True:
data = conn.recv(self.buff_size)
if not data:
break
received = len(data)
parsed = parser.execute(data, received)
assert parsed == received
if parser.is_partial_body():
tmp.write(parser.recv_body())
if parser.is_message_complete():
break
tmp_name = tmp.name
return tmp_name

View File

@ -89,8 +89,8 @@ def sigfigs(num):
def main(): def main():
monitors = 1 monitors = 1
patterns = 1000 patterns = 1
events = 1 events = 10_000
n = 100 n = 100

View File

@ -1117,7 +1117,7 @@ class NetworkMonitorTests(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
event = create_network_event("path", rule) event = create_network_event("path", rule)
event = create_network_event("path", rule, time(), 8181) event = create_network_event("path", rule, time(), 8181, "hash")
self.assertEqual(type(event), dict) self.assertEqual(type(event), dict)
self.assertEqual(len(event.keys()), len(NETWORK_EVENT_KEYS)) self.assertEqual(len(event.keys()), len(NETWORK_EVENT_KEYS))
@ -1134,6 +1134,7 @@ class NetworkMonitorTests(unittest.TestCase):
rule, rule,
time(), time(),
8182, 8182,
"hash",
extras={"a":1} extras={"a":1}
) )