[ME] delete check_bool and replace with Validate.check_bool

pull/7110/head
chenzomi 5 years ago
parent 6c9b6d491d
commit d4e8e94981

@ -26,10 +26,6 @@ from mindspore import log as logger
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
# Named string regular expression
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
class Rel(Enum): class Rel(Enum):
"""Numerical relationship between variables, logical relationship enumeration definition of range.""" """Numerical relationship between variables, logical relationship enumeration definition of range."""
# scalar compare # scalar compare
@ -114,7 +110,7 @@ class Validator:
@staticmethod @staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name=None): def check_integer(arg_name, arg_value, value, rel, prim_name=None):
"""Integer value judgment.""" """Check argument is integer"""
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError excp_cls = TypeError if type_mismatch else ValueError
@ -125,6 +121,7 @@ class Validator:
f' with type `{type(arg_value).__name__}`.') f' with type `{type(arg_value).__name__}`.')
return arg_value return arg_value
@staticmethod @staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name): def check_number(arg_name, arg_value, value, rel, prim_name):
"""Number value judgment.""" """Number value judgment."""
@ -142,10 +139,11 @@ class Validator:
return arg_value return arg_value
@staticmethod @staticmethod
def check_bool(arg_name, arg_value): def check_bool(arg_value, arg_name=None):
"""Check arg isinstance of bool""" """Check argument is instance of bool"""
if not isinstance(arg_value, bool): if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') arg_name = arg_name if arg_name else "Parameter"
raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.')
return arg_value return arg_value
@staticmethod @staticmethod
@ -170,15 +168,14 @@ class Validator:
return arg_value return arg_value
@staticmethod @staticmethod
def check_string(arg_name, arg_value, valid_values, prim_name): def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
"""Checks whether a string is in some value list""" """Checks whether a string is in some value list"""
if isinstance(arg_value, str) and arg_value in valid_values: if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value return arg_value
if len(valid_values) == 1: arg_name = arg_name if arg_name else "Parameter"
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
f' but got {arg_value}.') raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,'
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' f' but got `{arg_value}`.')
f' but got {arg_value}.')
@staticmethod @staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name): def check_pad_value_by_mode(pad_mode, padding, prim_name):
@ -404,24 +401,6 @@ def check_int_zero_one(input_param):
raise ValueError("The data must be 0 or 1.") raise ValueError("The data must be 0 or 1.")
def check_bool(input_param):
"""Bool type judgment."""
if isinstance(input_param, bool):
return input_param
raise TypeError("Input type must be bool!")
def check_string(input_param, valid_values):
"""String type judgment."""
if isinstance(input_param, str) and input_param in valid_values:
return input_param
if len(valid_values) == 1:
raise ValueError(f'Input should be str and must be {valid_values[0]},'
f' but got {input_param}.')
raise ValueError(f'Input should be str and must be one of {valid_values},'
f' but got {input_param}.')
def check_input_format(input_param): def check_input_format(input_param):
"""Judge input format.""" """Judge input format."""
if input_param == "NCHW": if input_param == "NCHW":
@ -587,7 +566,8 @@ def check_shape(arg_name, arg_value):
def _check_str_by_regular(target, reg=None, flag=re.ASCII): def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if reg is None: if reg is None:
reg = _name_re # Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None: if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True return True

