added support for multi-type waiting plus some cleanup
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
|
||||
from inspect import signature
|
||||
from os.path import sep
|
||||
from typing import Any, _SpecialForm
|
||||
from typing import Any, _SpecialForm, Union, get_origin, get_args
|
||||
|
||||
from core.correctness.vars import VALID_PATH_CHARS
|
||||
from core.correctness.vars import VALID_PATH_CHARS, get_not_imp_msg
|
||||
|
||||
def check_input(variable:Any, expected_type:type, alt_types:list[type]=[],
|
||||
def check_type(variable:Any, expected_type:type, alt_types:list[type]=[],
|
||||
or_none:bool=False)->None:
|
||||
"""
|
||||
Checks if a given variable is of the expected type. Raises TypeError or
|
||||
@ -24,23 +25,38 @@ def check_input(variable:Any, expected_type:type, alt_types:list[type]=[],
|
||||
"""
|
||||
|
||||
type_list = [expected_type]
|
||||
if get_origin(expected_type) is Union:
|
||||
type_list = list(get_args(expected_type))
|
||||
type_list = type_list + alt_types
|
||||
|
||||
if not or_none:
|
||||
if expected_type != Any \
|
||||
and type(variable) not in type_list:
|
||||
if variable is None:
|
||||
if or_none == False:
|
||||
raise TypeError(
|
||||
'Expected type was %s, got %s'
|
||||
% (expected_type, type(variable))
|
||||
)
|
||||
else:
|
||||
if expected_type != Any \
|
||||
and not type(variable) not in type_list \
|
||||
and not isinstance(variable, type(None)):
|
||||
raise TypeError(
|
||||
'Expected type was %s or None, got %s'
|
||||
% (expected_type, type(variable))
|
||||
f'Not allowed None for variable. Expected {expected_type}.'
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
if expected_type == Any:
|
||||
return
|
||||
|
||||
if not isinstance(variable, tuple(type_list)):
|
||||
raise TypeError(
|
||||
'Expected type(s) are %s, got %s'
|
||||
% (get_args(expected_type), type(variable))
|
||||
)
|
||||
|
||||
def check_implementation(child_func, parent_class):
|
||||
parent_func = getattr(parent_class, child_func.__name__)
|
||||
if (child_func == parent_func):
|
||||
msg = get_not_imp_msg(parent_class, parent_func)
|
||||
raise NotImplementedError(msg)
|
||||
child_sig = signature(child_func).parameters
|
||||
parent_sig = signature(parent_func).parameters
|
||||
|
||||
if child_sig.keys() != parent_sig.keys():
|
||||
msg = get_not_imp_msg(parent_class, parent_func)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def valid_string(variable:str, valid_chars:str, min_length:int=1)->None:
|
||||
"""
|
||||
@ -56,8 +72,8 @@ def valid_string(variable:str, valid_chars:str, min_length:int=1)->None:
|
||||
|
||||
:return: No return.
|
||||
"""
|
||||
check_input(variable, str)
|
||||
check_input(valid_chars, str)
|
||||
check_type(variable, str)
|
||||
check_type(valid_chars, str)
|
||||
|
||||
if len(variable) < min_length:
|
||||
raise ValueError (
|
||||
@ -74,12 +90,12 @@ def valid_string(variable:str, valid_chars:str, min_length:int=1)->None:
|
||||
def valid_dict(variable:dict[Any, Any], key_type:type, value_type:type,
|
||||
required_keys:list[Any]=[], optional_keys:list[Any]=[],
|
||||
strict:bool=True, min_length:int=1)->None:
|
||||
check_input(variable, dict)
|
||||
check_input(key_type, type, alt_types=[_SpecialForm])
|
||||
check_input(value_type, type, alt_types=[_SpecialForm])
|
||||
check_input(required_keys, list)
|
||||
check_input(optional_keys, list)
|
||||
check_input(strict, bool)
|
||||
check_type(variable, dict)
|
||||
check_type(key_type, type, alt_types=[_SpecialForm])
|
||||
check_type(value_type, type, alt_types=[_SpecialForm])
|
||||
check_type(required_keys, list)
|
||||
check_type(optional_keys, list)
|
||||
check_type(strict, bool)
|
||||
|
||||
if len(variable) < min_length:
|
||||
raise ValueError(f"Dictionary '{variable}' is below minimum length of "
|
||||
@ -106,18 +122,17 @@ def valid_dict(variable:dict[Any, Any], key_type:type, value_type:type,
|
||||
|
||||
def valid_list(variable:list[Any], entry_type:type,
|
||||
alt_types:list[type]=[], min_length:int=1)->None:
|
||||
check_input(variable, list)
|
||||
check_type(variable, list)
|
||||
if len(variable) < min_length:
|
||||
raise ValueError(f"List '{variable}' is too short. Should be at least "
|
||||
f"of length {min_length}")
|
||||
for entry in variable:
|
||||
check_input(entry, entry_type, alt_types=alt_types)
|
||||
check_type(entry, entry_type, alt_types=alt_types)
|
||||
|
||||
def valid_path(variable:str, allow_base=False, extension:str="", min_length=1):
|
||||
valid_string(variable, VALID_PATH_CHARS, min_length=min_length)
|
||||
if not allow_base and variable.startswith(sep):
|
||||
raise ValueError(f"Cannot accept path '{variable}'. Must be relative.")
|
||||
if min_length > 0 and extension and not variable.endswith(extension):
|
||||
if extension and not variable.endswith(extension):
|
||||
raise ValueError(f"Path '{variable}' does not have required "
|
||||
f"extension '{extension}'.")
|
||||
|
||||
|
Reference in New Issue
Block a user