|
|
|
@ -995,7 +995,7 @@ class Conv2D(PrimitiveWithInfer):
|
|
|
|
|
else:
|
|
|
|
|
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
|
|
|
|
self.padding = pad
|
|
|
|
|
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
|
|
|
|
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
|
|
|
|
|
|
|
|
|
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
|
|
|
|
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
|
|
|
@ -1134,7 +1134,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
|
|
|
|
else:
|
|
|
|
|
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
|
|
|
|
self.padding = pad
|
|
|
|
|
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
|
|
|
|
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
|
|
|
|
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
|
|
|
|
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
|
|
|
|
if self.pad_mode == 'pad':
|
|
|
|
@ -1216,7 +1216,7 @@ class _Pool(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
|
|
|
|
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
|
|
|
|
|
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
|
|
|
|
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
|
|
|
|
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
|
|
|
|
self.add_prim_attr("padding", self.padding)
|
|
|
|
|
self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
|
|
|
|
|
if not self.is_maxpoolwithargmax:
|
|
|
|
@ -1521,7 +1521,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|
|
|
|
else:
|
|
|
|
|
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
|
|
|
|
self.padding = pad
|
|
|
|
|
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
|
|
|
|
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
|
|
|
|
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
|
|
|
|
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
|
|
|
|
if self.pad_mode == 'pad':
|
|
|
|
@ -1942,8 +1942,8 @@ class DataFormatDimMap(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, src_format='NHWC', dst_format='NCHW'):
|
|
|
|
|
valid_values = ['NHWC', 'NCHW']
|
|
|
|
|
self.src_format = validator.check_string("src_format", src_format, valid_values, self.name)
|
|
|
|
|
self.dst_format = validator.check_string("dst_format", dst_format, valid_values, self.name)
|
|
|
|
|
self.src_format = validator.check_string(src_format, valid_values, "src_format", self.name)
|
|
|
|
|
self.dst_format = validator.check_string(dst_format, valid_values, "dst_format", self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
@ -2961,7 +2961,7 @@ class MirrorPad(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, mode='REFLECT'):
|
|
|
|
|
"""Initialize Pad"""
|
|
|
|
|
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
|
|
|
|
|
validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
|
|
|
|
|
self.mode = mode
|
|
|
|
|
self.set_const_input_indexes([1])
|
|
|
|
|
|
|
|
|
@ -3651,7 +3651,7 @@ class KLDivLoss(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, reduction='mean'):
|
|
|
|
|
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
|
|
|
|
|
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, y_shape):
|
|
|
|
|
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
|
|
|
@ -3727,7 +3727,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, reduction='mean'):
|
|
|
|
|
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
|
|
|
|
|
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, y_shape, weight_shape):
|
|
|
|
|
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
|
|
|
@ -5487,7 +5487,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|
|
|
|
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
|
|
|
|
|
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
|
|
|
|
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
|
|
|
|
|
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
|
|
|
|
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
|
|
|
|
@ -5605,9 +5605,9 @@ class DynamicRNN(PrimitiveWithInfer):
|
|
|
|
|
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
|
|
|
|
|
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
|
|
|
|
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
|
|
|
|
|
self.cell_type = validator.check_string("cell_type", cell_type, ['LSTM'], self.name)
|
|
|
|
|
self.direction = validator.check_string("direction", direction, ['UNIDIRECTIONAL'], self.name)
|
|
|
|
|
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
|
|
|
|
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
|
|
|
|
|
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
|
|
|
|
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
|
|
|
|
@ -5720,7 +5720,7 @@ class LRN(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("alpha", alpha, [float], self.name)
|
|
|
|
|
validator.check_value_type("beta", beta, [float], self.name)
|
|
|
|
|
validator.check_value_type("norm_region", norm_region, [str], self.name)
|
|
|
|
|
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)
|
|
|
|
|
validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name)
|
|
|
|
|
validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|