@ -27,7 +27,7 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator as validator, check_int_positive, check_bool from mindspore._checkparam import Rel, Validator, check_int_positive
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import context from mindspore import context
from ..cell import Cell from ..cell import Cell
@ -86,8 +86,8 @@ class Dropout(Cell):
super(Dropout, self).__init__() super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1: if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob self.keep_prob = keep_prob
seed0 = get_seed() seed0 = get_seed()
self.seed0 = seed0 if seed0 is not None else 0 self.seed0 = seed0 if seed0 is not None else 0
@ -205,7 +205,7 @@ class Dense(Cell):
super(Dense, self).__init__() super(Dense, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -348,7 +348,7 @@ class Norm(Cell):
def __init__(self, axis=(), keep_dims=False): def __init__(self, axis=(), keep_dims=False):
super(Norm, self).__init__() super(Norm, self).__init__()
validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name) Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
self.axis = axis self.axis = axis
self.keep_dims = keep_dims self.keep_dims = keep_dims
self.reduce_sum = P.ReduceSum(True) self.reduce_sum = P.ReduceSum(True)
@ -472,7 +472,7 @@ class Pad(Cell):
super(Pad, self).__init__() super(Pad, self).__init__()
self.mode = mode self.mode = mode
self.paddings = paddings self.paddings = paddings
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name) Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
if not isinstance(paddings, tuple): if not isinstance(paddings, tuple):
raise TypeError('Paddings must be tuple type.') raise TypeError('Paddings must be tuple type.')
for item in paddings: for item in paddings:
@ -549,7 +549,7 @@ class Unfold(Cell):
@constexpr @constexpr
def _get_matrix_diag_assist(x_shape, x_dtype): def _get_matrix_diag_assist(x_shape, x_dtype):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist") Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
return Tensor(assist, x_dtype) return Tensor(assist, x_dtype)
@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype):
@constexpr @constexpr
def _get_matrix_diag_part_assist(x_shape, x_dtype): def _get_matrix_diag_part_assist(x_shape, x_dtype):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist") Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
return Tensor(assist, x_dtype) return Tensor(assist, x_dtype)

@ -21,7 +21,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive from mindspore._checkparam import Validator, Rel, twice, check_int_positive
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
@ -92,7 +92,7 @@ class _Conv(Cell):
shape = [out_channels, in_channels // group, *kernel_size] shape = [out_channels, in_channels // group, *kernel_size]
self.weight = Parameter(initializer(self.weight_init, shape), name='weight') self.weight = Parameter(initializer(self.weight_init, shape), name='weight')
if check_bool(has_bias): if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias') self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
else: else:
if self.bias_init != 'zeros': if self.bias_init != 'zeros':
@ -566,7 +566,7 @@ class Conv2dTranspose(_Conv):
self.is_valid = self.pad_mode == 'valid' self.is_valid = self.pad_mode == 'valid'
self.is_same = self.pad_mode == 'same' self.is_same = self.pad_mode == 'same'
self.is_pad = self.pad_mode == 'pad' self.is_pad = self.pad_mode == 'pad'
if check_bool(has_bias): if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel. # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
@ -745,7 +745,7 @@ class Conv1dTranspose(_Conv):
self.is_valid = self.pad_mode == 'valid' self.is_valid = self.pad_mode == 'valid'
self.is_same = self.pad_mode == 'same' self.is_same = self.pad_mode == 'same'
self.is_pad = self.pad_mode == 'pad' self.is_pad = self.pad_mode == 'pad'
if check_bool(has_bias): if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel. # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.

@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
import mindspore.context as context import mindspore.context as context
from mindspore._checkparam import check_bool, check_typename, check_int_positive from mindspore._checkparam import Validator, check_typename, check_int_positive
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management from mindspore.communication import management
@ -604,7 +604,7 @@ class GroupNorm(Cell):
if num_channels % num_groups != 0: if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups") raise ValueError("num_channels should be divided by num_groups")
self.eps = check_typename('eps', eps, (float,)) self.eps = check_typename('eps', eps, (float,))
self.affine = check_bool(affine) self.affine = Validator.check_bool(affine)
gamma = initializer(gamma_init, num_channels) gamma = initializer(gamma_init, num_channels)
beta = initializer(beta_init, num_channels) beta = initializer(beta_init, num_channels)

@ -27,7 +27,7 @@ class _PoolNd(Cell):
def __init__(self, kernel_size, stride, pad_mode): def __init__(self, kernel_size, stride, pad_mode):
super(_PoolNd, self).__init__() super(_PoolNd, self).__init__()
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
def _check_int_or_tuple(arg_name, arg_value): def _check_int_or_tuple(arg_name, arg_value):
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name) validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
@ -270,7 +270,7 @@ class AvgPool1d(_PoolNd):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name) validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name) validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name) validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
self.kernel_size = (1, kernel_size) self.kernel_size = (1, kernel_size)

