|
|
|
@ -17,7 +17,7 @@ import re
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from itertools import repeat
|
|
|
|
|
from collections import Iterable
|
|
|
|
|
from collections.abc import Iterable
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
@ -98,7 +98,7 @@ class Validator:
|
|
|
|
|
"""validator for checking input parameters"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
|
|
|
|
|
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
|
|
|
|
|
"""
|
|
|
|
|
Method for judging relation between two int values or list/tuple made up of ints.
|
|
|
|
|
|
|
|
|
@ -108,8 +108,8 @@ class Validator:
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
|
|
|
|
msg_prefix = f'For {prim_name} the' if prim_name else "The"
|
|
|
|
|
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
|
|
|
|
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_integer(arg_name, arg_value, value, rel, prim_name):
|
|
|
|
@ -118,8 +118,17 @@ class Validator:
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
|
|
|
|
raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_number(arg_name, arg_value, value, rel, prim_name):
|
|
|
|
|
"""Integer value judgment."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -133,9 +142,46 @@ class Validator:
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
|
|
|
|
"""Method for checking whether a numeric value is in some range."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, lower_limit, upper_limit):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_string(arg_name, arg_value, valid_values, prim_name):
|
|
|
|
|
"""Checks whether a string is in some value list"""
|
|
|
|
|
if isinstance(arg_value, str) and arg_value in valid_values:
|
|
|
|
|
return arg_value
|
|
|
|
|
if len(valid_values) == 1:
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
|
|
|
|
|
f' but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_pad_value_by_mode(pad_mode, padding, prim_name):
|
|
|
|
|
"""Validates value of padding according to pad_mode"""
|
|
|
|
|
if pad_mode != 'pad' and padding != 0:
|
|
|
|
|
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
|
|
|
|
return padding
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_float_positive(arg_name, arg_value, prim_name):
|
|
|
|
|
"""Float type judgment."""
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
|
|
|
|
if isinstance(arg_value, float):
|
|
|
|
|
if arg_value > 0:
|
|
|
|
|
return arg_value
|
|
|
|
|
raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
|
|
|
|
|
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_subclass(arg_name, type_, template_type, prim_name):
|
|
|
|
|
"""Check whether some type is sublcass of another type"""
|
|
|
|
|
"""Checks whether some type is sublcass of another type"""
|
|
|
|
|
if not isinstance(template_type, Iterable):
|
|
|
|
|
template_type = (template_type,)
|
|
|
|
|
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
|
|
|
@ -143,16 +189,44 @@ class Validator:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
|
|
|
|
|
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_const_input(arg_name, arg_value, prim_name):
|
|
|
|
|
"""Check valid value."""
|
|
|
|
|
if arg_value is None:
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_scalar_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""check whether the types of inputs are the same."""
|
|
|
|
|
def _check_tensor_type(arg):
|
|
|
|
|
arg_key, arg_val = arg
|
|
|
|
|
elem_type = arg_val
|
|
|
|
|
if not elem_type in valid_values:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},'
|
|
|
|
|
f' but `{arg_key}` is {elem_type}.')
|
|
|
|
|
return (arg_key, elem_type)
|
|
|
|
|
|
|
|
|
|
def _check_types_same(arg1, arg2):
|
|
|
|
|
arg1_name, arg1_type = arg1
|
|
|
|
|
arg2_name, arg2_type = arg2
|
|
|
|
|
if arg1_type != arg2_type:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
|
|
|
|
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
|
|
|
|
return arg1
|
|
|
|
|
|
|
|
|
|
elem_types = map(_check_tensor_type, args.items())
|
|
|
|
|
reduce(_check_types_same, elem_types)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_tensor_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""check whether the element types of input tensors are the same."""
|
|
|
|
|
"""Checks whether the element types of input tensors are the same."""
|
|
|
|
|
def _check_tensor_type(arg):
|
|
|
|
|
arg_key, arg_val = arg
|
|
|
|
|
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
|
|
|
|
|
elem_type = arg_val.element_type()
|
|
|
|
|
if not elem_type in valid_values:
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
|
|
|
|
|
f' but `{arg_key}` is {elem_type}.')
|
|
|
|
|
f' but element type of `{arg_key}` is {elem_type}.')
|
|
|
|
|
return (arg_key, elem_type)
|
|
|
|
|
|
|
|
|
|
def _check_types_same(arg1, arg2):
|
|
|
|
@ -168,8 +242,13 @@ class Validator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
|
|
|
|
|
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
|
|
|
|
|
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
|
|
|
|
"""
|
|
|
|
|
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
|
|
|
|
|
|
|
|
|
|
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _check_argument_type(arg):
|
|
|
|
|
arg_key, arg_val = arg
|
|
|
|
|
if isinstance(arg_val, type(mstype.tensor)):
|
|
|
|
@ -188,6 +267,9 @@ class Validator:
|
|
|
|
|
arg2_type = arg2_type.element_type()
|
|
|
|
|
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
|
|
|
|
|
pass
|
|
|
|
|
elif allow_mix:
|
|
|
|
|
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
|
|
|
|
|
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
|
|
|
|
|
else:
|
|
|
|
|
excp_flag = True
|
|
|
|
|
|
|
|
|
@ -199,13 +281,14 @@ class Validator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
|
|
|
|
"""Check whether a values is instance of some types."""
|
|
|
|
|
"""Checks whether a value is instance of some types."""
|
|
|
|
|
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
|
|
|
def raise_error_msg():
|
|
|
|
|
"""func for raising error message when check failed"""
|
|
|
|
|
type_names = [t.__name__ for t in valid_types]
|
|
|
|
|
num_types = len(valid_types)
|
|
|
|
|
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
|
|
|
|
|
f'{"one of " if num_types > 1 else ""}'
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
|
|
|
|
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
|
|
|
|
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
|
|
|
|
|
|
|
|
|
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
|
|
|
@ -216,6 +299,23 @@ class Validator:
|
|
|
|
|
return arg_value
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
|
|
|
|
"""Checks whether a type in some specified types"""
|
|
|
|
|
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
|
|
|
def get_typename(t):
|
|
|
|
|
return t.__name__ if hasattr(t, '__name__') else str(t)
|
|
|
|
|
|
|
|
|
|
if arg_type in valid_types:
|
|
|
|
|
return arg_type
|
|
|
|
|
type_names = [get_typename(t) for t in valid_types]
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
|
|
|
|
if len(valid_types) == 1:
|
|
|
|
|
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
|
|
|
|
|
f' but got {get_typename(arg_type)}.')
|
|
|
|
|
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
|
|
|
|
|
f' but got {get_typename(arg_type)}.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParamValidator:
|
|
|
|
|
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
|
|
|
|