Skip to content

Commit

Permalink
updated function in ote_sdk/ote_sdk/utils/argument_checks.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saltykox committed Mar 25, 2022
1 parent ede7332 commit fb0fd0e
Showing 1 changed file with 129 additions and 80 deletions.
209 changes: 129 additions & 80 deletions ote_sdk/ote_sdk/utils/argument_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,84 @@
#

import inspect
import itertools
import typing
from abc import ABC, abstractmethod
from collections.abc import Sequence
from functools import wraps
from os.path import exists
from os.path import exists, splitext

import yaml
from numpy import floating
from omegaconf import DictConfig

IMAGE_FILE_EXTENSIONS = [
".bmp",
".dib",
".jpeg",
".jpg",
".jpe",
".jp2",
".png",
".webp",
".pbm",
".pgm",
".ppm",
".pxm",
".pnm",
".sr",
".ras",
".tiff",
".tif",
".exr",
".hdr",
".pic",
]


def get_bases(parameter) -> set:
"""Function to get set of all base classes of parameter"""

def __get_bases(parameter_type):
return [parameter_type.__name__] + list(
itertools.chain.from_iterable(
__get_bases(t1) for t1 in parameter_type.__bases__
)
)

return set(__get_bases(type(parameter)))


def get_parameter_repr(parameter) -> str:
"""Function to get parameter representation"""
try:
parameter_str = repr(parameter)
# pylint: disable=broad-except
except Exception:
parameter_str = "<unable to get parameter repr>"
return parameter_str


def raise_value_error_if_parameter_has_unexpected_type(
parameter, parameter_name, expected_type
):
"""Function raises ValueError exception if parameter has unexpected type"""
if isinstance(expected_type, typing.ForwardRef):
expected_type = expected_type.__forward_arg__
if isinstance(expected_type, str):
parameter_types = get_bases(parameter)
if not any(t == expected_type for t in parameter_types):
parameter_str = get_parameter_repr(parameter)
raise ValueError(
f"Unexpected type of '{parameter_name}' parameter, expected: {expected_type}, "
f"actual value: {parameter_str}"
)
return
if expected_type == float:
expected_type = (int, float, floating)
if not isinstance(parameter, expected_type):
parameter_type = type(parameter)
try:
parameter_str = repr(parameter)
# pylint: disable=broad-except
except Exception:
parameter_str = "<unable to get parameter repr>"
parameter_str = get_parameter_repr(parameter)
raise ValueError(
f"Unexpected type of '{parameter_name}' parameter, expected: {expected_type}, actual: {parameter_type}, "
f"actual value: {parameter_str}"
Expand Down Expand Up @@ -67,21 +121,10 @@ def check_dictionary_keys_values_type(
)


def check_parameter_type(parameter, parameter_name, expected_type):
"""Function extracts nested expected types and raises ValueError exception if parameter has unexpected type"""
# pylint: disable=W0212
if expected_type in [typing.Any, inspect._empty]: # type: ignore
return
if not isinstance(expected_type, typing._GenericAlias): # type: ignore
raise_value_error_if_parameter_has_unexpected_type(
parameter=parameter,
parameter_name=parameter_name,
expected_type=expected_type,
)
return
expected_type_dict = expected_type.__dict__
origin_class = expected_type_dict.get("__origin__")
nested_elements_class = expected_type_dict.get("__args__")
def check_nested_classes_parameters(
parameter, parameter_name, origin_class, nested_elements_class
):
"""Function to check type of parameters with nested elements"""
if origin_class == dict:
if len(nested_elements_class) != 2:
raise TypeError(
Expand All @@ -100,18 +143,53 @@ def check_parameter_type(parameter, parameter_name, expected_type):
parameter_name=parameter_name,
expected_type=origin_class,
)
if len(nested_elements_class) != 1:
raise TypeError(
"length of nested expected types for Sequence should be equal to 1"
)
if origin_class == tuple:
tuple_length = len(nested_elements_class)
if tuple_length > 2:
raise NotImplementedError(
"length of nested expected types for Tuple should not exceed 2"
)
if tuple_length == 2:
if nested_elements_class[1] != Ellipsis:
raise NotImplementedError("expected homogeneous tuple annotation")
nested_elements_class = nested_elements_class[0]
else:
if len(nested_elements_class) != 1:
raise TypeError(
"length of nested expected types for Sequence should be equal to 1"
)
check_nested_elements_type(
iterable=parameter,
parameter_name=parameter_name,
expected_type=nested_elements_class,
)


def check_parameter_type(parameter, parameter_name, expected_type):
"""Function extracts nested expected types and raises ValueError exception if parameter has unexpected type"""
# pylint: disable=W0212
if expected_type in [typing.Any, inspect._empty]: # type: ignore
return
if not isinstance(expected_type, typing._GenericAlias): # type: ignore
raise_value_error_if_parameter_has_unexpected_type(
parameter=parameter,
parameter_name=parameter_name,
expected_type=expected_type,
)
return
# Checking parameters with nested elements
expected_type_dict = expected_type.__dict__
origin_class = expected_type_dict.get("__origin__")
nested_elements_class = expected_type_dict.get("__args__")
check_nested_classes_parameters(
parameter=parameter,
parameter_name=parameter_name,
origin_class=origin_class,
nested_elements_class=nested_elements_class,
)
# Union type with nested elements check
if origin_class == typing.Union:
expected_args = expected_type_dict.get("__args__")
# Union type with nested elements check
checks_counter = 0
errors_counter = 0
for expected_arg in expected_args:
Expand All @@ -128,10 +206,13 @@ def check_parameter_type(parameter, parameter_name, expected_type):
)