@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator from mindspore._checkparam import Validator, Rel, check_int_positive, twice
import mindspore.context as context import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU from .activation import get_activation, ReLU, LeakyReLU
@ -133,7 +133,7 @@ class Conv2dBnAct(Cell):
has_bias=has_bias, has_bias=has_bias,
weight_init=weight_init, weight_init=weight_init,
bias_init=bias_init) bias_init=bias_init)
self.has_bn = Validator.check_bool("has_bn", has_bn) self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None self.has_act = activation is not None
self.after_fake = after_fake self.after_fake = after_fake
if has_bn: if has_bn:
@ -201,7 +201,7 @@ class DenseBnAct(Cell):
weight_init, weight_init,
bias_init, bias_init,
has_bias) has_bias)
self.has_bn = Validator.check_bool("has_bn", has_bn) self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None self.has_act = activation is not None
self.after_fake = after_fake self.after_fake = after_fake
if has_bn: if has_bn:
@ -511,7 +511,7 @@ class Conv2dBnFoldQuant(Cell):
channel_axis = 0 channel_axis = 0
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if check_bool(has_bias): if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else: else:
self.bias = None self.bias = None
@ -668,7 +668,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if check_bool(has_bias): if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else: else:
self.bias = None self.bias = None
@ -799,7 +799,7 @@ class Conv2dQuant(Cell):
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if check_bool(has_bias): if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else: else:
self.bias = None self.bias = None
@ -888,7 +888,7 @@ class DenseQuant(Cell):
super(DenseQuant, self).__init__() super(DenseQuant, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \

@ -18,8 +18,7 @@ from mindspore.ops import _selected_ops
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool from mindspore._checkparam import Validator
from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer from .optimizer import Optimizer
_momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@ -126,12 +125,12 @@ class Momentum(Optimizer):
""" """
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("momentum", momentum, [float], self.cls_name) Validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov) self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov) self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)

@ -15,7 +15,7 @@
"""dense_variational""" """dense_variational"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool from mindspore._checkparam import check_int_positive, Validator
from ...cell import Cell from ...cell import Cell
from ...layer.activation import get_activation from ...layer.activation import get_activation
from .layer_distribution import NormalPrior, NormalPosterior from .layer_distribution import NormalPrior, NormalPosterior
@ -41,7 +41,7 @@ class _DenseVariational(Cell):
super(_DenseVariational, self).__init__() super(_DenseVariational, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_prior_fn, Cell): if isinstance(weight_prior_fn, Cell):
self.weight_prior = weight_prior_fn self.weight_prior = weight_prior_fn

@ -16,7 +16,7 @@
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from mindspore._checkparam import check_int_positive, check_bool from mindspore._checkparam import check_int_positive, Validator
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model from mindspore.train import Model
@ -84,7 +84,7 @@ class UncertaintyEvaluation:
self.epochs = check_int_positive(epochs) self.epochs = check_int_positive(epochs)
self.epi_uncer_model_path = epi_uncer_model_path self.epi_uncer_model_path = epi_uncer_model_path
self.ale_uncer_model_path = ale_uncer_model_path self.ale_uncer_model_path = ale_uncer_model_path
self.save_model = check_bool(save_model) self.save_model = Validator.check_bool(save_model)
self.epi_uncer_model = None self.epi_uncer_model = None
self.ale_uncer_model = None self.ale_uncer_model = None
self.concat = P.Concat(axis=0) self.concat = P.Concat(axis=0)

@ -216,7 +216,7 @@ class KLDivLossGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape, doutput_shape): def infer_shape(self, x_shape, y_shape, doutput_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -233,7 +233,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape): def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -609,7 +609,7 @@ class _PoolGrad(PrimitiveWithInfer):
validator.check_value_type('ksize', ksize, [int, tuple], self.name) validator.check_value_type('ksize', ksize, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name) validator.check_value_type('strides', strides, [int, tuple], self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding) self.add_prim_attr("padding", self.padding)
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax") self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
if not self.is_maxpoolgradwithargmax: if not self.is_maxpoolgradwithargmax:
@ -1457,7 +1457,7 @@ class MirrorPadGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, mode="REFLECT"): def __init__(self, mode="REFLECT"):
"""Initialize MirrorPad""" """Initialize MirrorPad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
self.mode = mode self.mode = mode
def __infer__(self, dout, paddings): def __infer__(self, dout, paddings):
@ -1570,7 +1570,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, forget_bias, activation): def __init__(self, forget_bias, activation):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):

