|
|
|
@ -149,7 +149,7 @@ class Softmax(PrimitiveWithInfer):
|
|
|
|
|
validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name)
|
|
|
|
|
rank = len(logits)
|
|
|
|
|
for axis_v in self.axis:
|
|
|
|
|
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
|
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, logits):
|
|
|
|
@ -193,7 +193,7 @@ class LogSoftmax(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, logits):
|
|
|
|
|
rank = len(logits)
|
|
|
|
|
validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name)
|
|
|
|
|
validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name)
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, logits):
|
|
|
|
@ -637,8 +637,8 @@ class FusedBatchNorm(Primitive):
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
|
|
|
|
|
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
|
|
|
|
|
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
|
|
|
|
|
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
|
|
|
|
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
|
|
|
|
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
|
|
|
|
self._update_parameter = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -710,8 +710,8 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
|
|
|
|
|
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
|
|
|
|
|
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
|
|
|
|
|
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
|
|
|
|
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
|
|
|
|
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
|
|
|
|
self._update_parameter = True
|
|
|
|
|
self.add_prim_attr('data_format', "NCHW")
|
|
|
|
|
|
|
|
|
@ -818,8 +818,8 @@ class BNTrainingUpdate(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("isRef", isRef, [bool], self.name)
|
|
|
|
|
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
|
|
|
|
validator.check_value_type("factor", factor, [float], self.name)
|
|
|
|
|
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
|
|
|
|
|
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
|
|
|
|
|
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', 'BNTrainingUpdate')
|
|
|
|
|
self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate')
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
|
|
|
|
|
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
|
|
|
|
@ -898,7 +898,7 @@ class BatchNorm(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, is_training=False, epsilon=1e-5):
|
|
|
|
|
validator.check_value_type('is_training', is_training, (bool,), self.name)
|
|
|
|
|
validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
|
|
|
|
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
|
|
|
|
self.add_prim_attr('data_format', "NCHW")
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
|
|
|
|
|
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
|
|
|
|
@ -2383,7 +2383,7 @@ class L2Normalize(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, input_x):
|
|
|
|
|
dim = len(input_x)
|
|
|
|
|
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
|
|
|
|
|
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
|
|
|
|
|
return input_x
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, input_x):
|
|
|
|
@ -2481,10 +2481,10 @@ class DropoutDoMask(PrimitiveWithInfer):
|
|
|
|
|
keep_prob_v = keep_prob['value']
|
|
|
|
|
if keep_prob_v is not None:
|
|
|
|
|
if isinstance(keep_prob['dtype'], type(mstype.tensor)):
|
|
|
|
|
validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
validator.check_float_range(keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
|
|
|
|
|
else:
|
|
|
|
|
validator.check_value_type("keep_prob", keep_prob_v, [float], self.name)
|
|
|
|
|
validator.check_number_range('keep_prob', keep_prob_v, 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
validator.check_float_range(keep_prob_v, 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
|
|
|
|
|
|
|
|
|
|
out = {'shape': input_x_shape,
|
|
|
|
|
'dtype': input_x['dtype'],
|
|
|
|
@ -2584,7 +2584,7 @@ class OneHot(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
# check shape
|
|
|
|
|
indices_shp = indices['shape']
|
|
|
|
|
validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name)
|
|
|
|
|
validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name)
|
|
|
|
|
depth_val = depth['value']
|
|
|
|
|
validator.check_non_negative_int(depth_val, "depth", self.name)
|
|
|
|
|
# create new dimension at end if self.axis is -1
|
|
|
|
@ -2771,7 +2771,7 @@ class LSTM(PrimitiveWithInfer):
|
|
|
|
|
self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
|
|
|
|
|
self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
|
|
|
|
|
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
|
|
|
|
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
|
|
|
|
|
|
|
|
|
|
if bidirectional:
|
|
|
|
|
self.num_directions = 2
|
|
|
|
@ -3054,7 +3054,7 @@ class ROIAlign(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
|
|
|
|
validator.check_value_type("sample_num", sample_num, [int], self.name)
|
|
|
|
|
validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name)
|
|
|
|
|
validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
validator.check_int_range(roi_end_mode, 0, 1, Rel.INC_BOTH, "roi_end_mode", self.name)
|
|
|
|
|
self.pooled_height = pooled_height
|
|
|
|
|
self.pooled_width = pooled_width
|
|
|
|
|
self.spatial_scale = spatial_scale
|
|
|
|
@ -3502,9 +3502,9 @@ class FusedSparseFtrl(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("l1", l1, [float], self.name)
|
|
|
|
|
validator.check_value_type("l2", l2, [float], self.name)
|
|
|
|
|
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
|
|
|
|
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
|
|
|
|
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.lr = validator.check_positive_float(lr, "lr", self.name)
|
|
|
|
|
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
|
|
|
|
|
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
|
|
|
|
|
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
|
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
|
|
|
|
|
@ -4240,7 +4240,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, lr, update_slots=True, use_locking=False):
|
|
|
|
|
validator.check_value_type("lr", lr, [float], self.name)
|
|
|
|
|
validator.check_number_range("lr", lr, float("-inf"), float("inf"), Rel.INC_NEITHER, self.name)
|
|
|
|
|
validator.check_is_float(lr, "lr", self.name)
|
|
|
|
|
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
|
|
|
|
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
|
|
|
|
|
@ -5142,9 +5142,9 @@ class SparseApplyFtrl(PrimitiveWithCheck):
|
|
|
|
|
validator.check_value_type("l1", l1, [float], self.name)
|
|
|
|
|
validator.check_value_type("l2", l2, [float], self.name)
|
|
|
|
|
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
|
|
|
|
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
|
|
|
|
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.lr = validator.check_positive_float(lr, "lr", self.name)
|
|
|
|
|
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
|
|
|
|
|
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
|
|
|
|
|
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
|
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
|
|
|
|
@ -5239,9 +5239,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("l1", l1, [float], self.name)
|
|
|
|
|
validator.check_value_type("l2", l2, [float], self.name)
|
|
|
|
|
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
|
|
|
|
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
|
|
|
|
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.lr = validator.check_positive_float(lr, "lr", self.name)
|
|
|
|
|
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
|
|
|
|
|
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
|
|
|
|
|
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
|
|
|
|
self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name)
|
|
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
@ -5285,7 +5285,7 @@ class Dropout(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, keep_prob=0.5):
|
|
|
|
|
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)
|
|
|
|
|
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
|
|
|
|
@ -5510,7 +5510,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
|
|
|
|
|
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
|
|
|
|
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
|
|
|
|
|
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
|
|
|
|
|
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
|
|
|
|
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
|
|
|
|
|
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
|
|
|
|