diff --git a/patterns/network_event_pattern.py b/patterns/network_event_pattern.py index 500eb38..11932a6 100644 --- a/patterns/network_event_pattern.py +++ b/patterns/network_event_pattern.py @@ -2,6 +2,8 @@ import sys import socket import threading import tempfile +from os import unlink + from time import time from typing import Any, Dict, List @@ -92,6 +94,8 @@ class NetworkMonitor(BaseMonitor): implemented by any child process. Depending on the nature of the monitor, this may wish to directly call apply_retroactive_rules before starting.""" + self.temp_files = [] + self.ports = set( rule.pattern.triggering_port for rule in self._rules.values() ) @@ -105,6 +109,7 @@ class NetworkMonitor(BaseMonitor): self._rules_lock.acquire() try: + self.temp_files.append(event["tmp file"]) for rule in self._rules.values(): # Match event port against rule ports hit = event["triggering port"] @@ -129,12 +134,19 @@ class NetworkMonitor(BaseMonitor): self._rules_lock.release() + def _delete_temp_files(self): + for file in self.temp_files: + unlink(file) + self.temp_files = [] + def stop(self)->None: """Function to stop the monitor as an ongoing process/thread. Must be implemented by any child process""" for listener in self.listeners: listener.stop() + self._delete_temp_files() + def _is_valid_recipes(self, recipes:Dict[str,BaseRecipe])->None: """Validation check for 'recipes' variable from main constructor. Is automatically called during initialisation.""" diff --git a/tests/test_patterns.py b/tests/test_patterns.py index b3fd90e..309a3c5 100644 --- a/tests/test_patterns.py +++ b/tests/test_patterns.py @@ -1164,6 +1164,7 @@ class NetworkMonitorTests(unittest.TestCase): def testNetworkMonitorEventIdentification(self)->None: localhost = "127.0.0.1" port = 8181 + test_packet = b'test' from_monitor_reader, from_monitor_writer = Pipe() @@ -1191,7 +1192,7 @@ class NetworkMonitorTests(unittest.TestCase): sender = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sender.connect((localhost,port)) - sender.sendall(b'test') + sender.sendall(test_packet) sender.close() if from_monitor_reader.poll(3): @@ -1213,6 +1214,11 @@ class NetworkMonitorTests(unittest.TestCase): self.assertEqual(event[TRIGGERING_PORT], port) self.assertEqual(event[EVENT_RULE].name, rule.name) + with open(event[EVENT_PATH], "rb") as file_pointer: + received_packet = file_pointer.read() + + self.assertEqual(received_packet, test_packet) + monitor.stop() # Test NetworkMonitor get_patterns function