|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Check parameters."""
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
import inspect
|
|
|
|
|
import math
|
|
|
|
@ -20,10 +21,9 @@ from enum import Enum
|
|
|
|
|
from functools import reduce, wraps
|
|
|
|
|
from itertools import repeat
|
|
|
|
|
from collections.abc import Iterable
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from .common import dtype as mstype
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Named string regular expression
|
|
|
|
@ -103,18 +103,17 @@ class Validator:
|
|
|
|
|
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
|
|
|
|
|
"""
|
|
|
|
|
Method for judging relation between two int values or list/tuple made up of ints.
|
|
|
|
|
|
|
|
|
|
This method is not suitable for judging relation between floats, since it does not consider float error.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
|
|
|
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
|
|
|
|
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_integer(arg_name, arg_value, value, rel, prim_name):
|
|
|
|
|
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
|
|
|
|
|
"""Integer value judgment."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
|
|
|
@ -135,6 +134,20 @@ class Validator:
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_isinstance(arg_name, arg_value, classes):
|
|
|
|
|
"""Check arg isinstance of classes"""
|
|
|
|
|
if not isinstance(arg_value, classes):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_bool(arg_name, arg_value):
|
|
|
|
|
"""Check arg isinstance of bool"""
|
|
|
|
|
if not isinstance(arg_value, bool):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
|
|
|
|
"""Method for checking whether an int value is in some range."""
|
|
|
|
@ -208,6 +221,27 @@ class Validator:
|
|
|
|
|
"""Checks valid value."""
|
|
|
|
|
if arg_value is None:
|
|
|
|
|
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_type(arg_name, arg_value, valid_types):
|
|
|
|
|
"""Type checking."""
|
|
|
|
|
def raise_error_msg():
|
|
|
|
|
"""func for raising error message when check failed"""
|
|
|
|
|
type_names = [t.__name__ for t in valid_types]
|
|
|
|
|
num_types = len(valid_types)
|
|
|
|
|
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
|
|
|
|
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
|
|
|
|
|
|
|
|
|
if isinstance(arg_value, type(mstype.tensor)):
|
|
|
|
|
arg_value = arg_value.element_type()
|
|
|
|
|
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
|
|
|
|
# `check_type('x', True, [bool, int])` will check pass
|
|
|
|
|
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
if isinstance(arg_value, tuple(valid_types)):
|
|
|
|
|
return arg_value
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_type_same(args, valid_values, prim_name):
|
|
|
|
@ -239,7 +273,6 @@ class Validator:
|
|
|
|
|
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
|
|
|
|
"""
|
|
|
|
|
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
|
|
|
|
|
|
|
|
|
|
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -335,63 +368,6 @@ class Validator:
|
|
|
|
|
f'{tuple(exp_shape)}, but got {shape}.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParamValidator:
|
|
|
|
|
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
|
|
|
|
|
"""This method is only used for check int values, since when compare float values,
|
|
|
|
|
we need consider float error."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
if not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_integer(arg_name, arg_value, value, rel):
|
|
|
|
|
"""Integer value judgment."""
|
|
|
|
|
rel_fn = Rel.get_fns(rel)
|
|
|
|
|
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
|
|
|
|
if type_mismatch or not rel_fn(arg_value, value):
|
|
|
|
|
rel_str = Rel.get_strs(rel).format(value)
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_isinstance(arg_name, arg_value, classes):
|
|
|
|
|
"""Check arg isinstance of classes"""
|
|
|
|
|
if not isinstance(arg_value, classes):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_bool(arg_name, arg_value):
|
|
|
|
|
"""Check arg isinstance of bool"""
|
|
|
|
|
if not isinstance(arg_value, bool):
|
|
|
|
|
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_type(arg_name, arg_value, valid_types):
|
|
|
|
|
"""Type checking."""
|
|
|
|
|
def raise_error_msg():
|
|
|
|
|
"""func for raising error message when check failed"""
|
|
|
|
|
type_names = [t.__name__ for t in valid_types]
|
|
|
|
|
num_types = len(valid_types)
|
|
|
|
|
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
|
|
|
|
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
|
|
|
|
|
|
|
|
|
if isinstance(arg_value, type(mstype.tensor)):
|
|
|
|
|
arg_value = arg_value.element_type()
|
|
|
|
|
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
|
|
|
|
# `check_type('x', True, [bool, int])` will check pass
|
|
|
|
|
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
if isinstance(arg_value, tuple(valid_types)):
|
|
|
|
|
return arg_value
|
|
|
|
|
raise_error_msg()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_int(input_param):
|
|
|
|
|
"""Int type judgment."""
|
|
|
|
|
if isinstance(input_param, int) and not isinstance(input_param, bool):
|
|
|
|
@ -638,7 +614,6 @@ def args_type_check(*type_args, **type_kwargs):
|
|
|
|
|
if value is not None and not isinstance(value, bound_types[name]):
|
|
|
|
|
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
return type_check
|
|
|
|
|