|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
"""Check parameters."""
|
|
|
|
|
import re
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from itertools import repeat
|
|
|
|
|
from collections import Iterable
|
|
|
|
|
|
|
|
|
@ -93,8 +94,131 @@ rel_strs = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Validator:
|
|
|
|
|
"""validator for checking input parameters"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
|
|
|
|
|
"""
|
|
|
|
|
Method for judging relation between two int values or list/tuple made up of ints.
|
|
|
|
|
|
|
|
|
|
This method is not suitable for judging relation between floats, since it does not consider float error.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
|
|
|
|
msg_prefix = f'For {prim_name} the' if prim_name else "The"
|
|
|
|
|
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_integer(arg_name, arg_value, value, rel, prim_name):
|
|
|
|
|
"""Integer value judgment."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
|
|
|
|
"""Method for checking whether an int value is in some range."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int)
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_subclass(arg_name, type_, template_type, prim_name):
|
|
|
|
|
"""Check whether some type is sublcass of another type"""
|
|
|
|
|
if not isinstance(template_type, Iterable):
|
|
|
|
|
template_type = (template_type,)
|
|
|
|
|
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
|
|
|
|
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
|
|
|
|
|
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_tensor_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""check whether the element types of input tensors are the same."""
|
|
|
|
|
def _check_tensor_type(arg):
|
|
|
|
|
arg_key, arg_val = arg
|
|
|
|
|
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
|
|
|
|
|
elem_type = arg_val.element_type()
|
|
|
|
|
if not elem_type in valid_values:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
|
|
|
|
|
f' but `{arg_key}` is {elem_type}.')
|
|
|
|
|
return (arg_key, elem_type)
|
|
|
|
|
|
|
|
|
|
def _check_types_same(arg1, arg2):
|
|
|
|
|
arg1_name, arg1_type = arg1
|
|
|
|
|
arg2_name, arg2_type = arg2
|
|
|
|
|
if arg1_type != arg2_type:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
|
|
|
|
|
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
|
|
|
|
return arg1
|
|
|
|
|
|
|
|
|
|
elem_types = map(_check_tensor_type, args.items())
|
|
|
|
|
reduce(_check_types_same, elem_types)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
|
|
|
|
|
def _check_argument_type(arg):
|
|
|
|
|
arg_key, arg_val = arg
|
|
|
|
|
if isinstance(arg_val, type(mstype.tensor)):
|
|
|
|
|
arg_val = arg_val.element_type()
|
|
|
|
|
if not arg_val in valid_values:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
|
|
|
|
|
f' but `{arg_key}` is {arg_val}.')
|
|
|
|
|
return arg
|
|
|
|
|
|
|
|
|
|
def _check_types_same(arg1, arg2):
|
|
|
|
|
arg1_name, arg1_type = arg1
|
|
|
|
|
arg2_name, arg2_type = arg2
|
|
|
|
|
excp_flag = False
|
|
|
|
|
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
|
|
|
|
|
arg1_type = arg1_type.element_type()
|
|
|
|
|
arg2_type = arg2_type.element_type()
|
|
|
|
|
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
excp_flag = True
|
|
|
|
|
|
|
|
|
|
if excp_flag or arg1_type != arg2_type:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
|
|
|
|
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
|
|
|
|
return arg1
|
|
|
|
|
reduce(_check_types_same, map(_check_argument_type, args.items()))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
|
|
|
|
"""Check whether a values is instance of some types."""
|
|
|
|
|
def raise_error_msg():
|
|
|
|
|
"""func for raising error message when check failed"""
|
|
|
|
|
type_names = [t.__name__ for t in valid_types]
|
|
|
|
|
num_types = len(valid_types)
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
|
|
|
|
|
f'{"one of " if num_types > 1 else ""}'
|
|
|
|
|
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
|
|
|
|
|
|
|
|
|
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
|
|
|
|
# `check_value_type('x', True, [bool, int])` will check pass
|
|
|
|
|
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
if isinstance(arg_value, tuple(valid_types)):
|
|
|
|
|
return arg_value
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParamValidator:
|
|
|
|
|
"""Parameter validator."""
|
|
|
|
|
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def equal(arg_name, arg_value, cond_str, cond):
|
|
|
|
|