|
|
|
@ -320,6 +320,224 @@ class Validator:
|
|
|
|
|
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParamValidator:
|
|
|
|
|
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def equal(arg_name, arg_value, cond_str, cond):
|
|
|
|
|
"""Judging valid value."""
|
|
|
|
|
if not cond:
|
|
|
|
|
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
|
|
|
|
|
"""This method is only used for check int values, since when compare float values,
|
|
|
|
|
we need consider float error."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_integer(arg_name, arg_value, value, rel):
|
|
|
|
|
"""Integer value judgment."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_shape_length(arg_name, arg_value, value, rel):
|
|
|
|
|
"""Shape length judgment."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int)
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
|
|
|
|
"""This method is only used for check int values,
|
|
|
|
|
since when compare float values, we need consider float error."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int)
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_isinstance(arg_name, arg_value, classes):
|
|
|
|
|
"""Check arg isinstance of classes"""
|
|
|
|
|
if not isinstance(arg_value, classes):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
|
|
|
|
"""Is it necessary to consider error when comparing float values."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, lower_limit, upper_limit):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_subclass(arg_name, type_, template_type, with_type_of=True):
|
|
|
|
|
"""Check whether some type is subclass of another type"""
|
|
|
|
|
if not isinstance(template_type, Iterable):
|
|
|
|
|
template_type = (template_type,)
|
|
|
|
|
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
|
|
|
|
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
|
|
|
|
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
|
|
|
|
|
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_args_tensor(args):
|
|
|
|
|
"""Check whether args are all tensor."""
|
|
|
|
|
if not isinstance(args, dict):
|
|
|
|
|
raise TypeError("The args should be a dict.")
|
|
|
|
|
for arg, value in args.items():
|
|
|
|
|
ParamValidator.check_subclass(arg, value, mstype.tensor)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_bool(arg_name, arg_value):
|
|
|
|
|
"""Check arg isinstance of bool"""
|
|
|
|
|
if not isinstance(arg_value, bool):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_type(arg_name, arg_value, valid_types):
|
|
|
|
|
"""Type checking."""
|
|
|
|
|
def raise_error_msg():
|
|
|
|
|
"""func for raising error message when check failed"""
|
|
|
|
|
type_names = [t.__name__ for t in valid_types]
|
|
|
|
|
num_types = len(valid_types)
|
|
|
|
|
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
|
|
|
|
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
|
|
|
|
|
|
|
|
|
if isinstance(arg_value, type(mstype.tensor)):
|
|
|
|
|
arg_value = arg_value.element_type()
|
|
|
|
|
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
|
|
|
|
# `check_type('x', True, [bool, int])` will check pass
|
|
|
|
|
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
if isinstance(arg_value, tuple(valid_types)):
|
|
|
|
|
return arg_value
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_typename(arg_name, arg_type, valid_types):
|
|
|
|
|
"""Does it contain the _name_ attribute."""
|
|
|
|
|
|
|
|
|
|
def get_typename(t):
|
|
|
|
|
return t.__name__ if hasattr(t, '__name__') else str(t)
|
|
|
|
|
|
|
|
|
|
if isinstance(arg_type, type(mstype.tensor)):
|
|
|
|
|
arg_type = arg_type.element_type()
|
|
|
|
|
|
|
|
|
|
if arg_type in valid_types:
|
|
|
|
|
return arg_type
|
|
|
|
|
type_names = [get_typename(t) for t in valid_types]
|
|
|
|
|
if len(valid_types) == 1:
|
|
|
|
|
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
|
|
|
|
|
f' but got {get_typename(arg_type)}.')
|
|
|
|
|
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
|
|
|
|
|
f' but got {get_typename(arg_type)}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_string(arg_name, arg_value, valid_values):
|
|
|
|
|
"""String type judgment."""
|
|
|
|
|
if isinstance(arg_value, str) and arg_value in valid_values:
|
|
|
|
|
return arg_value
|
|
|
|
|
if len(valid_values) == 1:
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_type_same(args, valid_values):
|
|
|
|
|
"""Determine whether the types are the same."""
|
|
|
|
|
name = list(args.keys())[0]
|
|
|
|
|
value = list(args.values())[0]
|
|
|
|
|
if isinstance(value, type(mstype.tensor)):
|
|
|
|
|
value = value.element_type()
|
|
|
|
|
for arg_name, arg_value in args.items():
|
|
|
|
|
if isinstance(arg_value, type(mstype.tensor)):
|
|
|
|
|
arg_value = arg_value.element_type()
|
|
|
|
|
|
|
|
|
|
if arg_value not in valid_values:
|
|
|
|
|
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
|
|
|
|
|
f' but `{arg_name}` is {arg_value}.')
|
|
|
|
|
if arg_value != value:
|
|
|
|
|
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
|
|
|
|
|
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
|
|
|
|
|
"""Determine whether the types of two variables are the same."""
|
|
|
|
|
if arg1_type != arg2_type:
|
|
|
|
|
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_value_on_integer(arg_name, arg_value, value, rel):
|
|
|
|
|
"""Judging integer type."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_match = isinstance(arg_value, int)
|
|
|
|
|
if type_match and (not rel_fn(arg_value, value)):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
|
|
|
|
|
"""Judging the equality of parameters."""
|
|
|
|
|
if param1_value != param2_value:
|
|
|
|
|
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
|
|
|
|
|
f" but got `{param1_name}` = {param1_value},"
|
|
|
|
|
f" `{param2_name}` = {param2_value}.")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_const_input(arg_name, arg_value):
|
|
|
|
|
"""Check valid value."""
|
|
|
|
|
if arg_value is None:
|
|
|
|
|
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_float_positive(arg_name, arg_value):
|
|
|
|
|
"""Float type judgment."""
|
|
|
|
|
if isinstance(arg_value, float):
|
|
|
|
|
if arg_value > 0:
|
|
|
|
|
return arg_value
|
|
|
|
|
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
|
|
|
|
|
|
|
|
|
|
raise TypeError(f"`{arg_name}` must be float!")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_pad_value_by_mode(op_name, pad_mode, padding):
|
|
|
|
|
"""Validate value of padding according to pad_mode"""
|
|
|
|
|
if pad_mode != 'pad' and padding != 0:
|
|
|
|
|
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
|
|
|
|
return padding
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_empty_shape_input(arg_name, arg_value):
|
|
|
|
|
"""Check zeros value."""
|
|
|
|
|
if 0 in arg_value:
|
|
|
|
|
raise ValueError(f"Input `{arg_name}` cannot be empty.")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_scalar_shape_input(arg_name, arg_value):
|
|
|
|
|
"""Check scalar shape input."""
|
|
|
|
|
if arg_value != []:
|
|
|
|
|
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_int(input_param):
|
|
|
|
|
"""Int type judgment."""
|
|
|
|
|
if isinstance(input_param, int) and not isinstance(input_param, bool):
|
|
|
|
|