From ea9a689b2661bb083a2c52a31291154a5c129e50 Mon Sep 17 00:00:00 2001 From: PatchOfScotland Date: Thu, 15 Dec 2022 11:31:51 +0100 Subject: [PATCH] added support for multi-type waiting plus some cleanup --- core/correctness/validation.py | 71 +++++--- core/correctness/vars.py | 6 + core/functionality.py | 29 +++- core/meow.py | 123 +++++++------ patterns/file_event_pattern.py | 22 +-- recipes/jupyter_notebook_recipe.py | 74 ++++++-- rules/file_event_jupyter_notebook_rule.py | 6 +- tests/test_functionality.py | 202 +++++++++++++++++++++- tests/test_meow.py | 16 +- tests/test_patterns.py | 18 +- tests/test_recipes.py | 38 +++- tests/test_validation.py | 84 +++++++-- 12 files changed, 516 insertions(+), 173 deletions(-) diff --git a/core/correctness/validation.py b/core/correctness/validation.py index 7d1ef1e..21143cc 100644 --- a/core/correctness/validation.py +++ b/core/correctness/validation.py @@ -1,10 +1,11 @@ +from inspect import signature from os.path import sep -from typing import Any, _SpecialForm +from typing import Any, _SpecialForm, Union, get_origin, get_args -from core.correctness.vars import VALID_PATH_CHARS +from core.correctness.vars import VALID_PATH_CHARS, get_not_imp_msg -def check_input(variable:Any, expected_type:type, alt_types:list[type]=[], +def check_type(variable:Any, expected_type:type, alt_types:list[type]=[], or_none:bool=False)->None: """ Checks if a given variable is of the expected type. Raises TypeError or @@ -24,23 +25,38 @@ def check_input(variable:Any, expected_type:type, alt_types:list[type]=[], """ type_list = [expected_type] + if get_origin(expected_type) is Union: + type_list = list(get_args(expected_type)) type_list = type_list + alt_types - if not or_none: - if expected_type != Any \ - and type(variable) not in type_list: + if variable is None: + if or_none == False: raise TypeError( - 'Expected type was %s, got %s' - % (expected_type, type(variable)) - ) - else: - if expected_type != Any \ - and not type(variable) not in type_list \ - and not isinstance(variable, type(None)): - raise TypeError( - 'Expected type was %s or None, got %s' - % (expected_type, type(variable)) + f'Not allowed None for variable. Expected {expected_type}.' ) + else: + return + + if expected_type == Any: + return + + if not isinstance(variable, tuple(type_list)): + raise TypeError( + 'Expected type(s) are %s, got %s' + % (get_args(expected_type), type(variable)) + ) + +def check_implementation(child_func, parent_class): + parent_func = getattr(parent_class, child_func.__name__) + if (child_func == parent_func): + msg = get_not_imp_msg(parent_class, parent_func) + raise NotImplementedError(msg) + child_sig = signature(child_func).parameters + parent_sig = signature(parent_func).parameters + + if child_sig.keys() != parent_sig.keys(): + msg = get_not_imp_msg(parent_class, parent_func) + raise NotImplementedError(msg) def valid_string(variable:str, valid_chars:str, min_length:int=1)->None: """ @@ -56,8 +72,8 @@ def valid_string(variable:str, valid_chars:str, min_length:int=1)->None: :return: No return. """ - check_input(variable, str) - check_input(valid_chars, str) + check_type(variable, str) + check_type(valid_chars, str) if len(variable) < min_length: raise ValueError ( @@ -74,12 +90,12 @@ def valid_string(variable:str, valid_chars:str, min_length:int=1)->None: def valid_dict(variable:dict[Any, Any], key_type:type, value_type:type, required_keys:list[Any]=[], optional_keys:list[Any]=[], strict:bool=True, min_length:int=1)->None: - check_input(variable, dict) - check_input(key_type, type, alt_types=[_SpecialForm]) - check_input(value_type, type, alt_types=[_SpecialForm]) - check_input(required_keys, list) - check_input(optional_keys, list) - check_input(strict, bool) + check_type(variable, dict) + check_type(key_type, type, alt_types=[_SpecialForm]) + check_type(value_type, type, alt_types=[_SpecialForm]) + check_type(required_keys, list) + check_type(optional_keys, list) + check_type(strict, bool) if len(variable) < min_length: raise ValueError(f"Dictionary '{variable}' is below minimum length of " @@ -106,18 +122,17 @@ def valid_dict(variable:dict[Any, Any], key_type:type, value_type:type, def valid_list(variable:list[Any], entry_type:type, alt_types:list[type]=[], min_length:int=1)->None: - check_input(variable, list) + check_type(variable, list) if len(variable) < min_length: raise ValueError(f"List '{variable}' is too short. Should be at least " f"of length {min_length}") for entry in variable: - check_input(entry, entry_type, alt_types=alt_types) + check_type(entry, entry_type, alt_types=alt_types) def valid_path(variable:str, allow_base=False, extension:str="", min_length=1): valid_string(variable, VALID_PATH_CHARS, min_length=min_length) if not allow_base and variable.startswith(sep): raise ValueError(f"Cannot accept path '{variable}'. Must be relative.") - if min_length > 0 and extension and not variable.endswith(extension): + if extension and not variable.endswith(extension): raise ValueError(f"Path '{variable}' does not have required " f"extension '{extension}'.") - diff --git a/core/correctness/vars.py b/core/correctness/vars.py index a532b6d..9df02d3 100644 --- a/core/correctness/vars.py +++ b/core/correctness/vars.py @@ -1,8 +1,12 @@ import os +from multiprocessing import Queue +from multiprocessing.connection import Connection from inspect import signature +from typing import Union + CHAR_LOWERCASE = 'abcdefghijklmnopqrstuvwxyz' CHAR_UPPERCASE = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' CHAR_NUMERIC = '0123456789' @@ -20,6 +24,8 @@ VALID_JUPYTER_NOTEBOOK_EXTENSIONS = [".ipynb"] VALID_PATH_CHARS = VALID_NAME_CHARS + "." + os.path.sep VALID_TRIGGERING_PATH_CHARS = VALID_NAME_CHARS + ".*" + os.path.sep +VALID_CHANNELS = Union[Connection,Queue] + BAREBONES_NOTEBOOK = { "cells": [], "metadata": {}, diff --git a/core/functionality.py b/core/functionality.py index 124d7af..6aaacc9 100644 --- a/core/functionality.py +++ b/core/functionality.py @@ -2,15 +2,15 @@ import sys import inspect +from multiprocessing.connection import Connection, wait as multi_wait +from multiprocessing.queues import Queue from typing import Union from random import SystemRandom from core.meow import BasePattern, BaseRecipe, BaseRule -from core.correctness.validation import check_input, valid_dict, valid_list -from core.correctness.vars import CHAR_LOWERCASE, CHAR_UPPERCASE -from patterns import * -from recipes import * -from rules import * +from core.correctness.validation import check_type, valid_dict, valid_list +from core.correctness.vars import CHAR_LOWERCASE, CHAR_UPPERCASE, \ + VALID_CHANNELS def check_pattern_dict(patterns, min_length=1): valid_dict(patterns, str, BasePattern, strict=False, min_length=min_length) @@ -42,8 +42,8 @@ def generate_id(prefix:str="", length:int=16, existing_ids:list[str]=[], def create_rules(patterns:Union[dict[str,BasePattern],list[BasePattern]], recipes:Union[dict[str,BaseRecipe],list[BaseRecipe]], new_rules:list[BaseRule]=[])->dict[str,BaseRule]: - check_input(patterns, dict, alt_types=[list]) - check_input(recipes, dict, alt_types=[list]) + check_type(patterns, dict, alt_types=[list]) + check_type(recipes, dict, alt_types=[list]) valid_list(new_rules, BaseRule, min_length=0) if isinstance(patterns, list): @@ -58,8 +58,9 @@ def create_rules(patterns:Union[dict[str,BasePattern],list[BasePattern]], else: check_recipe_dict(recipes, min_length=0) + # Imported here to avoid circular imports at top of file + import rules rules = {} - all_rules ={(r.pattern_type, r.recipe_type):r for r in [r[1] \ for r in inspect.getmembers(sys.modules["rules"], inspect.isclass) \ if (issubclass(r[1], BaseRule))]} @@ -75,4 +76,14 @@ def create_rules(patterns:Union[dict[str,BasePattern],list[BasePattern]], recipes[pattern.recipe] ) rules[rule.name] = rule - return rules \ No newline at end of file + return rules + +def wait(inputs:list[VALID_CHANNELS])->list[VALID_CHANNELS]: + all_connections = [i for i in inputs if type(i) is Connection] \ + + [i._reader for i in inputs if type(i) is Queue] + + ready = multi_wait(all_connections) + ready_inputs = [i for i in inputs if \ + (type(i) is Connection and i in ready) \ + or (type(i) is Queue and i._reader in ready)] + return ready_inputs diff --git a/core/meow.py b/core/meow.py index 66465a2..a5769d2 100644 --- a/core/meow.py +++ b/core/meow.py @@ -1,11 +1,11 @@ -from multiprocessing.connection import Connection from typing import Any from core.correctness.vars import VALID_RECIPE_NAME_CHARS, \ - VALID_PATTERN_NAME_CHARS, VALID_RULE_NAME_CHARS, \ - get_not_imp_msg, get_drt_imp_msg -from core.correctness.validation import valid_string + VALID_PATTERN_NAME_CHARS, VALID_RULE_NAME_CHARS, VALID_CHANNELS, \ + get_drt_imp_msg +from core.correctness.validation import valid_string, check_type, \ + check_implementation class BaseRecipe: @@ -15,15 +15,9 @@ class BaseRecipe: requirements:dict[str, Any] def __init__(self, name:str, recipe:Any, parameters:dict[str,Any]={}, requirements:dict[str,Any]={}): - if (type(self)._is_valid_recipe == BaseRecipe._is_valid_recipe): - msg = get_not_imp_msg(BaseRecipe, BaseRecipe._is_valid_recipe) - raise NotImplementedError(msg) - if (type(self)._is_valid_parameters == BaseRecipe._is_valid_parameters): - msg = get_not_imp_msg(BaseRecipe, BaseRecipe._is_valid_parameters) - raise NotImplementedError(msg) - if (type(self)._is_valid_requirements == BaseRecipe._is_valid_requirements): - msg = get_not_imp_msg(BaseRecipe, BaseRecipe._is_valid_requirements) - raise NotImplementedError(msg) + check_implementation(type(self)._is_valid_recipe, BaseRecipe) + check_implementation(type(self)._is_valid_parameters, BaseRecipe) + check_implementation(type(self)._is_valid_requirements, BaseRecipe) self._is_valid_name(name) self.name = name self._is_valid_recipe(recipe) @@ -55,19 +49,13 @@ class BaseRecipe: class BasePattern: name:str recipe:str - parameters:dict[str, Any] - outputs:dict[str, Any] + parameters:dict[str,Any] + outputs:dict[str,Any] def __init__(self, name:str, recipe:str, parameters:dict[str,Any]={}, outputs:dict[str,Any]={}): - if (type(self)._is_valid_recipe == BasePattern._is_valid_recipe): - msg = get_not_imp_msg(BasePattern, BasePattern._is_valid_recipe) - raise NotImplementedError(msg) - if (type(self)._is_valid_parameters == BasePattern._is_valid_parameters): - msg = get_not_imp_msg(BasePattern, BasePattern._is_valid_parameters) - raise NotImplementedError(msg) - if (type(self)._is_valid_output == BasePattern._is_valid_output): - msg = get_not_imp_msg(BasePattern, BasePattern._is_valid_output) - raise NotImplementedError(msg) + check_implementation(type(self)._is_valid_recipe, BasePattern) + check_implementation(type(self)._is_valid_parameters, BasePattern) + check_implementation(type(self)._is_valid_output, BasePattern) self._is_valid_name(name) self.name = name self._is_valid_recipe(recipe) @@ -103,13 +91,8 @@ class BaseRule: pattern_type:str="" recipe_type:str="" def __init__(self, name:str, pattern:BasePattern, recipe:BaseRecipe): - if (type(self)._is_valid_pattern == BaseRule._is_valid_pattern): - msg = get_not_imp_msg(BaseRule, BaseRule._is_valid_pattern) - raise NotImplementedError(msg) - if (type(self)._is_valid_recipe == BaseRule._is_valid_recipe): - msg = get_not_imp_msg(BaseRule, BaseRule._is_valid_recipe) - raise NotImplementedError(msg) - + check_implementation(type(self)._is_valid_pattern, BaseRule) + check_implementation(type(self)._is_valid_recipe, BaseRule) self._is_valid_name(name) self.name = name self._is_valid_pattern(pattern) @@ -144,29 +127,15 @@ class BaseRule: class BaseMonitor: rules: dict[str, BaseRule] - report: Connection - listen: Connection - def __init__(self, rules:dict[str, BaseRule], report:Connection, - listen:Connection) -> None: - if (type(self).start == BaseMonitor.start): - msg = get_not_imp_msg(BaseMonitor, BaseMonitor.start) - raise NotImplementedError(msg) - if (type(self).stop == BaseMonitor.stop): - msg = get_not_imp_msg(BaseMonitor, BaseMonitor.stop) - raise NotImplementedError(msg) - if (type(self)._is_valid_report == BaseMonitor._is_valid_report): - msg = get_not_imp_msg(BaseMonitor, BaseMonitor._is_valid_report) - raise NotImplementedError(msg) + report: VALID_CHANNELS + def __init__(self, rules:dict[str,BaseRule], + report:VALID_CHANNELS)->None: + check_implementation(type(self).start, BaseMonitor) + check_implementation(type(self).stop, BaseMonitor) + check_implementation(type(self)._is_valid_report, BaseMonitor) + check_implementation(type(self)._is_valid_rules, BaseMonitor) self._is_valid_report(report) self.report = report - if (type(self)._is_valid_listen == BaseMonitor._is_valid_listen): - msg = get_not_imp_msg(BaseMonitor, BaseMonitor._is_valid_listen) - raise NotImplementedError(msg) - self._is_valid_listen(listen) - self.listen = listen - if (type(self)._is_valid_rules == BaseMonitor._is_valid_rules): - msg = get_not_imp_msg(BaseMonitor, BaseMonitor._is_valid_rules) - raise NotImplementedError(msg) self._is_valid_rules(rules) self.rules = rules @@ -176,13 +145,10 @@ class BaseMonitor: raise TypeError(msg) return object.__new__(cls) - def _is_valid_report(self, report:Connection)->None: + def _is_valid_report(self, report:VALID_CHANNELS)->None: pass - def _is_valid_listen(self, listen:Connection)->None: - pass - - def _is_valid_rules(self, rules:dict[str, BaseRule])->None: + def _is_valid_rules(self, rules:dict[str,BaseRule])->None: pass def start(self)->None: @@ -194,13 +160,11 @@ class BaseMonitor: class BaseHandler: inputs:Any - def __init__(self, inputs:Any) -> None: - if (type(self).handle == BaseHandler.handle): - msg = get_not_imp_msg(BaseHandler, BaseHandler.handle) - raise NotImplementedError(msg) - if (type(self)._is_valid_inputs == BaseHandler._is_valid_inputs): - msg = get_not_imp_msg(BaseHandler, BaseHandler._is_valid_inputs) - raise NotImplementedError(msg) + def __init__(self, inputs:list[VALID_CHANNELS]) -> None: + check_implementation(type(self).start, BaseHandler) + check_implementation(type(self).stop, BaseHandler) + check_implementation(type(self).handle, BaseHandler) + check_implementation(type(self)._is_valid_inputs, BaseHandler) self._is_valid_inputs(inputs) self.inputs = inputs @@ -213,5 +177,34 @@ class BaseHandler: def _is_valid_inputs(self, inputs:Any)->None: pass - def handle()->None: + def handle(self, event:Any, rule:BaseRule)->None: pass + + def start(self)->None: + pass + + def stop(self)->None: + pass + + +# TODO test me +class MeowRunner: + monitor:BaseMonitor + handler:BaseHandler + def __init__(self, monitor:BaseMonitor, handler:BaseHandler) -> None: + self._is_valid_monitor(monitor) + self.monitor = monitor + self._is_valid_handler(handler) + self.handler = handler + + def start(self)->None: + self.monitor.start() + + def stop(self)->None: + self.monitor.stop() + + def _is_valid_monitor(self, monitor:BaseMonitor)->None: + check_type(monitor, BaseMonitor) + + def _is_valid_handler(self, handler:BaseHandler)->None: + check_type(handler, BaseHandler) diff --git a/patterns/file_event_pattern.py b/patterns/file_event_pattern.py index 051befe..bbc4982 100644 --- a/patterns/file_event_pattern.py +++ b/patterns/file_event_pattern.py @@ -3,7 +3,6 @@ import threading import os from fnmatch import translate -from multiprocessing.connection import Connection from re import match from time import time, sleep from typing import Any @@ -12,13 +11,13 @@ from watchdog.events import PatternMatchingEventHandler, FileCreatedEvent, \ FileModifiedEvent, FileMovedEvent, FileClosedEvent, FileDeletedEvent, \ DirCreatedEvent, DirDeletedEvent, DirModifiedEvent, DirMovedEvent -from core.correctness.validation import check_input, valid_string, \ +from core.correctness.validation import check_type, valid_string, \ valid_dict, valid_list, valid_path from core.correctness.vars import VALID_RECIPE_NAME_CHARS, \ VALID_VARIABLE_NAME_CHARS, FILE_EVENTS, FILE_CREATE_EVENT, \ FILE_MODIFY_EVENT, FILE_MOVED_EVENT, FILE_CLOSED_EVENT, \ FILE_DELETED_EVENT, DIR_CREATE_EVENT, DIR_DELETED_EVENT, \ - DIR_MODIFY_EVENT, DIR_MOVED_EVENT + DIR_MODIFY_EVENT, DIR_MOVED_EVENT, VALID_CHANNELS from core.meow import BasePattern, BaseMonitor, BaseRule _EVENT_TRANSLATIONS = { @@ -88,14 +87,14 @@ class WatchdogMonitor(BaseMonitor): _rules_lock:threading.Lock def __init__(self, base_dir:str, rules:dict[str, BaseRule], - report:Connection, listen:Connection, autostart=False, + report:VALID_CHANNELS, autostart=False, settletime:int=1)->None: - super().__init__(rules, report, listen) + super().__init__(rules, report) self._is_valid_base_dir(base_dir) self.base_dir = base_dir - check_input(settletime, int) + check_type(settletime, int) self._rules_lock = threading.Lock() - self.event_handler = MEOWEventHandler(self, settletime=settletime) + self.event_handler = WatchdogEventHandler(self, settletime=settletime) self.monitor = Observer() self.monitor.schedule( self.event_handler, @@ -147,17 +146,14 @@ class WatchdogMonitor(BaseMonitor): def _is_valid_base_dir(self, base_dir:str)->None: valid_path(base_dir) - def _is_valid_report(self, report:Connection)->None: - check_input(report, Connection) - - def _is_valid_listen(self, listen:Connection)->None: - check_input(listen, Connection) + def _is_valid_report(self, report:VALID_CHANNELS)->None: + check_type(report, VALID_CHANNELS) def _is_valid_rules(self, rules:dict[str, BaseRule])->None: valid_dict(rules, str, BaseRule, min_length=0, strict=False) -class MEOWEventHandler(PatternMatchingEventHandler): +class WatchdogEventHandler(PatternMatchingEventHandler): monitor:WatchdogMonitor _settletime:int _recent_jobs:dict[str, Any] diff --git a/recipes/jupyter_notebook_recipe.py b/recipes/jupyter_notebook_recipe.py index c7b5dec..f3a22ff 100644 --- a/recipes/jupyter_notebook_recipe.py +++ b/recipes/jupyter_notebook_recipe.py @@ -1,13 +1,15 @@ import nbformat +import threading +from multiprocessing import Pipe from typing import Any -from core.correctness.validation import check_input, valid_string, \ - valid_dict, valid_path -from core.correctness.vars import VALID_JUPYTER_NOTEBOOK_FILENAME_CHARS, \ - VALID_JUPYTER_NOTEBOOK_EXTENSIONS, VALID_VARIABLE_NAME_CHARS -from core.meow import BaseRecipe +from core.correctness.validation import check_type, valid_string, \ + valid_dict, valid_path, valid_list +from core.correctness.vars import VALID_VARIABLE_NAME_CHARS, VALID_CHANNELS +from core.functionality import wait +from core.meow import BaseRecipe, BaseHandler class JupyterNotebookRecipe(BaseRecipe): source:str @@ -18,21 +20,11 @@ class JupyterNotebookRecipe(BaseRecipe): self.source = source def _is_valid_source(self, source:str)->None: - valid_path(source, extension=".ipynb", min_length=0) - - if not source: - return - - matched = False - for i in VALID_JUPYTER_NOTEBOOK_EXTENSIONS: - if source.endswith(i): - matched = True - if not matched: - raise ValueError(f"source '{source}' does not end with a valid " - "jupyter notebook extension.") + if source: + valid_path(source, extension=".ipynb", min_length=0) def _is_valid_recipe(self, recipe:dict[str,Any])->None: - check_input(recipe, dict) + check_type(recipe, dict) nbformat.validate(recipe) def _is_valid_parameters(self, parameters:dict[str,Any])->None: @@ -44,3 +36,49 @@ class JupyterNotebookRecipe(BaseRecipe): valid_dict(requirements, str, Any, strict=False, min_length=0) for k in requirements.keys(): valid_string(k, VALID_VARIABLE_NAME_CHARS) + +class PapermillHandler(BaseHandler): + _worker:threading.Thread + _stop_pipe:Pipe + def __init__(self, inputs:list[VALID_CHANNELS])->None: + super().__init__(inputs) + self._worker = None + self._stop_pipe = Pipe() + + def run(self)->None: + all_inputs = self.inputs + [self._stop_pipe[0]] + while True: + ready = wait(all_inputs) + + if self._stop_pipe[0] in ready: + return + else: + for input in self.inputs: + if input in ready: + message = input.recv() + event, rule = message + self.handle(event, rule) + + def start(self)->None: + if self._worker is None: + self._worker = threading.Thread( + target=self.run, + args=[]) + self._worker.daemon = True + self._worker.start() + else: + raise RuntimeWarning("Repeated calls to start have no effect.") + + def stop(self)->None: + if self._worker is None: + raise RuntimeWarning("Cannot stop thread that is not started.") + else: + self._stop_pipe[1].send(1) + self._worker.join() + + def handle(self, event, rule)->None: + # TODO finish implementation and test + pass + + def _is_valid_inputs(self, inputs:list[VALID_CHANNELS])->None: + valid_list(inputs, VALID_CHANNELS) diff --git a/rules/file_event_jupyter_notebook_rule.py b/rules/file_event_jupyter_notebook_rule.py index 0a6b748..fc5c405 100644 --- a/rules/file_event_jupyter_notebook_rule.py +++ b/rules/file_event_jupyter_notebook_rule.py @@ -1,5 +1,5 @@ -from core.correctness.validation import check_input +from core.correctness.validation import check_type from core.meow import BaseRule from patterns.file_event_pattern import FileEventPattern from recipes.jupyter_notebook_recipe import JupyterNotebookRecipe @@ -16,7 +16,7 @@ class FileEventJupyterNotebookRule(BaseRule): f"uses {pattern.recipe}") def _is_valid_pattern(self, pattern:FileEventPattern) -> None: - check_input(pattern, FileEventPattern) + check_type(pattern, FileEventPattern) def _is_valid_recipe(self, recipe:JupyterNotebookRecipe) -> None: - check_input(recipe, JupyterNotebookRecipe) + check_type(recipe, JupyterNotebookRecipe) diff --git a/tests/test_functionality.py b/tests/test_functionality.py index 76215bd..8f6b12c 100644 --- a/tests/test_functionality.py +++ b/tests/test_functionality.py @@ -1,9 +1,13 @@ import unittest +from multiprocessing import Pipe, Queue +from time import sleep + from core.correctness.vars import CHAR_LOWERCASE, CHAR_UPPERCASE, \ BAREBONES_NOTEBOOK -from core.functionality import create_rules, generate_id +from core.functionality import create_rules, generate_id, wait, \ + check_pattern_dict, check_recipe_dict from core.meow import BaseRule from patterns.file_event_pattern import FileEventPattern from recipes.jupyter_notebook_recipe import JupyterNotebookRecipe @@ -87,3 +91,199 @@ class CorrectnessTests(unittest.TestCase): } with self.assertRaises(KeyError): create_rules({}, recipes) + + def testCheckPatternDictValid(self)->None: + fep1 = FileEventPattern("name_one", "path", "recipe", "file") + fep2 = FileEventPattern("name_two", "path", "recipe", "file") + + patterns = { + fep1.name: fep1, + fep2.name: fep2 + } + + check_pattern_dict(patterns=patterns) + + def testCheckPatternDictNoEntries(self)->None: + with self.assertRaises(ValueError): + check_pattern_dict(patterns={}) + + check_pattern_dict(patterns={}, min_length=0) + + def testCheckPatternDictMissmatchedName(self)->None: + fep1 = FileEventPattern("name_one", "path", "recipe", "file") + fep2 = FileEventPattern("name_two", "path", "recipe", "file") + + patterns = { + fep2.name: fep1, + fep1.name: fep2 + } + + with self.assertRaises(KeyError): + check_pattern_dict(patterns=patterns) + + def testCheckRecipeDictValid(self)->None: + jnr1 = JupyterNotebookRecipe("recipe_one", BAREBONES_NOTEBOOK) + jnr2 = JupyterNotebookRecipe("recipe_two", BAREBONES_NOTEBOOK) + + recipes = { + jnr1.name: jnr1, + jnr2.name: jnr2 + } + + check_recipe_dict(recipes=recipes) + + def testCheckRecipeDictNoEntires(self)->None: + with self.assertRaises(ValueError): + check_recipe_dict(recipes={}) + + check_recipe_dict(recipes={}, min_length=0) + + def testCheckRecipeDictMismatchedName(self)->None: + jnr1 = JupyterNotebookRecipe("recipe_one", BAREBONES_NOTEBOOK) + jnr2 = JupyterNotebookRecipe("recipe_two", BAREBONES_NOTEBOOK) + + recipes = { + jnr2.name: jnr1, + jnr1.name: jnr2 + } + + with self.assertRaises(KeyError): + check_recipe_dict(recipes=recipes) + + def testWaitPipes(self)->None: + pipe_one_reader, pipe_one_writer = Pipe() + pipe_two_reader, pipe_two_writer = Pipe() + + inputs = [ + pipe_one_reader, pipe_two_reader + ] + + pipe_one_writer.send(1) + readables = wait(inputs) + + self.assertIn(pipe_one_reader, readables) + self.assertEqual(len(readables), 1) + msg = readables[0].recv() + self.assertEqual(msg, 1) + + pipe_one_writer.send(1) + pipe_two_writer.send(2) + readables = wait(inputs) + + self.assertIn(pipe_one_reader, readables) + self.assertIn(pipe_two_reader, readables) + self.assertEqual(len(readables), 2) + for readable in readables: + if readable == pipe_one_reader: + msg = readable.recv() + self.assertEqual(msg, 1) + elif readable == pipe_two_reader: + msg = readable.recv() + self.assertEqual(msg, 2) + + def testWaitQueues(self)->None: + queue_one = Queue() + queue_two = Queue() + + inputs = [ + queue_one, queue_two + ] + + queue_one.put(1) + readables = wait(inputs) + + self.assertIn(queue_one, readables) + self.assertEqual(len(readables), 1) + msg = readables[0].get() + self.assertEqual(msg, 1) + + queue_one.put(1) + queue_two.put(2) + sleep(0.1) + readables = wait(inputs) + + self.assertIn(queue_one, readables) + self.assertIn(queue_two, readables) + self.assertEqual(len(readables), 2) + for readable in readables: + if readable == queue_one: + msg = readable.get() + self.assertEqual(msg, 1) + elif readable == queue_two: + msg = readable.get() + self.assertEqual(msg, 2) + + + def testWaitPipesAndQueues(self)->None: + pipe_one_reader, pipe_one_writer = Pipe() + pipe_two_reader, pipe_two_writer = Pipe() + queue_one = Queue() + queue_two = Queue() + + inputs = [ + pipe_one_reader, pipe_two_reader, queue_one, queue_two + ] + + pipe_one_writer.send(1) + readables = wait(inputs) + + self.assertIn(pipe_one_reader, readables) + self.assertEqual(len(readables), 1) + msg = readables[0].recv() + self.assertEqual(msg, 1) + + pipe_one_writer.send(1) + pipe_two_writer.send(2) + readables = wait(inputs) + + self.assertIn(pipe_one_reader, readables) + self.assertIn(pipe_two_reader, readables) + self.assertEqual(len(readables), 2) + for readable in readables: + if readable == pipe_one_reader: + msg = readable.recv() + self.assertEqual(msg, 1) + if readable == pipe_two_reader: + msg = readable.recv() + self.assertEqual(msg, 2) + + queue_one.put(1) + readables = wait(inputs) + + self.assertIn(queue_one, readables) + self.assertEqual(len(readables), 1) + msg = readables[0].get() + self.assertEqual(msg, 1) + + queue_one.put(1) + queue_two.put(2) + sleep(0.1) + readables = wait(inputs) + + self.assertIn(queue_one, readables) + self.assertIn(queue_two, readables) + self.assertEqual(len(readables), 2) + for readable in readables: + if readable == queue_one: + msg = readable.get() + self.assertEqual(msg, 1) + elif readable == queue_two: + msg = readable.get() + self.assertEqual(msg, 2) + + queue_one.put(1) + pipe_one_writer.send(1) + sleep(0.1) + readables = wait(inputs) + + self.assertIn(queue_one, readables) + self.assertIn(pipe_one_reader, readables) + self.assertEqual(len(readables), 2) + for readable in readables: + if readable == queue_one: + msg = readable.get() + self.assertEqual(msg, 1) + elif readable == pipe_one_reader: + msg = readable.recv() + self.assertEqual(msg, 1) + diff --git a/tests/test_meow.py b/tests/test_meow.py index e1dc354..6ad95e5 100644 --- a/tests/test_meow.py +++ b/tests/test_meow.py @@ -71,13 +71,13 @@ class MeowTests(unittest.TestCase): def testBaseMonitor(self)->None: with self.assertRaises(TypeError): - BaseMonitor("", "", "") + BaseMonitor("", "") class TestMonitor(BaseMonitor): pass with self.assertRaises(NotImplementedError): - TestMonitor("", "", "") + TestMonitor("", "") class FullTestMonitor(BaseMonitor): def start(self): @@ -86,11 +86,9 @@ class MeowTests(unittest.TestCase): pass def _is_valid_report(self, report:Any)->None: pass - def _is_valid_listen(self, listen:Any)->None: - pass def _is_valid_rules(self, rules:Any)->None: pass - FullTestMonitor("", "", "") + FullTestMonitor("", "") def testBaseHandler(self)->None: with self.assertRaises(TypeError): @@ -103,11 +101,13 @@ class MeowTests(unittest.TestCase): TestHandler("") class FullTestHandler(BaseHandler): - def handle(self): + def handle(self, event, rule): + pass + def start(self): + pass + def stop(self): pass def _is_valid_inputs(self, inputs:Any)->None: pass FullTestHandler("") - - diff --git a/tests/test_patterns.py b/tests/test_patterns.py index f858137..9e0f4df 100644 --- a/tests/test_patterns.py +++ b/tests/test_patterns.py @@ -112,14 +112,11 @@ class CorrectnessTests(unittest.TestCase): self.assertEqual(fep.event_mask, FILE_EVENTS) def testWatchdogMonitorMinimum(self)->None: - to_monitor = Pipe() from_monitor = Pipe() - WatchdogMonitor(TEST_BASE, {}, from_monitor[PIPE_WRITE], - to_monitor[PIPE_READ]) + WatchdogMonitor(TEST_BASE, {}, from_monitor[PIPE_WRITE]) def testWatchdogMonitorEventIdentificaion(self)->None: - to_monitor = Pipe() - from_monitor = Pipe() + from_monitor_reader, from_monitor_writer = Pipe() pattern_one = FileEventPattern( "pattern_one", "A", "recipe_one", "file_one") @@ -134,14 +131,13 @@ class CorrectnessTests(unittest.TestCase): } rules = create_rules(patterns, recipes) - wm = WatchdogMonitor(TEST_BASE, rules, from_monitor[PIPE_WRITE], - to_monitor[PIPE_READ]) + wm = WatchdogMonitor(TEST_BASE, rules, from_monitor_writer) wm.start() open(os.path.join(TEST_BASE, "A"), "w") - if from_monitor[PIPE_READ].poll(3): - message = from_monitor[PIPE_READ].recv() + if from_monitor_reader.poll(3): + message = from_monitor_reader.recv() self.assertIsNotNone(message) event, rule = message @@ -150,8 +146,8 @@ class CorrectnessTests(unittest.TestCase): self.assertEqual(event.src_path, os.path.join(TEST_BASE, "A")) open(os.path.join(TEST_BASE, "B"), "w") - if from_monitor[PIPE_READ].poll(3): - new_message = from_monitor[PIPE_READ].recv() + if from_monitor_reader.poll(3): + new_message = from_monitor_reader.recv() else: new_message = None self.assertIsNone(new_message) diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 53facf8..85d2c9a 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -2,7 +2,10 @@ import jsonschema import unittest -from recipes.jupyter_notebook_recipe import JupyterNotebookRecipe +from multiprocessing import Pipe + +from recipes.jupyter_notebook_recipe import JupyterNotebookRecipe, \ + PapermillHandler from core.correctness.vars import BAREBONES_NOTEBOOK class CorrectnessTests(unittest.TestCase): @@ -73,3 +76,36 @@ class CorrectnessTests(unittest.TestCase): jnr = JupyterNotebookRecipe( "name", BAREBONES_NOTEBOOK, source=source) self.assertEqual(jnr.source, source) + + def testPapermillHanderMinimum(self)->None: + monitor_to_handler_reader, _ = Pipe() + + PapermillHandler([monitor_to_handler_reader]) + + def testPapermillHanderStartStop(self)->None: + monitor_to_handler_reader, _ = Pipe() + + ph = PapermillHandler([monitor_to_handler_reader]) + + ph.start() + ph.stop() + + def testPapermillHanderRepeatedStarts(self)->None: + monitor_to_handler_reader, _ = Pipe() + + ph = PapermillHandler([monitor_to_handler_reader]) + + ph.start() + with self.assertRaises(RuntimeWarning): + ph.start() + ph.stop() + + def testPapermillHanderStopBeforeStart(self)->None: + monitor_to_handler_reader, _ = Pipe() + + ph = PapermillHandler([monitor_to_handler_reader]) + + with self.assertRaises(RuntimeWarning): + ph.stop() + + diff --git a/tests/test_validation.py b/tests/test_validation.py index 7bdbbd7..1129905 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,10 +1,10 @@ import unittest -from typing import Any +from typing import Any, Union -from core.correctness.validation import check_input, valid_string, \ - valid_dict, valid_list +from core.correctness.validation import check_type, check_implementation, \ + valid_string, valid_dict, valid_list from core.correctness.vars import VALID_NAME_CHARS @@ -15,21 +15,28 @@ class CorrectnessTests(unittest.TestCase): def tearDown(self)->None: return super().tearDown() - def testCheckInputValid(self)->None: - check_input(1, int) - check_input(0, int) - check_input(False, bool) - check_input(True, bool) - check_input(1, Any) - - def testCheckInputMistyped(self)->None: - with self.assertRaises(TypeError): - check_input(1, str) + def testCheckTypeValid(self)->None: + check_type(1, int) + check_type(0, int) + check_type(False, bool) + check_type(True, bool) - def testCheckInputOrNone(self)->None: - check_input(None, int, or_none=True) + def testCheckTypeValidAny(self)->None: + check_type(1, Any) + + def testCheckTypeValidUnion(self)->None: + check_type(1, Union[int,str]) with self.assertRaises(TypeError): - check_input(None, int, or_none=False) + check_type(Union[int, str], Union[int,str]) + + def testCheckTypeMistyped(self)->None: + with self.assertRaises(TypeError): + check_type(1, str) + + def testCheckTypeOrNone(self)->None: + check_type(None, int, or_none=True) + with self.assertRaises(TypeError): + check_type(None, int, or_none=False) def testValidStringValid(self)->None: valid_string("David_Marchant", VALID_NAME_CHARS) @@ -106,3 +113,48 @@ class CorrectnessTests(unittest.TestCase): def testValidListMinLength(self)->None: with self.assertRaises(ValueError): valid_list([1, 2, 3], str, min_length=10) + + def testCheckImplementationMinimum(self)->None: + class Parent: + def func(): + pass + + class Child(Parent): + def func(): + pass + + check_implementation(Child.func, Parent) + + def testCheckImplementationUnaltered(self)->None: + class Parent: + def func(): + pass + + class Child(Parent): + pass + + with self.assertRaises(NotImplementedError): + check_implementation(Child.func, Parent) + + def testCheckImplementationDifferingSig(self)->None: + class Parent: + def func(): + pass + + class Child(Parent): + def func(var): + pass + + with self.assertRaises(NotImplementedError): + check_implementation(Child.func, Parent) + + def testCheckImplementationAnyType(self)->None: + class Parent: + def func(var:Any): + pass + + class Child(Parent): + def func(var:str): + pass + + check_implementation(Child.func, Parent)