[ME] delete reduant function in check_parameter

pull/7388/head
chenzomi 4 years ago
parent 5b769dfb20
commit acadb694aa

@ -97,7 +97,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
Check argument integer.
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
"""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
@ -166,12 +166,12 @@ class Validator:
return arg_value
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
"""
Checks input integer value `arg_value` compare to `value`.
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, rel, int, arg_name, prim_name)
@ -187,6 +187,16 @@ class Validator:
"""
return check_is_number(arg_value, int, arg_name, prim_name)
@staticmethod
def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
"""
Checks input integer value `arg_value` compare to `value`.
Usage:
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name)
@staticmethod
def check_positive_int(arg_value, arg_name=None, prim_name=None):
"""
@ -365,6 +375,17 @@ class Validator:
raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,'
f' but got `{arg_value}`.')
@staticmethod
def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
if reg is None:
# Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
target, prim_name, reg, flag))
return True
@staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name):
"""Validates value of padding according to pad_mode"""
@ -530,13 +551,6 @@ class Validator:
f'{tuple(exp_shape)}, but got {shape}.')
def check_int_zero_one(input_param):
"""Judge whether it is 0 or 1."""
if input_param in (0, 1):
return input_param
raise ValueError("The data must be 0 or 1.")
def check_input_format(input_param):
"""Judge input format."""
if input_param == "NCHW":
@ -544,27 +558,6 @@ def check_input_format(input_param):
raise ValueError("The data format must be NCHW.")
def check_padding(padding):
"""Check padding."""
if padding >= 0:
return padding
raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))
def check_padmode(mode):
"""Check padmode."""
if mode in ("same", "valid", "pad"):
return mode
raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))
def check_tensor_supported_type(dtype):
"""Check tensor dtype."""
if dtype in (mstype.int32, mstype.float32):
return dtype
raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))
def _expand_tuple(n_dimensions):
"""To expand a number to tuple."""
@ -673,42 +666,6 @@ def check_typename(arg_name, arg_type, valid_types):
f' but got {get_typename(arg_type)}.')
def check_shape(arg_name, arg_value):
"""Check shape."""
# First, check if shape is a tuple
if not isinstance(arg_value, tuple):
raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
f' but got {type(arg_value).__name__}.')
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
arg_value = np.array(arg_value)
# shape can not be ()
if arg_value.size == 0:
raise ValueError('Shape can not be empty.')
# shape's dimension should be 1
if arg_value.ndim != 1:
raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))
# Thirdly, check each element's type of the shape
valid_types = (int, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64)
for dim_size in arg_value:
if not isinstance(dim_size, valid_types) or dim_size <= 0:
raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
' but got {}.'.format(dim_size))
def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if reg is None:
# Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True
def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""

@ -19,7 +19,7 @@ from .._c_expression import ParamInfo
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from .._checkparam import Validator
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
@ -263,7 +263,7 @@ class Parameter(MetaTensor):
Returns:
Parameter, a new parameter.
"""
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
x = copy(self)
# pylint: disable=protected-access
x._param_info = self._param_info.clone()
@ -446,7 +446,7 @@ class ParameterTuple(tuple):
Returns:
Tuple, the new Parameter tuple.
"""
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
new = []
for x in self:
x1 = x.clone(prefix, init)

@ -23,7 +23,7 @@ from collections import namedtuple
from types import FunctionType
from mindspore import log as logger
from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check
from mindspore._checkparam import args_type_check, Validator
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
@ -35,9 +35,9 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut
GRAPH_MODE = 0
PYNATIVE_MODE = 1
# The max memory size of graph plus variable.
_DEVICE_APP_MEMORY_SIZE = 31
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
_k_context = None
def _make_directory(path):
"""Make directory."""
@ -223,7 +223,7 @@ class _Context:
def set_variable_memory_max_size(self, variable_memory_max_size):
"""set values of variable_memory_max_size and graph_memory_max_size"""
if not _check_input_format(variable_memory_max_size):
if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
@ -235,7 +235,7 @@ class _Context:
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
def set_max_device_memory(self, max_device_memory):
if not _check_input_format(max_device_memory):
if not Validator.check_str_by_regular(max_device_memory, _re_pattern):
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
max_device_memory_value = float(max_device_memory[:-2])
if max_device_memory_value == 0:
@ -294,16 +294,6 @@ class _Context:
thread_info.debug_runtime = enable
def _check_input_format(x):
import re
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result = re.match(pattern, x)
return result is not None
_k_context = None
def _context():
"""
Get the global _context, if context is not created, create a new one.