@ -67,7 +67,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
_check_tuple_or_list("ksize", ksizes, self.name) _check_tuple_or_list("ksize", ksizes, self.name)
_check_tuple_or_list("stride", strides, self.name) _check_tuple_or_list("stride", strides, self.name)
_check_tuple_or_list("rate", rates, self.name) _check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding) self.add_prim_attr("padding", self.padding)
self.add_prim_attr("io_format", "NHWC") self.add_prim_attr("io_format", "NHWC")
self.is_ge = context.get_context("enable_ge") self.is_ge = context.get_context("enable_ge")
@ -206,8 +206,8 @@ class Quant(PrimitiveWithInfer):
self.scale = validator.check_value_type("scale", scale, [float], self.name) self.scale = validator.check_value_type("scale", scale, [float], self.name)
self.offset = validator.check_value_type("offset", offset, [float], self.name) self.offset = validator.check_value_type("offset", offset, [float], self.name)
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
self.round_mode = validator.check_string("round_mode", round_mode, self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
["Round", "Floor", "Ceil", "Trunc"], self.name) "round_mode", self.name)
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape): def infer_shape(self, x_shape):

@ -513,7 +513,7 @@ class Im2Col(PrimitiveWithInfer):
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation) self.add_prim_attr('dilation', self.dilation)
validator.check_value_type('pad', pad, (int,), self.name) validator.check_value_type('pad', pad, (int,), self.name)
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
if self.pad_mode == 'pad': if self.pad_mode == 'pad':
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)

@ -82,7 +82,7 @@ class CropAndResize(PrimitiveWithInfer):
"""Initialize CropAndResize""" """Initialize CropAndResize"""
self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y'])
validator.check_value_type("method", method, [str], self.name) validator.check_value_type("method", method, [str], self.name)
validator.check_string("method", method, ["bilinear", "nearest", "bilinear_v2"], self.name) validator.check_string(method, ["bilinear", "nearest", "bilinear_v2"], "method", self.name)
self.method = method self.method = method
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name) validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
self.extrapolation_value = extrapolation_value self.extrapolation_value = extrapolation_value

@ -1484,7 +1484,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name) self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name) validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
valid_values = ['int32', 'int64'] valid_values = ['int32', 'int64']
self.dtype = validator.check_string("dtype", dtype, valid_values, self.name) self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
def infer_shape(self, x_shape, range_shape): def infer_shape(self, x_shape, range_shape):

