|
|
|
@ -118,10 +118,12 @@ class Validator:
|
|
|
|
|
"""Integer value judgment."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
|
|
|
|
excp_cls = TypeError if type_mismatch else ValueError
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
|
|
|
|
raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
|
|
|
|
raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
|
|
|
|
|
f' with type `{type(arg_value).__name__}`.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -138,10 +140,11 @@ class Validator:
|
|
|
|
|
"""Method for checking whether an int value is in some range."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int)
|
|
|
|
|
excp_cls = TypeError if type_mismatch else ValueError
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
|
|
|
|
|
f' but got `{arg_value}` with type `{type(arg_value).__name__}`.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -193,19 +196,23 @@ class Validator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_const_input(arg_name, arg_value, prim_name):
|
|
|
|
|
"""Check valid value."""
|
|
|
|
|
"""Checks valid value."""
|
|
|
|
|
if arg_value is None:
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_scalar_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""check whether the types of inputs are the same."""
|
|
|
|
|
def check_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""Checks whether the types of inputs are the same."""
|
|
|
|
|
def _check_tensor_type(arg):
|
|
|
|
|
arg_key, arg_val = arg
|
|
|
|
|
elem_type = arg_val
|
|
|
|
|
type_names = []
|
|
|
|
|
if not elem_type in valid_values:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},'
|
|
|
|
|
f' but `{arg_key}` is {elem_type}.')
|
|
|
|
|
for t in valid_values:
|
|
|
|
|
type_names.append(str(t))
|
|
|
|
|
types_info = '[' + ", ".join(type_names) + ']'
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
|
|
|
|
|
f' but got {elem_type}.')
|
|
|
|
|
return (arg_key, elem_type)
|
|
|
|
|
|
|
|
|
|
def _check_types_same(arg1, arg2):
|
|
|
|
@ -213,7 +220,7 @@ class Validator:
|
|
|
|
|
arg2_name, arg2_type = arg2
|
|
|
|
|
if arg1_type != arg2_type:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
|
|
|
|
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
|
|
|
|
f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
|
|
|
|
|
return arg1
|
|
|
|
|
|
|
|
|
|
elem_types = map(_check_tensor_type, args.items())
|
|
|
|
@ -222,25 +229,8 @@ class Validator:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_tensor_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""Checks whether the element types of input tensors are the same."""
|
|
|
|
|
def _check_tensor_type(arg):
|
|
|
|
|
arg_key, arg_val = arg
|
|
|
|
|
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
|
|
|
|
|
elem_type = arg_val.element_type()
|
|
|
|
|
if not elem_type in valid_values:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
|
|
|
|
|
f' but element type of `{arg_key}` is {elem_type}.')
|
|
|
|
|
return (arg_key, elem_type)
|
|
|
|
|
|
|
|
|
|
def _check_types_same(arg1, arg2):
|
|
|
|
|
arg1_name, arg1_type = arg1
|
|
|
|
|
arg2_name, arg2_type = arg2
|
|
|
|
|
if arg1_type != arg2_type:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
|
|
|
|
|
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
|
|
|
|
return arg1
|
|
|
|
|
|
|
|
|
|
elem_types = map(_check_tensor_type, args.items())
|
|
|
|
|
reduce(_check_types_same, elem_types)
|
|
|
|
|
tensor_types = [mstype.tensor_type(t) for t in valid_values]
|
|
|
|
|
Validator.check_type_same(args, tensor_types, prim_name)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
|
|
|
|