@ -23,7 +23,7 @@ from mindspore import log as logger
from .. import context
from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec
from .._checkparam import _check_str_by_regular
from .._checkparam import Validator
from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend, Cell_
from ..ops.primitive import Primitive
@ -715,7 +715,7 @@ class Cell(Cell_):
recurse (bool): Whether contains the parameters of subcells. Default: True.
"""
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
for name, param in self.parameters_and_names(expand=recurse):
if prefix != '':
param.is_init = False

@ -549,7 +549,7 @@ class Unfold(Cell):
@constexpr
def _get_matrix_diag_assist(x_shape, x_dtype):
Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
return Tensor(assist, x_dtype)
@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype):
@constexpr
def _get_matrix_diag_part_assist(x_shape, x_dtype):
Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist")
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
return Tensor(assist, x_dtype)

@ -239,8 +239,8 @@ class Conv2d(_Conv):
"""Initialize depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
Validator.check_equal_int(self.group, self.in_channels, 'group')
Validator.check_equal_int(self.group, self.out_channels, 'group')
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
@ -384,10 +384,10 @@ class Conv1d(_Conv):
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name)
Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
@ -395,7 +395,7 @@ class Conv1d(_Conv):
get_dtype = P.DType()
if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name)
weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2)
@ -539,7 +539,7 @@ class Conv2dTranspose(_Conv):
dilation = twice(dilation)
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
if isinstance(padding, tuple):
Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(padding), 4, 'padding size', self.cls_name)
# out_channels and in_channels swap.
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
# then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
@ -703,10 +703,10 @@ class Conv1dTranspose(_Conv):
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name)
Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
@ -714,7 +714,7 @@ class Conv1dTranspose(_Conv):
get_dtype = P.DType()
if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name)
weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2)

@ -220,7 +220,7 @@ class SSIM(Cell):
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
@ -298,7 +298,7 @@ class MSSSIM(Cell):
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)

@ -190,8 +190,8 @@ class MaxPool1d(_PoolNd):
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.max_pool = P.MaxPool(ksize=self.kernel_size,
@ -349,8 +349,8 @@ class AvgPool1d(_PoolNd):
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size,

@ -323,7 +323,7 @@ class FakeQuantWithMinMax(Cell):
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell):
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.bias = None
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,