@ -995,7 +995,7 @@ class Conv2D(PrimitiveWithInfer):
else: else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.padding = pad self.padding = pad
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0): if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
@ -1134,7 +1134,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
else: else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.padding = pad self.padding = pad
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0): if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad': if self.pad_mode == 'pad':
@ -1216,7 +1216,7 @@ class _Pool(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('ksize', ksize, [int, tuple], self.name) validator.check_value_type('ksize', ksize, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name) validator.check_value_type('strides', strides, [int, tuple], self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding) self.add_prim_attr("padding", self.padding)
self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax") self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
if not self.is_maxpoolwithargmax: if not self.is_maxpoolwithargmax:
@ -1521,7 +1521,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
else: else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.padding = pad self.padding = pad
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0): if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad': if self.pad_mode == 'pad':
@ -1942,8 +1942,8 @@ class DataFormatDimMap(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, src_format='NHWC', dst_format='NCHW'): def __init__(self, src_format='NHWC', dst_format='NCHW'):
valid_values = ['NHWC', 'NCHW'] valid_values = ['NHWC', 'NCHW']
self.src_format = validator.check_string("src_format", src_format, valid_values, self.name) self.src_format = validator.check_string(src_format, valid_values, "src_format", self.name)
self.dst_format = validator.check_string("dst_format", dst_format, valid_values, self.name) self.dst_format = validator.check_string(dst_format, valid_values, "dst_format", self.name)
self.init_prim_io_names(inputs=['input_x'], outputs=['output']) self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -2961,7 +2961,7 @@ class MirrorPad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, mode='REFLECT'): def __init__(self, mode='REFLECT'):
"""Initialize Pad""" """Initialize Pad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
self.mode = mode self.mode = mode
self.set_const_input_indexes([1]) self.set_const_input_indexes([1])
@ -3651,7 +3651,7 @@ class KLDivLoss(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape): def infer_shape(self, x_shape, y_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -3727,7 +3727,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape, weight_shape): def infer_shape(self, x_shape, y_shape, weight_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -5487,7 +5487,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
@ -5605,9 +5605,9 @@ class DynamicRNN(PrimitiveWithInfer):
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name) self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
self.cell_type = validator.check_string("cell_type", cell_type, ['LSTM'], self.name) self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
self.direction = validator.check_string("direction", direction, ['UNIDIRECTIONAL'], self.name) self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
@ -5720,7 +5720,7 @@ class LRN(PrimitiveWithInfer):
validator.check_value_type("alpha", alpha, [float], self.name) validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_value_type("beta", beta, [float], self.name) validator.check_value_type("beta", beta, [float], self.name)
validator.check_value_type("norm_region", norm_region, [str], self.name) validator.check_value_type("norm_region", norm_region, [str], self.name)
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name) validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name)
validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name) validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name)
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):

@ -21,7 +21,7 @@ import time
import threading import threading
import mindspore.context as context import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative from mindspore._checkparam import Validator, check_int_non_negative
from mindspore.train._utils import _make_directory from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
@ -132,8 +132,8 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1 self._keep_checkpoint_max = 1
self._integrated_save = check_bool(integrated_save) self._integrated_save = Validator.check_bool(integrated_save)
self._async_save = check_bool(async_save) self._async_save = Validator.check_bool(async_save)
@property @property
def save_checkpoint_steps(self): def save_checkpoint_steps(self):

@ -16,7 +16,7 @@
import math import math
import os import os
from mindspore._checkparam import check_bool, check_int from mindspore._checkparam import Validator, check_int
from .. import context, nn from .. import context, nn
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
from ..nn.wrap import GetNextSingleOp from ..nn.wrap import GetNextSingleOp
@ -123,7 +123,7 @@ class DatasetHelper:
""" """
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size) check_int(sink_size)
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

@ -22,7 +22,7 @@ import numpy as np
from mindspore import log as logger from mindspore import log as logger
from ..common.tensor import Tensor from ..common.tensor import Tensor
from ..nn.metrics import get_metrics from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int from .._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
from .callback import _InternalCallbackParam, RunContext, _CallbackManager from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
@ -548,7 +548,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset) >>> model.train(2, dataset)
""" """
check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
if sink_size == -1: if sink_size == -1:
sink_size = train_dataset.get_dataset_size() sink_size = train_dataset.get_dataset_size()
check_int(sink_size) check_int(sink_size)
@ -664,7 +664,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset) >>> model.eval(dataset)
""" """
check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number) _device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns: if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.") raise ValueError("metric fn can not be None or empty.")

@ -22,8 +22,7 @@ import mindspore.context as context
from ... import log as logger from ... import log as logger
from ... import nn, ops from ... import nn, ops
from ..._checkparam import Validator from ..._checkparam import Validator, Rel
from ..._checkparam import Rel
from ...common import Tensor from ...common import Tensor
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.api import _executor from ...common.api import _executor
@ -92,16 +91,16 @@ class ConvertToQuantNetwork:
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"]) self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0]) self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1]) self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0]) self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1]) self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0]) self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1]) self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense} quant.DenseBnAct: self._convert_dense}

@ -15,7 +15,7 @@
"""Dataset help for minddata dataset""" """Dataset help for minddata dataset"""
import math import math
import os import os
from mindspore._checkparam import check_bool, check_int from mindspore._checkparam import Validator, check_int
from mindspore import context from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.nn.wrap import GetNextSingleOp from mindspore.nn.wrap import GetNextSingleOp
@ -61,7 +61,7 @@ class DatasetHelper:
""" """
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size) check_int(sink_size)
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

@ -18,8 +18,7 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool from mindspore._checkparam import Validator
from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from src.grad_reducer_thor import DistributedGradReducerThor from src.grad_reducer_thor import DistributedGradReducerThor
@ -53,12 +52,12 @@ class THOR_GPU(Optimizer):
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []): weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale) super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("momentum", momentum, [float], self.cls_name) Validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov) self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)

@ -16,7 +16,7 @@
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, twice, check_int_positive from mindspore._checkparam import Validator, twice, check_int_positive
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
@ -111,7 +111,7 @@ class _Conv(Cell):
self.weight = Parameter(initializer( self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight') weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight')
if check_bool(has_bias): if Validator.check_bool(has_bias):
self.bias = Parameter(_initializer( self.bias = Parameter(_initializer(
bias_init, [out_channels]), name='bias') bias_init, [out_channels]), name='bias')
else: else:
@ -294,7 +294,7 @@ class Dense_Thor_GPU(Cell):
super(Dense_Thor_GPU, self).__init__() super(Dense_Thor_GPU, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
self.thor = True self.thor = True
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -643,7 +643,7 @@ class Dense_Thor(Cell):
super(Dense_Thor, self).__init__() super(Dense_Thor, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
self.thor = True self.thor = True
self.batch_size = batch_size self.batch_size = batch_size
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import check_int_positive, check_bool from mindspore._checkparam import check_int_positive, Validator
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
@ -74,7 +74,7 @@ class GNNFeatureTransform(nn.Cell):
super(GNNFeatureTransform, self).__init__() super(GNNFeatureTransform, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -284,7 +284,7 @@ class AttentionHead(nn.Cell):
self.matmul = P.MatMul() self.matmul = P.MatMul()
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = check_bool(residual) self.residual = Validator.check_bool(residual)
if self.residual: if self.residual:
if in_channel != out_channel: if in_channel != out_channel:
self.residual_transform_flag = True self.residual_transform_flag = True
@ -458,7 +458,7 @@ class GAT(nn.Cell):
self.attn_drop = attn_drop self.attn_drop = attn_drop
self.ftr_drop = ftr_drop self.ftr_drop = ftr_drop
self.activation = activation self.activation = activation
self.residual = check_bool(residual) self.residual = Validator.check_bool(residual)
self.layers = [] self.layers = []
# first layer # first layer
self.layers.append(AttentionAggregator( self.layers.append(AttentionAggregator(

@ -16,7 +16,7 @@
import os import os
from mindspore import context from mindspore import context
from mindspore._checkparam import check_bool, check_int from mindspore._checkparam import Validator, check_int
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
@ -58,7 +58,7 @@ class DatasetHelper:
""" """
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=0): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=0):
check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size) check_int(sink_size)
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

@ -22,7 +22,7 @@ from mindspore._c_expression import init_exec_dataset
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore import nn from mindspore import nn
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -603,7 +603,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset) >>> model.train(2, dataset)
""" """
check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size) check_int(sink_size)
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
@ -718,7 +718,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset) >>> model.eval(dataset)
""" """
check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number) _device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns: if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.") raise ValueError("metric fn can not be None or empty.")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save