Extra touch

This commit is contained in:
NikolajDanger
2023-05-29 19:51:01 +02:00
parent afa764ad67
commit 4723482dbc
2 changed files with 19 additions and 1 deletions

View File

@ -2,6 +2,8 @@ import sys
import socket
import threading
import tempfile
from os import unlink
from time import time
from typing import Any, Dict, List
@ -92,6 +94,8 @@ class NetworkMonitor(BaseMonitor):
implemented by any child process. Depending on the nature of the
monitor, this may wish to directly call apply_retroactive_rules before
starting."""
self.temp_files = []
self.ports = set(
rule.pattern.triggering_port for rule in self._rules.values()
)
@ -105,6 +109,7 @@ class NetworkMonitor(BaseMonitor):
self._rules_lock.acquire()
try:
self.temp_files.append(event["tmp file"])
for rule in self._rules.values():
# Match event port against rule ports
hit = event["triggering port"]
@ -129,12 +134,19 @@ class NetworkMonitor(BaseMonitor):
self._rules_lock.release()
def _delete_temp_files(self):
for file in self.temp_files:
unlink(file)
self.temp_files = []
def stop(self)->None:
"""Function to stop the monitor as an ongoing process/thread. Must be
implemented by any child process"""
for listener in self.listeners:
listener.stop()
self._delete_temp_files()
def _is_valid_recipes(self, recipes:Dict[str,BaseRecipe])->None:
"""Validation check for 'recipes' variable from main constructor. Is
automatically called during initialisation."""

View File

@ -1164,6 +1164,7 @@ class NetworkMonitorTests(unittest.TestCase):
def testNetworkMonitorEventIdentification(self)->None:
localhost = "127.0.0.1"
port = 8181
test_packet = b'test'
from_monitor_reader, from_monitor_writer = Pipe()
@ -1191,7 +1192,7 @@ class NetworkMonitorTests(unittest.TestCase):
sender = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sender.connect((localhost,port))
sender.sendall(b'test')
sender.sendall(test_packet)
sender.close()
if from_monitor_reader.poll(3):
@ -1213,6 +1214,11 @@ class NetworkMonitorTests(unittest.TestCase):
self.assertEqual(event[TRIGGERING_PORT], port)
self.assertEqual(event[EVENT_RULE].name, rule.name)
with open(event[EVENT_PATH], "rb") as file_pointer:
received_packet = file_pointer.read()
self.assertEqual(received_packet, test_packet)
monitor.stop()
# Test NetworkMonitor get_patterns function