@ -931,19 +931,19 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
hx_shape, cx_shape, reserve_shape, state_shape):
# dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
# dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
# (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size)
@ -1015,19 +1015,19 @@ class LSTMGrad(PrimitiveWithInfer):
def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
dcy_shape, reserve_shape):
# dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
# dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
# (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size)
@ -1069,7 +1069,7 @@ class DynamicRNNGrad(PrimitiveWithInfer):
def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4
if w_shape[-1] % 4 != 0:
@ -1575,7 +1575,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
# dhy and dcy should be same shape
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
@ -1624,7 +1624,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
self.add_prim_attr("io_format", "HWCN")
def infer_shape(self, x_shape, h_shape, dgate_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
@ -1656,8 +1656,8 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
self.add_prim_attr("io_format", "ND")
def infer_shape(self, dgate_shape, w_shape):
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
batch_size = dgate_shape[0]
hidden_size = dgate_shape[1] // 4

@ -347,7 +347,7 @@ class MatrixDiag(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape, assist_shape):
validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name)
validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
validator.check('rank of x', len(x_shape)+1,
'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
@ -395,7 +395,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
if assist_shape[-2] < assist_shape[-1]:
@ -438,7 +438,7 @@ class MatrixSetDiag(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape, diagonal_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
if x_shape[-2] < x_shape[-1]:

@ -81,11 +81,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -147,11 +146,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -228,9 +226,9 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -284,8 +282,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
@ -375,14 +372,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
if len(x_shape) == 1:
self.channel_axis = 0
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -501,7 +496,7 @@ class BatchNormFold(PrimitiveWithInfer):
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return mean_shape, mean_shape, mean_shape, mean_shape
def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
@ -548,7 +543,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
"batch_std shape", batch_std_shape, Rel.EQ, self.name)
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
"input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return x_shape
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
@ -723,7 +718,7 @@ class BatchNormFold2(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return x_shape
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
@ -771,7 +766,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
def infer_dtype(self, dout_type, x_type, gamma_type,

@ -520,7 +520,7 @@ class Im2Col(PrimitiveWithInfer):
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
kernel_size_h = self.kernel_size[0]
kernel_size_w = self.kernel_size[1]
stride_h = self.stride[2]

@ -583,8 +583,8 @@ class Transpose(PrimitiveWithInfer):
tmp = list(p_value)
for i, dim in enumerate(p_value):
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name)
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name)
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
tmp.remove(dim)
if dim in tmp:
raise ValueError('The value of perm is wrong.')
@ -725,8 +725,8 @@ class Padding(PrimitiveWithInfer):
def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
validator.check_integer("rank of x", len(x_shape), 1, Rel.GT, self.name)
validator.check_integer("last dim of x", x_shape[-1], 1, Rel.EQ, self.name)
validator.check_int(len(x_shape), 1, Rel.GT, "rank of x", self.name)
validator.check_int(x_shape[-1], 1, Rel.EQ, "last dim of x", self.name)
out_shape = x_shape
out_shape[-1] = self.pad_dim_size
out = {'shape': out_shape,
@ -1575,7 +1575,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
@ -1628,7 +1628,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
@ -1730,7 +1730,7 @@ class ParallelConcat(PrimitiveWithInfer):
x_shp = values['shape']
x_type = values['dtype']
validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name)
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
@ -1738,7 +1738,7 @@ class ParallelConcat(PrimitiveWithInfer):
first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]):
j = i + 1
validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name)
validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name)
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
ret_shp = x_shp[0].copy()
@ -1755,7 +1755,7 @@ class ParallelConcat(PrimitiveWithInfer):
def _get_pack_shape(x_shape, x_type, axis, prim_name):
"""for pack output shape"""
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name)
validator.check_int(len(x_shape), 1, Rel.GE, "len of input_x", prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
rank_base = len(x_shape[0])
N = len(x_shape)
@ -1871,8 +1871,8 @@ class Unpack(PrimitiveWithInfer):
validator.check_positive_int(output_num, "output_num", self.name)
self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
self.name)
validator.check_int(output_valid_check, 0, Rel.EQ,
"The dimension which to unpack divides output_num", self.name)
out_shapes = []
out_dtypes = []
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
@ -2523,7 +2523,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
"""Initialize ResizeNearestNeighbor"""
validator.check_value_type("size", size, [tuple, list], self.name)
validator.check_value_type("align_corners", align_corners, [bool], self.name)
validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name)
validator.check_equal_int(len(size), 2, "length of size", self.name)
for i, value in enumerate(size):
validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
@ -3134,9 +3134,8 @@ class DepthToSpace(PrimitiveWithInfer):
for i in range(2):
out_shape[i + 2] *= self.block_size
validator.check_integer('x_shape[1] % (block_size*block_size)',
x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, self.name)
validator.check_int(x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, 'x_shape[1] % (block_size*block_size)', self.name)
out_shape[1] //= self.block_size * self.block_size
return out_shape
@ -3205,7 +3204,7 @@ class SpaceToBatch(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape):
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, 'rank of input_x', self.name)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1]
@ -3367,7 +3366,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
def infer_shape(self, x_shape):
x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
out_shape = copy.deepcopy(x_shape)
block_shape_prod = 1
@ -3460,7 +3459,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
def infer_shape(self, x_shape):
x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
validator.check_int(x_rank, 4, Rel.EQ, 'x_shape rank', self.name)
out_shape = copy.deepcopy(x_shape)
block_shape_prod = 1
@ -3607,11 +3606,11 @@ class Meshgrid(PrimitiveWithInfer):
def infer_shape(self, x_shape):
validator.check_value_type("shape", x_shape, [tuple, list], self.name)
validator.check_integer("len of input_x", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "len of input_x", self.name)
n = len(x_shape)
shape_0 = []
for s in x_shape:
validator.check_integer('each_input_rank', len(s), 1, Rel.EQ, self.name)
validator.check_int(len(s), 1, Rel.EQ, 'each_input_rank', self.name)
shape_0.append(s[0])
if self.indexing == "xy":
shape_0[0], shape_0[1] = shape_0[1], shape_0[0]