def check_input_parameters_type(checks_types: dict = None):
"""Decorator to check input parameters type"""
if checks_types is None:
checks_types = {}
def check_input_parameters_type(custom_checks: typing.Optional[dict] = None):
"""
Decorator to check input parameters type
:param custom_checks: dictionary where key - name of parameter and value - custom check class
"""
if custom_checks is None:
custom_checks = {}

def _check_input_parameters_type(function):
@wraps(function)
Expand All @@ -150,21 +231,23 @@ def validate(*args, **kwargs):
)
input_parameters_values_map[key] = value
# Checking input parameters type
for parameter in expected_types_map:
input_parameter_actual = input_parameters_values_map.get(parameter)
if input_parameter_actual is None:
default_value = expected_types_map.get(parameter).default
for parameter_name in expected_types_map:
parameter = input_parameters_values_map.get(parameter_name)
if parameter is None:
default_value = expected_types_map.get(parameter_name).default
# pylint: disable=protected-access
if default_value != inspect._empty: # type: ignore
input_parameter_actual = default_value
custom_check = checks_types.get(parameter)
if custom_check:
custom_check(input_parameter_actual, parameter).check()
parameter = default_value
if parameter_name in custom_checks:
custom_check = custom_checks[parameter_name]
if custom_check is None:
continue
custom_check(parameter, parameter_name).check()
else:
check_parameter_type(
parameter=input_parameter_actual,
parameter_name=parameter,
expected_type=expected_types_map.get(parameter).annotation,
parameter=parameter,
parameter_name=parameter_name,
expected_type=expected_types_map.get(parameter_name).annotation,
)
return function(**input_parameters_values_map)

Expand All @@ -177,7 +260,7 @@ def check_file_extension(
file_path: str, file_path_name: str, expected_extensions: list
):
"""Function raises ValueError exception if file has unexpected extension"""
file_extension = file_path.split(".")[-1].lower()
file_extension = splitext(file_path)[1].lower()
if file_extension not in expected_extensions:
raise ValueError(
f"Unexpected extension of {file_path_name} file. expected: {expected_extensions} actual: {file_extension}"
Expand Down Expand Up @@ -314,7 +397,7 @@ def check(self):
check_file_extension(
file_path=self.parameter,
file_path_name=self.parameter_name,
expected_extensions=["yaml"],
expected_extensions=[".yaml"],
)
check_that_all_characters_printable(
parameter=self.parameter, parameter_name=self.parameter_name
Expand Down Expand Up @@ -363,54 +446,20 @@ def __init__(self, parameter, parameter_name):
self.parameter_name = parameter_name

def check(self):
"""Method raises ValueError exception if parameter is not equal to DataSet"""
"""Method raises ValueError exception if parameter is not equal to Dataset"""
check_is_parameter_like_dataset(
parameter=self.parameter, parameter_name=self.parameter_name
)


class OptionalDatasetParamTypeCheck(DatasetParamTypeCheck):
"""Class to check DatasetEntity-type parameters"""

def check(self):
"""Method raises ValueError exception if parameter is not equal to DataSet"""
if self.parameter is not None:
check_is_parameter_like_dataset(
parameter=self.parameter, parameter_name=self.parameter_name
)


class OptionalModelParamTypeCheck(BaseInputArgumentChecker):
"""Class to check ModelEntity-type parameters"""

def __init__(self, parameter, parameter_name):
self.parameter = parameter
self.parameter_name = parameter_name

def check(self):
"""Method raises ValueError exception if parameter is not equal to DataSet"""
if self.parameter is not None:
for expected_attribute in (
"__train_dataset__",
"__previous_trained_revision__",
"__model_format__",
):
if not hasattr(self.parameter, expected_attribute):
parameter_type = type(self.parameter)
raise ValueError(
f"parameter '{self.parameter_name}' is not like ModelEntity, actual type: {parameter_type} "
f"which does not have expected '{expected_attribute}' Model attribute"
)


class OptionalImageFilePathCheck(OptionalFilePathCheck):
"""Class to check optional image file path parameters"""

def __init__(self, parameter, parameter_name):
super().__init__(
parameter=parameter,
parameter_name=parameter_name,
expected_file_extension=["jpg", "png"],
expected_file_extension=IMAGE_FILE_EXTENSIONS,
)


Expand All @@ -421,7 +470,7 @@ def __init__(self, parameter, parameter_name):
super().__init__(
parameter=parameter,
parameter_name=parameter_name,
expected_file_extension=["yaml"],
expected_file_extension=[".yaml"],
)


Expand Down

0 comments on commit fb0fd0e

Please sign in to comment.