diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index de3e598283..9c19d6833e 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -97,7 +97,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N Check argument integer. Usage: - - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 + - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 """ rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) @@ -166,12 +166,12 @@ class Validator: return arg_value @staticmethod - def check_integer(arg_name, arg_value, value, rel, prim_name=None): + def check_int(arg_value, value, rel, arg_name=None, prim_name=None): """ Checks input integer value `arg_value` compare to `value`. Usage: - - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 + - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 """ return check_number(arg_value, value, rel, int, arg_name, prim_name) @@ -187,6 +187,16 @@ class Validator: """ return check_is_number(arg_value, int, arg_name, prim_name) + @staticmethod + def check_equal_int(arg_value, value, arg_name=None, prim_name=None): + """ + Checks input integer value `arg_value` compare to `value`. + + Usage: + - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 + """ + return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name) + @staticmethod def check_positive_int(arg_value, arg_name=None, prim_name=None): """ @@ -365,6 +375,17 @@ class Validator: raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,' f' but got `{arg_value}`.') + @staticmethod + def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): + if reg is None: + # Named string regular expression + reg = r"^\w+[0-9a-zA-Z\_\.]*$" + if re.match(reg, target, flag) is None: + prim_name = f'in `{prim_name}`' if prim_name else "" + raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( + target, prim_name, reg, flag)) + return True + @staticmethod def check_pad_value_by_mode(pad_mode, padding, prim_name): """Validates value of padding according to pad_mode""" @@ -530,13 +551,6 @@ class Validator: f'{tuple(exp_shape)}, but got {shape}.') -def check_int_zero_one(input_param): - """Judge whether it is 0 or 1.""" - if input_param in (0, 1): - return input_param - raise ValueError("The data must be 0 or 1.") - - def check_input_format(input_param): """Judge input format.""" if input_param == "NCHW": @@ -544,27 +558,6 @@ def check_input_format(input_param): raise ValueError("The data format must be NCHW.") -def check_padding(padding): - """Check padding.""" - if padding >= 0: - return padding - raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding)) - - -def check_padmode(mode): - """Check padmode.""" - if mode in ("same", "valid", "pad"): - return mode - raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode)) - - -def check_tensor_supported_type(dtype): - """Check tensor dtype.""" - if dtype in (mstype.int32, mstype.float32): - return dtype - raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype)) - - def _expand_tuple(n_dimensions): """To expand a number to tuple.""" @@ -673,42 +666,6 @@ def check_typename(arg_name, arg_type, valid_types): f' but got {get_typename(arg_type)}.') -def check_shape(arg_name, arg_value): - """Check shape.""" - # First, check if shape is a tuple - if not isinstance(arg_value, tuple): - raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},' - f' but got {type(arg_value).__name__}.') - - # Second, wrap arg_value with numpy array so that it can be checked through numpy api - arg_value = np.array(arg_value) - - # shape can not be () - if arg_value.size == 0: - raise ValueError('Shape can not be empty.') - - # shape's dimension should be 1 - if arg_value.ndim != 1: - raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim)) - - # Thirdly, check each element's type of the shape - valid_types = (int, np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64) - for dim_size in arg_value: - if not isinstance(dim_size, valid_types) or dim_size <= 0: - raise ValueError('Every dimension size of the tensor shape should be a positive integer,' - ' but got {}.'.format(dim_size)) - - -def _check_str_by_regular(target, reg=None, flag=re.ASCII): - if reg is None: - # Named string regular expression - reg = r"^\w+[0-9a-zA-Z\_\.]*$" - if re.match(reg, target, flag) is None: - raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) - return True - - def args_type_check(*type_args, **type_kwargs): """Check whether input data type is correct.""" diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 6ad4cb19d2..c41274dee2 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -19,7 +19,7 @@ from .._c_expression import ParamInfo from . import dtype as mstype from .initializer import initializer, Initializer from .tensor import Tensor, MetaTensor -from .._checkparam import _check_str_by_regular +from .._checkparam import Validator from ..parallel._tensor import _get_slice_index from ..parallel._auto_parallel_context import auto_parallel_context from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched @@ -263,7 +263,7 @@ class Parameter(MetaTensor): Returns: Parameter, a new parameter. """ - _check_str_by_regular(prefix) + Validator.check_str_by_regular(prefix) x = copy(self) # pylint: disable=protected-access x._param_info = self._param_info.clone() @@ -446,7 +446,7 @@ class ParameterTuple(tuple): Returns: Tuple, the new Parameter tuple. """ - _check_str_by_regular(prefix) + Validator.check_str_by_regular(prefix) new = [] for x in self: x1 = x.clone(prefix, init) diff --git a/mindspore/context.py b/mindspore/context.py index 6532f89cea..44daf6d464 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -23,7 +23,7 @@ from collections import namedtuple from types import FunctionType from mindspore import log as logger from mindspore._c_expression import MSContext, ms_ctx_param -from mindspore._checkparam import args_type_check +from mindspore._checkparam import args_type_check, Validator from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context @@ -35,9 +35,9 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut GRAPH_MODE = 0 PYNATIVE_MODE = 1 -# The max memory size of graph plus variable. -_DEVICE_APP_MEMORY_SIZE = 31 - +_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. +_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' +_k_context = None def _make_directory(path): """Make directory.""" @@ -223,7 +223,7 @@ class _Context: def set_variable_memory_max_size(self, variable_memory_max_size): """set values of variable_memory_max_size and graph_memory_max_size""" - if not _check_input_format(variable_memory_max_size): + if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern): raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: raise ValueError("Context param variable_memory_max_size should be less than 31GB.") @@ -235,7 +235,7 @@ class _Context: self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_) def set_max_device_memory(self, max_device_memory): - if not _check_input_format(max_device_memory): + if not Validator.check_str_by_regular(max_device_memory, _re_pattern): raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") max_device_memory_value = float(max_device_memory[:-2]) if max_device_memory_value == 0: @@ -294,16 +294,6 @@ class _Context: thread_info.debug_runtime = enable -def _check_input_format(x): - import re - pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' - result = re.match(pattern, x) - return result is not None - - -_k_context = None - - def _context(): """ Get the global _context, if context is not created, create a new one. diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 1330fc3474..00f67bc791 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -23,7 +23,7 @@ from mindspore import log as logger from .. import context from ..common import dtype as mstype from ..common.api import _executor, _pynative_exec -from .._checkparam import _check_str_by_regular +from .._checkparam import Validator from ..common.parameter import Parameter, ParameterTuple from .._c_expression import init_backend, Cell_ from ..ops.primitive import Primitive @@ -715,7 +715,7 @@ class Cell(Cell_): recurse (bool): Whether contains the parameters of subcells. Default: True. """ - _check_str_by_regular(prefix) + Validator.check_str_by_regular(prefix) for name, param in self.parameters_and_names(expand=recurse): if prefix != '': param.is_init = False diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 5277f37741..b9f18d5494 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -549,7 +549,7 @@ class Unfold(Cell): @constexpr def _get_matrix_diag_assist(x_shape, x_dtype): - Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist") + Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) return Tensor(assist, x_dtype) @@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype): @constexpr def _get_matrix_diag_part_assist(x_shape, x_dtype): - Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist") + Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist") base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) return Tensor(assist, x_dtype) diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 2e661a518d..793d338da4 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -239,8 +239,8 @@ class Conv2d(_Conv): """Initialize depthwise conv2d op""" if context.get_context("device_target") == "Ascend" and self.group > 1: self.dilation = self._dilation - Validator.check_integer('group', self.group, self.in_channels, Rel.EQ) - Validator.check_integer('group', self.group, self.out_channels, Rel.EQ) + Validator.check_equal_int(self.group, self.in_channels, 'group') + Validator.check_equal_int(self.group, self.out_channels, 'group') self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=self.kernel_size, pad_mode=self.pad_mode, @@ -384,10 +384,10 @@ class Conv1d(_Conv): Validator.check_value_type("stride", stride, [int], self.cls_name) Validator.check_value_type("padding", padding, [int], self.cls_name) Validator.check_value_type("dilation", dilation, [int], self.cls_name) - Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) - Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) + Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name) + Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name) Validator.check_non_negative_int(padding, 'padding', self.cls_name) - Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) + Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name) kernel_size = (1, kernel_size) stride = (1, stride) dilation = (1, dilation) @@ -395,7 +395,7 @@ class Conv1d(_Conv): get_dtype = P.DType() if isinstance(weight_init, Tensor): weight_init_shape = get_shape(weight_init) - Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name) + Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name) weight_init_dtype = get_dtype(weight_init) weight_init_value = weight_init.asnumpy() weight_init_value = np.expand_dims(weight_init_value, 2) @@ -539,7 +539,7 @@ class Conv2dTranspose(_Conv): dilation = twice(dilation) Validator.check_value_type('padding', padding, (int, tuple), self.cls_name) if isinstance(padding, tuple): - Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name) + Validator.check_equal_int(len(padding), 4, 'padding size', self.cls_name) # out_channels and in_channels swap. # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel, # then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel. @@ -703,10 +703,10 @@ class Conv1dTranspose(_Conv): Validator.check_value_type("stride", stride, [int], self.cls_name) Validator.check_value_type("padding", padding, [int], self.cls_name) Validator.check_value_type("dilation", dilation, [int], self.cls_name) - Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) - Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) + Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name) + Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name) Validator.check_non_negative_int(padding, 'padding', self.cls_name) - Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) + Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name) kernel_size = (1, kernel_size) stride = (1, stride) dilation = (1, dilation) @@ -714,7 +714,7 @@ class Conv1dTranspose(_Conv): get_dtype = P.DType() if isinstance(weight_init, Tensor): weight_init_shape = get_shape(weight_init) - Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name) + Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name) weight_init_dtype = get_dtype(weight_init) weight_init_value = weight_init.asnumpy() weight_init_value = np.expand_dims(weight_init_value, 2) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index cfafd8b5d4..097910e7c9 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -220,7 +220,7 @@ class SSIM(Cell): validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val - self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) + self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name) self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) @@ -298,7 +298,7 @@ class MSSSIM(Cell): validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) - self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) + self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name) self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 64aa70c07a..2c13957636 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -190,8 +190,8 @@ class MaxPool1d(_PoolNd): validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) validator.check_value_type('stride', stride, [int], self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) - validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name) - validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name) + validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) + validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) self.kernel_size = (1, kernel_size) self.stride = (1, stride) self.max_pool = P.MaxPool(ksize=self.kernel_size, @@ -349,8 +349,8 @@ class AvgPool1d(_PoolNd): validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) validator.check_value_type('stride', stride, [int], self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) - validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name) - validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name) + validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) + validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) self.kernel_size = (1, kernel_size) self.stride = (1, stride) self.avg_pool = P.AvgPool(ksize=self.kernel_size, diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 8fa2ba8cc0..66545a71d1 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -323,7 +323,7 @@ class FakeQuantWithMinMax(Cell): Validator.check_type("min_init", min_init, [int, float]) Validator.check_type("max_init", max_init, [int, float]) Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) - Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) + Validator.check_non_negative_int(quant_delay, 'quant_delay') self.min_init = min_init self.max_init = max_init self.num_bits = num_bits @@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell): # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: - Validator.check_integer('group', group, in_channels, Rel.EQ) - Validator.check_integer('group', group, out_channels, Rel.EQ) + Validator.check_equal_int(group, in_channels, 'group') + Validator.check_equal_int(group, out_channels, 'group') self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=self.kernel_size, pad_mode=pad_mode, @@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell): self.bias = None # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: - Validator.check_integer('group', group, in_channels, Rel.EQ) - Validator.check_integer('group', group, out_channels, Rel.EQ) + Validator.check_equal_int(group, in_channels, 'group') + Validator.check_equal_int(group, out_channels, 'group') self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=self.kernel_size, pad_mode=pad_mode, diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 150d471d69..48f405678d 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -931,19 +931,19 @@ class LSTMGradData(PrimitiveWithInfer): def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape, hx_shape, cx_shape, reserve_shape, state_shape): # dhy and dcy should be same shape - validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name) - validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name) - validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name) - validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name) - validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name) + validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) + validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) + validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) + validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) + validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) - validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) - validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name) + validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) + validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name) # dy: (seq_len, batch_size, hidden_size * num_directions) - validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name) - validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name) - validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name) + validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) + validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) + validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) # (seq_len, batch_size, input_size) dx_shape = (y_shape[0], y_shape[1], self.input_size) @@ -1015,19 +1015,19 @@ class LSTMGrad(PrimitiveWithInfer): def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape, dcy_shape, reserve_shape): # dhy and dcy should be same shape - validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name) - validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name) - validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name) - validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name) - validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name) + validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) + validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) + validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) + validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) + validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) - validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) - validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name) + validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) + validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name) # dy: (seq_len, batch_size, hidden_size * num_directions) - validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name) - validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name) - validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name) + validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) + validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) + validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) # (seq_len, batch_size, input_size) dx_shape = (y_shape[0], y_shape[1], self.input_size) @@ -1069,7 +1069,7 @@ class DynamicRNNGrad(PrimitiveWithInfer): def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): - validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name) + validator.check_equal_int(len(x_shape), 3, "x_shape", self.name) num_step, batch_size, input_size = x_shape hidden_size = w_shape[-1] // 4 if w_shape[-1] % 4 != 0: @@ -1575,7 +1575,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): # dhy and dcy should be same shape - validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) + validator.check_equal_int(len(c_shape), 2, "c rank", self.name) validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name) validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name) validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name) @@ -1624,7 +1624,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): self.add_prim_attr("io_format", "HWCN") def infer_shape(self, x_shape, h_shape, dgate_shape): - validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) + validator.check_equal_int(len(x_shape), 2, "x rank", self.name) validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name) validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name) validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) @@ -1656,8 +1656,8 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): self.add_prim_attr("io_format", "ND") def infer_shape(self, dgate_shape, w_shape): - validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) - validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) + validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name) + validator.check_equal_int(len(w_shape), 2, "w rank", self.name) validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) batch_size = dgate_shape[0] hidden_size = dgate_shape[1] // 4 diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 753f55c033..17382bdd52 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -347,7 +347,7 @@ class MatrixDiag(PrimitiveWithInfer): return x_dtype def infer_shape(self, x_shape, assist_shape): - validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name) + validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) validator.check('rank of x', len(x_shape)+1, 'rank of assist', len(assist_shape), Rel.LE, self.name) validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', @@ -395,7 +395,7 @@ class MatrixDiagPart(PrimitiveWithInfer): return x_dtype def infer_shape(self, x_shape, assist_shape): - validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) + validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) if assist_shape[-2] < assist_shape[-1]: @@ -438,7 +438,7 @@ class MatrixSetDiag(PrimitiveWithInfer): return x_dtype def infer_shape(self, x_shape, diagonal_shape, assist_shape): - validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) + validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) if x_shape[-2] < x_shape[-1]: diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 863e214eec..9e21c9ddbe 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -81,11 +81,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer): outputs=['min_up', 'max_up']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(min_shape), 1, "min shape", self.name) return min_shape, max_shape def infer_dtype(self, x_type, min_type, max_type): @@ -147,11 +146,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") if not self.is_ascend: - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(min_shape), 1, "min shape", self.name) return min_shape, max_shape def infer_dtype(self, x_type, min_type, max_type): @@ -228,9 +226,9 @@ class FakeQuantPerLayer(PrimitiveWithInfer): outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(min_shape), 1, "min shape", self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): @@ -284,8 +282,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): x_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(min_shape), 1, "min shape", self.name) return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): @@ -375,14 +372,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer): if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") if not self.is_ascend: - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) if len(x_shape) == 1: self.channel_axis = 0 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer( - "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer( - "max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) + validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name) + validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): @@ -501,7 +496,7 @@ class BatchNormFold(PrimitiveWithInfer): def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) return mean_shape, mean_shape, mean_shape, mean_shape def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): @@ -548,7 +543,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): "batch_std shape", batch_std_shape, Rel.EQ, self.name) validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) return x_shape def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, @@ -723,7 +718,7 @@ class BatchNormFold2(PrimitiveWithInfer): validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) return x_shape def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, @@ -771,7 +766,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer): validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape def infer_dtype(self, dout_type, x_type, gamma_type, diff --git a/mindspore/ops/operations/_thor_ops.py b/mindspore/ops/operations/_thor_ops.py index e54bbeb8d5..9cca988955 100644 --- a/mindspore/ops/operations/_thor_ops.py +++ b/mindspore/ops/operations/_thor_ops.py @@ -520,7 +520,7 @@ class Im2Col(PrimitiveWithInfer): self.add_prim_attr('data_format', "NCHW") def infer_shape(self, x_shape): - validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check_equal_int(len(x_shape), 4, "x rank", self.name) kernel_size_h = self.kernel_size[0] kernel_size_w = self.kernel_size[1] stride_h = self.stride[2] diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 95af441570..a7e66ddf96 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -583,8 +583,8 @@ class Transpose(PrimitiveWithInfer): tmp = list(p_value) for i, dim in enumerate(p_value): - validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name) - validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name) + validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name) + validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name) tmp.remove(dim) if dim in tmp: raise ValueError('The value of perm is wrong.') @@ -725,8 +725,8 @@ class Padding(PrimitiveWithInfer): def __infer__(self, x): validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) x_shape = list(x['shape']) - validator.check_integer("rank of x", len(x_shape), 1, Rel.GT, self.name) - validator.check_integer("last dim of x", x_shape[-1], 1, Rel.EQ, self.name) + validator.check_int(len(x_shape), 1, Rel.GT, "rank of x", self.name) + validator.check_int(x_shape[-1], 1, Rel.EQ, "last dim of x", self.name) out_shape = x_shape out_shape[-1] = self.pad_dim_size out = {'shape': out_shape, @@ -1575,7 +1575,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): valid_type = [mstype.float16, mstype.float32, mstype.int32] validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) - validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check(f'first shape of input_x', x_shape[0], 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) num_segments_v = num_segments['value'] @@ -1628,7 +1628,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer): valid_type = [mstype.float16, mstype.float32, mstype.int32] validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) - validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check(f'first shape of input_x', x_shape[0], 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) num_segments_v = num_segments['value'] @@ -1730,7 +1730,7 @@ class ParallelConcat(PrimitiveWithInfer): x_shp = values['shape'] x_type = values['dtype'] - validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) + validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) @@ -1738,7 +1738,7 @@ class ParallelConcat(PrimitiveWithInfer): first_elem = x_shp[0] for i, elem in enumerate(x_shp[1:]): j = i + 1 - validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) + validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name) validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) ret_shp = x_shp[0].copy() @@ -1755,7 +1755,7 @@ class ParallelConcat(PrimitiveWithInfer): def _get_pack_shape(x_shape, x_type, axis, prim_name): """for pack output shape""" validator.check_value_type("shape", x_shape, [tuple, list], prim_name) - validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name) + validator.check_int(len(x_shape), 1, Rel.GE, "len of input_x", prim_name) validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) rank_base = len(x_shape[0]) N = len(x_shape) @@ -1871,8 +1871,8 @@ class Unpack(PrimitiveWithInfer): validator.check_positive_int(output_num, "output_num", self.name) self.add_prim_attr('num', output_num) output_valid_check = x_shape[self.axis] - output_num - validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ, - self.name) + validator.check_int(output_valid_check, 0, Rel.EQ, + "The dimension which to unpack divides output_num", self.name) out_shapes = [] out_dtypes = [] out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] @@ -2523,7 +2523,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): """Initialize ResizeNearestNeighbor""" validator.check_value_type("size", size, [tuple, list], self.name) validator.check_value_type("align_corners", align_corners, [bool], self.name) - validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name) + validator.check_equal_int(len(size), 2, "length of size", self.name) for i, value in enumerate(size): validator.check_non_negative_int(value, f'{i}th value of size', self.name) self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) @@ -3134,9 +3134,8 @@ class DepthToSpace(PrimitiveWithInfer): for i in range(2): out_shape[i + 2] *= self.block_size - validator.check_integer('x_shape[1] % (block_size*block_size)', - x_shape[1] % (self.block_size * self.block_size), - 0, Rel.EQ, self.name) + validator.check_int(x_shape[1] % (self.block_size * self.block_size), + 0, Rel.EQ, 'x_shape[1] % (block_size*block_size)', self.name) out_shape[1] //= self.block_size * self.block_size return out_shape @@ -3205,7 +3204,7 @@ class SpaceToBatch(PrimitiveWithInfer): return x_dtype def infer_shape(self, x_shape): - validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) + validator.check_equal_int(len(x_shape), 4, 'rank of input_x', self.name) out_shape = copy.deepcopy(x_shape) for i in range(2): padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1] @@ -3367,7 +3366,7 @@ class SpaceToBatchND(PrimitiveWithInfer): def infer_shape(self, x_shape): x_rank = len(x_shape) - validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) + validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name) out_shape = copy.deepcopy(x_shape) block_shape_prod = 1 @@ -3460,7 +3459,7 @@ class BatchToSpaceND(PrimitiveWithInfer): def infer_shape(self, x_shape): x_rank = len(x_shape) - validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) + validator.check_int(x_rank, 4, Rel.EQ, 'x_shape rank', self.name) out_shape = copy.deepcopy(x_shape) block_shape_prod = 1 @@ -3607,11 +3606,11 @@ class Meshgrid(PrimitiveWithInfer): def infer_shape(self, x_shape): validator.check_value_type("shape", x_shape, [tuple, list], self.name) - validator.check_integer("len of input_x", len(x_shape), 2, Rel.GE, self.name) + validator.check_int(len(x_shape), 2, Rel.GE, "len of input_x", self.name) n = len(x_shape) shape_0 = [] for s in x_shape: - validator.check_integer('each_input_rank', len(s), 1, Rel.EQ, self.name) + validator.check_int(len(s), 1, Rel.EQ, 'each_input_rank', self.name) shape_0.append(s[0]) if self.indexing == "xy": shape_0[0], shape_0[1] = shape_0[1], shape_0[0] diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 37584cd3ba..4a06ed9b7f 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -204,7 +204,7 @@ class _HostAllGather(PrimitiveWithInfer): if group is None: raise ValueError(f"For '{self.name}' group must be set.") validator.check_value_type('group', group, (tuple, list), self.name) - validator.check_integer("group size", len(group), 2, Rel.GE, self.name) + validator.check_int(len(group), 2, Rel.GE, "group size", self.name) for r in group: validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) validator.check_value_type("rank_id", r, (int,), self.name) @@ -313,7 +313,7 @@ class _HostReduceScatter(PrimitiveWithInfer): raise ValueError(f"For '{self.name}' group must be set.") validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) validator.check_value_type('group', group, (tuple, list), self.name) - validator.check_integer("group size", len(group), 2, Rel.GE, self.name) + validator.check_int(len(group), 2, Rel.GE, "group size", self.name) for r in group: validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) validator.check_value_type("rank_id", r, (int,), self.name) diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index 2600ea1d61..acdf8ff548 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -126,7 +126,7 @@ class GeSwitch(PrimitiveWithInfer): raise NotImplementedError def infer_shape(self, data, pred): - validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name) + validator.check_equal_int(len(pred), 0, "pred rank", self.name) return (data, data) def infer_dtype(self, data_type, pred_type): diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 6b7335d0a7..6dcf8dc8a4 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -374,9 +374,9 @@ class Assert(PrimitiveWithInfer): def infer_shape(self, condition, inputs): condition_len = len(condition) - validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name) + validator.check_int(condition_len, 1, Rel.LE, "condition's rank", self.name) if condition_len == 1: - validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name) + validator.check_equal_int(condition[0], 1, "condition[0]", self.name) return [1] def infer_dtype(self, condition, inputs): diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index af399fdc53..fb1ef13bb2 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -17,7 +17,6 @@ import numbers from ..._checkparam import Validator as validator -from ..._checkparam import Rel from ...common.dtype import tensor, dtype_to_pytype from ..primitive import prim_attr_register, PrimitiveWithInfer @@ -43,7 +42,7 @@ class ScalarCast(PrimitiveWithInfer): pass def __infer__(self, x, t): - validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name) + validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name) value, to = x['value'], t['value'] if value is not None: validator.check_value_type("value", value, [numbers.Number, bool], self.name) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index a368192f01..8a318f79bd 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -827,7 +827,7 @@ class AddN(PrimitiveWithInfer): def infer_shape(self, inputs): cls_name = self.name - validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) + validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) self.add_prim_attr('n', len(inputs)) shp0 = inputs[0] for i, shp in enumerate(inputs): @@ -837,7 +837,7 @@ class AddN(PrimitiveWithInfer): def infer_dtype(self, inputs): cls_name = self.name validator.check_value_type("inputs", inputs, [tuple, list], cls_name) - validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) + validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) args = {} contains_undetermined = False for i, dtype in enumerate(inputs): @@ -910,7 +910,7 @@ class AccumulateNV2(PrimitiveWithInfer): def infer_shape(self, inputs): cls_name = self.name - validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) + validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) self.add_prim_attr('n', len(inputs)) shp0 = inputs[0] for i, shp in enumerate(inputs): @@ -920,7 +920,7 @@ class AccumulateNV2(PrimitiveWithInfer): def infer_dtype(self, inputs): cls_name = self.name validator.check_value_type("inputs", inputs, [tuple, list], cls_name) - validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) + validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) args = {} for i, dtype in enumerate(inputs): args[f"inputs[{i}]"] = dtype @@ -1488,7 +1488,7 @@ class HistogramFixedWidth(PrimitiveWithInfer): @prim_attr_register def __init__(self, nbins, dtype='int32'): self.nbins = validator.check_value_type("nbins", nbins, [int], self.name) - validator.check_integer("nbins", nbins, 1, Rel.GE, self.name) + validator.check_int(nbins, 1, Rel.GE, "nbins", self.name) valid_values = ['int32', 'int64'] self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name) self.init_prim_io_names(inputs=['x', 'range'], outputs=['y']) @@ -2810,8 +2810,8 @@ class NPUGetFloatStatus(PrimitiveWithInfer): def infer_shape(self, x_shape): cls_name = self.name - validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) - validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) + validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name) + validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name) return [8] def infer_dtype(self, x_dtype): @@ -2853,8 +2853,8 @@ class NPUClearFloatStatus(PrimitiveWithInfer): def infer_shape(self, x_shape): cls_name = self.name - validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) - validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) + validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name) + validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name) return [8] def infer_dtype(self, x_dtype): @@ -3023,9 +3023,9 @@ class NMSWithMask(PrimitiveWithInfer): def infer_shape(self, bboxes_shape): cls_name = self.name - validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) + validator.check_equal_int(len(bboxes_shape), 2, "bboxes rank", cls_name) validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name) - validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) + validator.check_equal_int(bboxes_shape[1], 5, "bboxes.shape[1]", cls_name) num = bboxes_shape[0] return (bboxes_shape, (num,), (num,)) @@ -3572,11 +3572,11 @@ class IFMR(PrimitiveWithInfer): validator.check_value_type("offset_flag", with_offset, [bool], self.name) def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): - validator.check_integer("dims of data_min", len(data_min_shape), 1, Rel.EQ, self.name) - validator.check_integer("data_min[0]", data_min_shape[0], 1, Rel.EQ, self.name) - validator.check_integer("dims of data_max", len(data_max_shape), 1, Rel.EQ, self.name) - validator.check_integer("data_max[0]", data_max_shape[0], 1, Rel.EQ, self.name) - validator.check_integer("dims of cumsum", len(cumsum_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name) + validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name) + validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name) + validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name) + validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name) return (1,), (1,) def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2de7bd8a4e..633f336440 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -98,7 +98,7 @@ class Flatten(PrimitiveWithInfer): pass def infer_shape(self, input_x): - validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) + validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:]) return input_x[0], prod @@ -146,7 +146,7 @@ class Softmax(PrimitiveWithInfer): validator.check_value_type("item of axis", item, [int], self.name) def infer_shape(self, logits): - validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) + validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name) rank = len(logits) for axis_v in self.axis: validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) @@ -636,7 +636,7 @@ class FusedBatchNorm(Primitive): def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) - self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) + self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) self._update_parameter = True @@ -709,17 +709,17 @@ class FusedBatchNormEx(PrimitiveWithInfer): def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) - self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) + self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) self._update_parameter = True self.add_prim_attr('data_format', "NCHW") def infer_shape(self, input_x, scale, bias, mean, variance): - validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) + validator.check_equal_int(len(scale), 1, "scale rank", self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) - validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) + validator.check_equal_int(len(mean), 1, "mean rank", self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) return (input_x, scale, scale, scale, scale, scale) @@ -757,7 +757,7 @@ class BNTrainingReduce(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum']) def infer_shape(self, x_shape): - validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check_equal_int(len(x_shape), 4, "x rank", self.name) return ([x_shape[1]], [x_shape[1]]) def infer_dtype(self, x_type): @@ -822,13 +822,13 @@ class BNTrainingUpdate(PrimitiveWithInfer): self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate') def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): - validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name) - validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name) - validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name) - validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) - validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name) - validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) - validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name) + validator.check_equal_int(len(x), 4, "x rank", self.name) + validator.check_equal_int(len(sum), 1, "sum rank", self.name) + validator.check_equal_int(len(square_sum), 1, "square_sum rank", self.name) + validator.check_equal_int(len(scale), 1, "scale rank", self.name) + validator.check_equal_int(len(b), 1, "b rank", self.name) + validator.check_equal_int(len(mean), 1, "mean rank", self.name) + validator.check_equal_int(len(variance), 1, "variance rank", self.name) validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name) validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name) validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name) @@ -904,11 +904,11 @@ class BatchNorm(PrimitiveWithInfer): outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) def infer_shape(self, input_x, scale, bias, mean, variance): - validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) + validator.check_equal_int(len(scale), 1, "scale rank", self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) if not self.is_training: - validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) + validator.check_equal_int(len(mean), 1, "mean rank", self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) return (input_x, scale, scale, scale, scale) @@ -1010,7 +1010,7 @@ class Conv2D(PrimitiveWithInfer): if isinstance(pad, int): pad = (pad,) * 4 else: - validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) + validator.check_equal_int(len(pad), 4, 'pad size', self.name) self.padding = pad self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) @@ -1020,15 +1020,15 @@ class Conv2D(PrimitiveWithInfer): for item in pad: validator.check_non_negative_int(item, 'pad item', self.name) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) + self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.add_prim_attr('data_format', "NCHW") self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.group = validator.check_positive_int(group, 'group', self.name) self.add_prim_attr('offset_a', 0) def infer_shape(self, x_shape, w_shape, b_shape=None): - validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) - validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check_equal_int(len(w_shape), 4, "weight rank", self.name) + validator.check_equal_int(len(x_shape), 4, "x rank", self.name) validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) @@ -1150,7 +1150,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): if isinstance(pad, int): pad = (pad,) * 4 else: - validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) + validator.check_equal_int(len(pad), 4, 'pad size', self.name) self.padding = pad self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) if pad_mode != 'pad' and pad != (0, 0, 0, 0): @@ -1158,15 +1158,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): if self.pad_mode == 'pad': for item in pad: validator.check_non_negative_int(item, 'pad item', self.name) - self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) + self.mode = validator.check_equal_int(mode, 3, "mode", self.name) self.add_prim_attr('data_format', "NCHW") self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name) self.group = validator.check_positive_int(group, "group", self.name) self.add_prim_attr('offset_a', 0) def infer_shape(self, x_shape, w_shape, b_shape=None): - validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) - validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check_equal_int(len(w_shape), 4, "weight rank", self.name) + validator.check_equal_int(len(x_shape), 4, "x rank", self.name) validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) @@ -1250,7 +1250,7 @@ class _Pool(PrimitiveWithInfer): self.add_prim_attr("strides", self.strides) def infer_shape(self, x_shape): - validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check_equal_int(len(x_shape), 4, "x rank", self.name) batch, channel, input_h, input_w = x_shape if self.is_maxpoolwithargmax: _, kernel_h, kernel_w, _ = self.ksize @@ -1536,7 +1536,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): if isinstance(pad, int): pad = (pad,) * 4 else: - validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) + validator.check_equal_int(len(pad), 4, 'pad size', self.name) self.padding = pad self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) if pad_mode != 'pad' and pad != (0, 0, 0, 0): @@ -1547,7 +1547,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) + self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.group = validator.check_positive_int(group, 'group', self.name) self.add_prim_attr('data_format', "NCHW") if pad_list: @@ -1624,8 +1624,8 @@ class BiasAdd(PrimitiveWithInfer): self.add_prim_attr('data_format', 'NCHW') def infer_shape(self, x_shape, b_shape): - validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) - validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name) + validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) + validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name) return x_shape @@ -2007,10 +2007,10 @@ class RNNTLoss(PrimitiveWithInfer): outputs=['costs', 'grads']) def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape): - validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name) - validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name) - validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) - validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) + validator.check_equal_int(len(acts_shape), 4, 'acts_rank', self.name) + validator.check_equal_int(len(labels_shape), 2, 'labels_rank', self.name) + validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name) + validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name) validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) @@ -2080,11 +2080,11 @@ class SGD(PrimitiveWithInfer): def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, accum_shape, momentum_shape, stat_shape): validator.check_positive_int(len(parameters_shape), "parameters rank", self.name) - validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name) - validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name) + validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name) + validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name) validator.check_positive_int(len(accum_shape), "accumulation rank", self.name) - validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name) - validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name) + validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name) + validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name) validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) return parameters_shape @@ -2780,17 +2780,17 @@ class LSTM(PrimitiveWithInfer): def infer_shape(self, x_shape, h_shape, c_shape, w_shape): # (seq, batch_size, feature) - validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) - validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name) + validator.check_equal_int(len(x_shape), 3, "x rank", self.name) + validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name) # h and c should be same shape - validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name) + validator.check_equal_int(len(h_shape), 3, "h rank", self.name) validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) # (num_layers * num_directions, batch, hidden_size) - validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) - validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name) - validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name) + validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name) + validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name) + validator.check_int(h_shape[2], self.hidden_size, Rel.EQ, "h[2]", self.name) y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions) @@ -2918,7 +2918,7 @@ class Pad(PrimitiveWithInfer): def infer_shape(self, x): paddings = np.array(self.paddings) - validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name) + validator.check_int(paddings.size, len(x) * 2, Rel.EQ, 'paddings.shape', self.name) if not np.all(paddings >= 0): raise ValueError('All elements of paddings must be >= 0.') y_shape = () @@ -2992,7 +2992,7 @@ class MirrorPad(PrimitiveWithInfer): x_shape = list(input_x['shape']) paddings_value = paddings['value'].asnumpy() paddings_size = paddings_value.size - validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name) + validator.check_int(paddings_size, len(x_shape) * 2, Rel.EQ, 'paddings.shape', self.name) if not np.all(paddings_value >= 0): raise ValueError('All elements of paddings must be >= 0.') adjust = 0 @@ -3276,7 +3276,7 @@ class FusedSparseAdam(PrimitiveWithInfer): beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]: raise ValueError(f"For '{self.name}', the shape of updates should be [] or " @@ -3409,7 +3409,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer): beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]: raise ValueError(f"For '{self.name}', the shape of updates should be [] or " @@ -3513,7 +3513,7 @@ class FusedSparseFtrl(PrimitiveWithInfer): validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) if len(var_shape) > 1: validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) return [1], [1], [1] @@ -3602,7 +3602,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) return [1], [1] def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): @@ -3869,25 +3869,25 @@ class ApplyAdaMax(PrimitiveWithInfer): validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name) validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name) beta1_power_shp_len = len(beta1_power_shape) - validator.check_integer("beta1 power's rank", beta1_power_shp_len, 1, Rel.LE, self.name) + validator.check_int(beta1_power_shp_len, 1, Rel.LE, "beta1 power's rank", self.name) if beta1_power_shp_len == 1: - validator.check_integer("beta1_power_shape[0]", beta1_power_shape[0], 1, Rel.EQ, self.name) + validator.check_int(beta1_power_shape[0], 1, Rel.EQ, "beta1_power_shape[0]", self.name) lr_shp_len = len(lr_shape) - validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) + validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) if lr_shp_len == 1: - validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) + validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) beta1_shp_len = len(beta1_shape) - validator.check_integer("beta1's rank", beta1_shp_len, 1, Rel.LE, self.name) + validator.check_int(beta1_shp_len, 1, Rel.LE, "beta1's rank", self.name) if beta1_shp_len == 1: - validator.check_integer("beta1_shape[0]", beta1_shape[0], 1, Rel.EQ, self.name) + validator.check_int(beta1_shape[0], 1, Rel.EQ, "beta1_shape[0]", self.name) beta2_shp_len = len(beta2_shape) - validator.check_integer("beta2's rank", beta2_shp_len, 1, Rel.LE, self.name) + validator.check_int(beta2_shp_len, 1, Rel.LE, "beta2's rank", self.name) if beta2_shp_len == 1: - validator.check_integer("beta2_shape[0]", beta2_shape[0], 1, Rel.EQ, self.name) + validator.check_int(beta2_shape[0], 1, Rel.EQ, "beta2_shape[0]", self.name) epsilon_shp_len = len(epsilon_shape) - validator.check_integer("epsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name) + validator.check_int(epsilon_shp_len, 1, Rel.LE, "epsilon's rank", self.name) if epsilon_shp_len == 1: - validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name) + validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name) return var_shape, m_shape, v_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype, @@ -3985,17 +3985,17 @@ class ApplyAdadelta(PrimitiveWithInfer): validator.check("accum_update_shape", accum_update_shape, "var_shape", var_shape, Rel.EQ, self.name) validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name) lr_shp_len = len(lr_shape) - validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) + validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) if lr_shp_len == 1: - validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) + validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) rho_shp_len = len(rho_shape) - validator.check_integer("rho's rank", rho_shp_len, 1, Rel.LE, self.name) + validator.check_int(rho_shp_len, 1, Rel.LE, "rho's rank", self.name) if rho_shp_len == 1: - validator.check_integer("rho_shape[0]", rho_shape[0], 1, Rel.EQ, self.name) + validator.check_int(rho_shape[0], 1, Rel.EQ, "rho_shape[0]", self.name) epsilon_shp_len = len(epsilon_shape) - validator.check_integer("lepsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name) + validator.check_int(epsilon_shp_len, 1, Rel.LE, "lepsilon's rank", self.name) if epsilon_shp_len == 1: - validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name) + validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name) return var_shape, accum_shape, accum_update_shape def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype, @@ -4077,9 +4077,9 @@ class ApplyAdagrad(PrimitiveWithInfer): validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name) lr_shp_len = len(lr_shape) - validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) + validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) if lr_shp_len == 1: - validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) + validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) return var_shape, accum_shape def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): @@ -4161,9 +4161,9 @@ class ApplyAdagradV2(PrimitiveWithInfer): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) lr_shp_len = len(lr_shape) - validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) + validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) if lr_shp_len == 1: - validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) + validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) return var_shape, accum_shape def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): @@ -4249,7 +4249,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer): validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) if len(var_shape) > 1: validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) return var_shape, accum_shape @@ -4338,7 +4338,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer): validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) if len(var_shape) > 1: validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) return var_shape, accum_shape @@ -4428,17 +4428,17 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name) lr_shp_len = len(lr_shape) - validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) + validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) if lr_shp_len == 1: - validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) + validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) l1_shp_len = len(l1_shape) - validator.check_integer("l1's rank", l1_shp_len, 1, Rel.LE, self.name) + validator.check_int(l1_shp_len, 1, Rel.LE, "l1's rank", self.name) if l1_shp_len == 1: - validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name) + validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name) l2_shp_len = len(l2_shape) - validator.check_integer("l2's rank", l2_shp_len, 1, Rel.LE, self.name) + validator.check_int(l2_shp_len, 1, Rel.LE, "l2's rank", self.name) if l2_shp_len == 1: - validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name) + validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name) return var_shape, accum_shape def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): @@ -4532,7 +4532,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} @@ -4623,21 +4623,21 @@ class ApplyAddSign(PrimitiveWithInfer): validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) lr_shape_len = len(lr_shape) - validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name) + validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name) if lr_shape_len == 1: - validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) + validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) alpha_shape_len = len(alpha_shape) - validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name) + validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name) if alpha_shape_len == 1: - validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name) + validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name) sign_decay_shape_len = len(sign_decay_shape) - validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name) + validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name) if sign_decay_shape_len == 1: - validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name) + validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name) beta_shape_len = len(beta_shape) - validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name) + validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name) if beta_shape_len == 1: - validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name) + validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name) return var_shape, m_shape def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype): @@ -4732,21 +4732,21 @@ class ApplyPowerSign(PrimitiveWithInfer): validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) lr_shape_len = len(lr_shape) - validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name) + validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name) if lr_shape_len == 1: - validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) + validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) logbase_shape_len = len(logbase_shape) - validator.check_integer("logbase's rank", logbase_shape_len, 1, Rel.LE, self.name) + validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name) if logbase_shape_len == 1: - validator.check_integer("logbase_shape[0]", logbase_shape[0], 1, Rel.EQ, self.name) + validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name) sign_decay_shape_len = len(sign_decay_shape) - validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name) + validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name) if sign_decay_shape_len == 1: - validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name) + validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name) beta_shape_len = len(beta_shape) - validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name) + validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name) if beta_shape_len == 1: - validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name) + validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name) return var_shape, m_shape def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype): @@ -4812,9 +4812,9 @@ class ApplyGradientDescent(PrimitiveWithInfer): def infer_shape(self, var_shape, alpha_shape, delta_shape): validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) alpha_shape_len = len(alpha_shape) - validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name) + validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name) if alpha_shape_len == 1: - validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name) + validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name) return var_shape def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype): @@ -4887,17 +4887,17 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer): def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape): validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) alpha_shape_len = len(alpha_shape) - validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name) + validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name) if alpha_shape_len == 1: - validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name) + validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name) l1_shape_len = len(l1_shape) - validator.check_integer("l1's rank", l1_shape_len, 1, Rel.LE, self.name) + validator.check_int(l1_shape_len, 1, Rel.LE, "l1's rank", self.name) if l1_shape_len == 1: - validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name) + validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name) l2_shape_len = len(l2_shape) - validator.check_integer("l2's rank", l2_shape_len, 1, Rel.LE, self.name) + validator.check_int(l2_shape_len, 1, Rel.LE, "l2's rank", self.name) if l2_shape_len == 1: - validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name) + validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name) return var_shape def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype): @@ -4965,13 +4965,13 @@ class LARSUpdate(PrimitiveWithInfer): validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ, self.name) shp_len = len(weight_decay_shape) - validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name) + validator.check_int(shp_len, 1, Rel.LE, "weight decay's rank", self.name) if shp_len == 1: - validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name) + validator.check_int(weight_decay_shape[0], 1, Rel.EQ, "weight_decay_shape[0]", self.name) shp_len = len(learning_rate_shape) - validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name) + validator.check_int(shp_len, 1, Rel.LE, "learning rate's rank", self.name) if shp_len == 1: - validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name) + validator.check_int(learning_rate_shape[0], 1, Rel.EQ, "learning_rate_shape[0]", self.name) return weight_shape def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype, @@ -5155,7 +5155,7 @@ class SparseApplyFtrl(PrimitiveWithCheck): validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) if len(var_shape) > 1: validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): @@ -5251,7 +5251,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) if len(var_shape) > 1: validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) return var_shape, accum_shape, linear_shape @@ -5288,7 +5288,7 @@ class Dropout(PrimitiveWithInfer): self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) def infer_shape(self, x_shape): - validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) + validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name) mask_shape = x_shape return x_shape, mask_shape @@ -5352,11 +5352,11 @@ class CTCLoss(PrimitiveWithInfer): self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs def infer_shape(self, inputs, labels_indices, labels_values, sequence_length): - validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name) - validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name) - validator.check_integer("labels_indices dim one", labels_indices[1], 2, Rel.EQ, self.name) - validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name) - validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name) + validator.check_int(len(inputs), 3, Rel.EQ, "inputs rank", self.name) + validator.check_int(len(labels_indices), 2, Rel.EQ, "labels_indices rank", self.name) + validator.check_int(labels_indices[1], 2, Rel.EQ, "labels_indices dim one", self.name) + validator.check_int(len(labels_values), 1, Rel.EQ, "labels_values rank", self.name) + validator.check_int(len(sequence_length), 1, Rel.EQ, "sequence_length rank", self.name) validator.check('labels_indices size', labels_indices[0], 'labels_values size', labels_values[0], Rel.EQ, self.name) validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size', @@ -5422,8 +5422,8 @@ class CTCGreedyDecoder(PrimitiveWithInfer): self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) def infer_shape(self, inputs_shape, sequence_length_shape): - validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name) - validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name) + validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name) + validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name) validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', sequence_length_shape[0], Rel.EQ, self.name) total_decoded_outputs = -1 @@ -5517,11 +5517,11 @@ class BasicLSTMCell(PrimitiveWithInfer): self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): - validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) - validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name) - validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) - validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) - validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) + validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name) + validator.check_int(len(h_shape), 2, Rel.EQ, "h rank", self.name) + validator.check_int(len(c_shape), 2, Rel.EQ, "c rank", self.name) + validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name) + validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name) validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name) @@ -5637,11 +5637,11 @@ class DynamicRNN(PrimitiveWithInfer): self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): - validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name) - validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) - validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) - validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name) - validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name) + validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name) + validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name) + validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name) + validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name) + validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name) if seq_shape is not None: raise ValueError(f"For {self.name}, seq_shape should be None.") @@ -5654,7 +5654,7 @@ class DynamicRNN(PrimitiveWithInfer): validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size", input_size + hidden_size, Rel.EQ, self.name) validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name) - validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name) + validator.check_int(h_shape[0], 1, Rel.EQ, "h_shape[0]", self.name) validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name) validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name) validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name) @@ -5754,5 +5754,5 @@ class LRN(PrimitiveWithInfer): return x_dtype def infer_shape(self, x_shape): - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) + validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name) return x_shape diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 47bf0d84f0..c7fba63528 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -98,16 +98,16 @@ class BoundingBoxEncode(PrimitiveWithInfer): validator.check_value_type("means[%d]" % i, value, [float], self.name) for i, value in enumerate(stds): validator.check_value_type("stds[%d]" % i, value, [float], self.name) - validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) - validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) + validator.check_equal_int(len(means), 4, "means len", self.name) + validator.check_equal_int(len(stds), 4, "stds len", self.name) def infer_shape(self, anchor_box, groundtruth_box): validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, self.name) validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name) - validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) - validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name) + validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name) + validator.check_equal_int(groundtruth_box[1], 4, 'groundtruth_box shape[1]', self.name) return anchor_box def infer_dtype(self, anchor_box, groundtruth_box): @@ -153,18 +153,18 @@ class BoundingBoxDecode(PrimitiveWithInfer): for i, value in enumerate(stds): validator.check_value_type("stds[%d]" % i, value, [float], self.name) validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) - validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) - validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) + validator.check_equal_int(len(means), 4, "means len", self.name) + validator.check_equal_int(len(stds), 4, "stds len", self.name) if max_shape is not None: validator.check_value_type('max_shape', max_shape, [tuple], self.name) - validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name) + validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name) def infer_shape(self, anchor_box, deltas): validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name) - validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) - validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name) + validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name) + validator.check_equal_int(deltas[1], 4, 'deltas shape[1]', self.name) return anchor_box def infer_dtype(self, anchor_box, deltas): @@ -272,10 +272,10 @@ class IOU(PrimitiveWithInfer): self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap']) def infer_shape(self, anchor_boxes, gt_boxes): - validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name) - validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name) - validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name) - validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name) + validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name) + validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name) + validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name) + validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name) iou = [gt_boxes[0], anchor_boxes[0]] return iou diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 4a5e03accb..44b03d2acc 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -356,8 +356,8 @@ class RandomChoiceWithMask(PrimitiveWithInfer): Validator.check_value_type('seed2', seed2, [int], self.name) def infer_shape(self, x_shape): - Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) - Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) + Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name) + Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name) return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 2a91024b0d..a1e20e5a43 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -227,7 +227,7 @@ class PrimitiveWithCheck(Primitive): >>> def __init__(self): >>> pass >>> def check_shape(self, input_x): - >>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) + >>> validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) >>> >>> def check_dtype(self, input_x): >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 0d697cdef5..f9b6aa2d76 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -89,12 +89,12 @@ class ConvertToQuantNetwork: def __init__(self, **kwargs): self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) - self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) - self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) + self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") + self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") - self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) - self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) - self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) + self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") + self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit") + self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit") self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") diff --git a/mindspore/train/summary/_summary_adapter.py b/mindspore/train/summary/_summary_adapter.py index 4cbb7b8fd5..5b917713fa 100644 --- a/mindspore/train/summary/_summary_adapter.py +++ b/mindspore/train/summary/_summary_adapter.py @@ -21,7 +21,7 @@ from PIL import Image from mindspore import log as logger -from ..._checkparam import _check_str_by_regular +from ..._checkparam import Validator from ..anf_ir_pb2 import DataType, ModelProto from ..summary_pb2 import Event @@ -47,8 +47,8 @@ def get_event_file_name(prefix, suffix): Returns: String, the name of event log file. """ - _check_str_by_regular(prefix) - _check_str_by_regular(suffix) + Validator.check_str_by_regular(prefix) + Validator.check_str_by_regular(suffix) file_name = "" time_second = str(int(time.time())) hostname = platform.node() diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 9ffe6f1240..6f3cd7574d 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -21,7 +21,7 @@ import threading from mindspore import log as logger from ..._c_expression import Tensor -from ..._checkparam import _check_str_by_regular +from ..._checkparam import Validator from .._utils import _check_lineage_value, _check_to_numpy, _make_directory from ._summary_adapter import get_event_file_name, package_graph_event from ._writer_pool import WriterPool @@ -103,8 +103,8 @@ class SummaryRecord: self._closed, self._event_writer = False, None self._mode, self._data_pool = 'train', _dictlist() - _check_str_by_regular(file_prefix) - _check_str_by_regular(file_suffix) + Validator.check_str_by_regular(file_prefix) + Validator.check_str_by_regular(file_suffix) self.log_path = _make_directory(log_dir) diff --git a/tests/ut/python/nn/test_checkparameter.py b/tests/ut/python/nn/test_checkparameter.py index fc72a31bb6..6551083939 100644 --- a/tests/ut/python/nn/test_checkparameter.py +++ b/tests/ut/python/nn/test_checkparameter.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" test checkparameter """ +""" test check parameter """ import pytest import numpy as np -from mindspore._checkparam import twice, Validator +from mindspore._checkparam import Validator, twice kernel_size = 5 kernel_size1 = twice(kernel_size) diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 69719a5b28..1acd295b8d 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -18,7 +18,7 @@ import numpy as np import pytest from mindspore import context, Tensor, Parameter, ParameterTuple, nn -from mindspore._checkparam import _check_str_by_regular +from mindspore._checkparam import Validator from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer @@ -124,15 +124,15 @@ def test_check_str_by_regular(): str4 = ".12_sf.asdf" str5 = "12_sf.a$sdf." str6 = "12+sf.asdf" - _check_str_by_regular(str1) - _check_str_by_regular(str2) - _check_str_by_regular(str3) + Validator.check_str_by_regular(str1) + Validator.check_str_by_regular(str2) + Validator.check_str_by_regular(str3) with pytest.raises(ValueError): - _check_str_by_regular(str4) + Validator.check_str_by_regular(str4) with pytest.raises(ValueError): - _check_str_by_regular(str5) + Validator.check_str_by_regular(str5) with pytest.raises(ValueError): - _check_str_by_regular(str6) + Validator.check_str_by_regular(str6) def test_parameter_compute(): para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1')