@ -204,7 +204,7 @@ class _HostAllGather(PrimitiveWithInfer):
if group is None:
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
for r in group:
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name)
@ -313,7 +313,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
for r in group:
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name)

@ -126,7 +126,7 @@ class GeSwitch(PrimitiveWithInfer):
raise NotImplementedError
def infer_shape(self, data, pred):
validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name)
validator.check_equal_int(len(pred), 0, "pred rank", self.name)
return (data, data)
def infer_dtype(self, data_type, pred_type):

@ -374,9 +374,9 @@ class Assert(PrimitiveWithInfer):
def infer_shape(self, condition, inputs):
condition_len = len(condition)
validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name)
validator.check_int(condition_len, 1, Rel.LE, "condition's rank", self.name)
if condition_len == 1:
validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name)
validator.check_equal_int(condition[0], 1, "condition[0]", self.name)
return [1]
def infer_dtype(self, condition, inputs):

@ -17,7 +17,6 @@
import numbers
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, PrimitiveWithInfer
@ -43,7 +42,7 @@ class ScalarCast(PrimitiveWithInfer):
pass
def __infer__(self, x, t):
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name)
value, to = x['value'], t['value']
if value is not None:
validator.check_value_type("value", value, [numbers.Number, bool], self.name)

@ -827,7 +827,7 @@ class AddN(PrimitiveWithInfer):
def infer_shape(self, inputs):
cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
self.add_prim_attr('n', len(inputs))
shp0 = inputs[0]
for i, shp in enumerate(inputs):
@ -837,7 +837,7 @@ class AddN(PrimitiveWithInfer):
def infer_dtype(self, inputs):
cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
args = {}
contains_undetermined = False
for i, dtype in enumerate(inputs):
@ -910,7 +910,7 @@ class AccumulateNV2(PrimitiveWithInfer):
def infer_shape(self, inputs):
cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
self.add_prim_attr('n', len(inputs))
shp0 = inputs[0]
for i, shp in enumerate(inputs):
@ -920,7 +920,7 @@ class AccumulateNV2(PrimitiveWithInfer):
def infer_dtype(self, inputs):
cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
args = {}
for i, dtype in enumerate(inputs):
args[f"inputs[{i}]"] = dtype
@ -1488,7 +1488,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nbins, dtype='int32'):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
validator.check_int(nbins, 1, Rel.GE, "nbins", self.name)
valid_values = ['int32', 'int64']
self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
@ -2810,8 +2810,8 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
def infer_shape(self, x_shape):
cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name)
validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name)
return [8]
def infer_dtype(self, x_dtype):
@ -2853,8 +2853,8 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
def infer_shape(self, x_shape):
cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name)
validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name)
return [8]
def infer_dtype(self, x_dtype):
@ -3023,9 +3023,9 @@ class NMSWithMask(PrimitiveWithInfer):
def infer_shape(self, bboxes_shape):
cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_equal_int(len(bboxes_shape), 2, "bboxes rank", cls_name)
validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name)
validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
validator.check_equal_int(bboxes_shape[1], 5, "bboxes.shape[1]", cls_name)
num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,))
@ -3572,11 +3572,11 @@ class IFMR(PrimitiveWithInfer):
validator.check_value_type("offset_flag", with_offset, [bool], self.name)
def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
validator.check_integer("dims of data_min", len(data_min_shape), 1, Rel.EQ, self.name)
validator.check_integer("data_min[0]", data_min_shape[0], 1, Rel.EQ, self.name)
validator.check_integer("dims of data_max", len(data_max_shape), 1, Rel.EQ, self.name)
validator.check_integer("data_max[0]", data_max_shape[0], 1, Rel.EQ, self.name)
validator.check_integer("dims of cumsum", len(cumsum_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
return (1,), (1,)
def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):

File diff suppressed because it is too large Load Diff

@ -98,16 +98,16 @@ class BoundingBoxEncode(PrimitiveWithInfer):
validator.check_value_type("means[%d]" % i, value, [float], self.name)
for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
validator.check_equal_int(len(means), 4, "means len", self.name)
validator.check_equal_int(len(stds), 4, "stds len", self.name)
def infer_shape(self, anchor_box, groundtruth_box):
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ,
self.name)
validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name)
validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
validator.check_equal_int(groundtruth_box[1], 4, 'groundtruth_box shape[1]', self.name)
return anchor_box
def infer_dtype(self, anchor_box, groundtruth_box):
@ -153,18 +153,18 @@ class BoundingBoxDecode(PrimitiveWithInfer):
for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
validator.check_equal_int(len(means), 4, "means len", self.name)
validator.check_equal_int(len(stds), 4, "stds len", self.name)
if max_shape is not None:
validator.check_value_type('max_shape', max_shape, [tuple], self.name)
validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name)
def infer_shape(self, anchor_box, deltas):
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name)
validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name)
validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
validator.check_equal_int(deltas[1], 4, 'deltas shape[1]', self.name)
return anchor_box
def infer_dtype(self, anchor_box, deltas):
@ -272,10 +272,10 @@ class IOU(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
def infer_shape(self, anchor_boxes, gt_boxes):
validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name)
validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name)
validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name)
validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name)
validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name)
validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name)
iou = [gt_boxes[0], anchor_boxes[0]]
return iou

@ -356,8 +356,8 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
Validator.check_value_type('seed2', seed2, [int], self.name)
def infer_shape(self, x_shape):
Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name)
return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype):

@ -227,7 +227,7 @@ class PrimitiveWithCheck(Primitive):
>>> def __init__(self):
>>> pass
>>> def check_shape(self, input_x):
>>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
>>> validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
>>>
>>> def check_dtype(self, input_x):
>>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name)

@ -89,12 +89,12 @@ class ConvertToQuantNetwork:
def __init__(self, **kwargs):
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit")
self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit")
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")

@ -21,7 +21,7 @@ from PIL import Image
from mindspore import log as logger
from ..._checkparam import _check_str_by_regular
from ..._checkparam import Validator
from ..anf_ir_pb2 import DataType, ModelProto
from ..summary_pb2 import Event
@ -47,8 +47,8 @@ def get_event_file_name(prefix, suffix):
Returns:
String, the name of event log file.
"""
_check_str_by_regular(prefix)
_check_str_by_regular(suffix)
Validator.check_str_by_regular(prefix)
Validator.check_str_by_regular(suffix)
file_name = ""
time_second = str(int(time.time()))
hostname = platform.node()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save