|
|
|
@ -14,8 +14,9 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Check parameters."""
|
|
|
|
|
import re
|
|
|
|
|
import inspect
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from functools import reduce, wraps
|
|
|
|
|
from itertools import repeat
|
|
|
|
|
from collections.abc import Iterable
|
|
|
|
|
|
|
|
|
@ -181,7 +182,7 @@ class Validator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_subclass(arg_name, type_, template_type, prim_name):
|
|
|
|
|
"""Checks whether some type is sublcass of another type"""
|
|
|
|
|
"""Checks whether some type is subclass of another type"""
|
|
|
|
|
if not isinstance(template_type, Iterable):
|
|
|
|
|
template_type = (template_type,)
|
|
|
|
|
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
|
|
|
@ -240,7 +241,6 @@ class Validator:
|
|
|
|
|
elem_types = map(_check_tensor_type, args.items())
|
|
|
|
|
reduce(_check_types_same, elem_types)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
|
|
|
|
"""
|
|
|
|
@ -261,7 +261,7 @@ class Validator:
|
|
|
|
|
def _check_types_same(arg1, arg2):
|
|
|
|
|
arg1_name, arg1_type = arg1
|
|
|
|
|
arg2_name, arg2_type = arg2
|
|
|
|
|
excp_flag = False
|
|
|
|
|
except_flag = False
|
|
|
|
|
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
|
|
|
|
|
arg1_type = arg1_type.element_type()
|
|
|
|
|
arg2_type = arg2_type.element_type()
|
|
|
|
@ -271,9 +271,9 @@ class Validator:
|
|
|
|
|
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
|
|
|
|
|
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
|
|
|
|
|
else:
|
|
|
|
|
excp_flag = True
|
|
|
|
|
except_flag = True
|
|
|
|
|
|
|
|
|
|
if excp_flag or arg1_type != arg2_type:
|
|
|
|
|
if except_flag or arg1_type != arg2_type:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
|
|
|
|
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
|
|
|
|
return arg1
|
|
|
|
@ -283,11 +283,12 @@ class Validator:
|
|
|
|
|
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
|
|
|
|
"""Checks whether a value is instance of some types."""
|
|
|
|
|
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
|
|
|
|
|
|
|
|
def raise_error_msg():
|
|
|
|
|
"""func for raising error message when check failed"""
|
|
|
|
|
type_names = [t.__name__ for t in valid_types]
|
|
|
|
|
num_types = len(valid_types)
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
|
|
|
|
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
|
|
|
|
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
|
|
|
|
|
|
|
|
@ -303,6 +304,7 @@ class Validator:
|
|
|
|
|
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
|
|
|
|
"""Checks whether a type in some specified types"""
|
|
|
|
|
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
|
|
|
|
|
|
|
|
def get_typename(t):
|
|
|
|
|
return t.__name__ if hasattr(t, '__name__') else str(t)
|
|
|
|
|
|
|
|
|
@ -368,9 +370,9 @@ class ParamValidator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_isinstance(arg_name, arg_value, classes):
|
|
|
|
|
"""Check arg isintance of classes"""
|
|
|
|
|
"""Check arg isinstance of classes"""
|
|
|
|
|
if not isinstance(arg_value, classes):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.')
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -384,7 +386,7 @@ class ParamValidator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_subclass(arg_name, type_, template_type, with_type_of=True):
|
|
|
|
|
"""Check whether some type is sublcass of another type"""
|
|
|
|
|
"""Check whether some type is subclass of another type"""
|
|
|
|
|
if not isinstance(template_type, Iterable):
|
|
|
|
|
template_type = (template_type,)
|
|
|
|
|
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
|
|
|
@ -402,9 +404,9 @@ class ParamValidator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_bool(arg_name, arg_value):
|
|
|
|
|
"""Check arg isintance of bool"""
|
|
|
|
|
"""Check arg isinstance of bool"""
|
|
|
|
|
if not isinstance(arg_value, bool):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.')
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII):
|
|
|
|
|
if re.match(reg, target, flag) is None:
|
|
|
|
|
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def args_type_check(*type_args, **type_kwargs):
|
|
|
|
|
"""Check whether input data type is correct."""
|
|
|
|
|
|
|
|
|
|
def type_check(func):
|
|
|
|
|
sig = inspect.signature(func)
|
|
|
|
|
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
|
|
|
|
|
|
|
|
|
|
@wraps(func)
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
nonlocal bound_types
|
|
|
|
|
bound_values = sig.bind(*args, **kwargs)
|
|
|
|
|
argument_dict = bound_values.arguments
|
|
|
|
|
if "kwargs" in bound_types:
|
|
|
|
|
bound_types = bound_types["kwargs"]
|
|
|
|
|
if "kwargs" in argument_dict:
|
|
|
|
|
argument_dict = argument_dict["kwargs"]
|
|
|
|
|
for name, value in argument_dict.items():
|
|
|
|
|
if name in bound_types:
|
|
|
|
|
if value is not None and not isinstance(value, bound_types[name]):
|
|
|
|
|
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
return type_check
|
|
|
|
|