|
|
|
@ -714,15 +714,20 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
class BNTrainingReduce(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
reduce sum at axis [0, 2, 3].
|
|
|
|
|
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with
|
|
|
|
|
BNTrainingUpdate.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
|
|
|
|
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **sum** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **square_sum** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **square_sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
|
|
|
|
>>> bn_training_reduce = P.BNTrainingReduce(input_x)
|
|
|
|
|
>>> output = bn_training_reduce(input_x)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@ -734,24 +739,90 @@ class BNTrainingReduce(PrimitiveWithInfer):
|
|
|
|
|
return ([x_shape[1]], [x_shape[1]])
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type):
|
|
|
|
|
validator.check_tensor_type_same({"x_type": x_type}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
return (x_type, x_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BNTrainingUpdate(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
The primitive operator of the register and info descriptor in bn_training_update.
|
|
|
|
|
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with
|
|
|
|
|
BNTrainingReduce.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
isRef (bool): If a ref. Default: True.
|
|
|
|
|
epsilon (float): A small value added to variance avoid dividing by zero. Default: 1e-5.
|
|
|
|
|
factor (float): A weight for updating the mean and variance. Default: 0.1.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
|
|
|
|
|
- **sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce.
|
|
|
|
|
Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **square_sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator
|
|
|
|
|
BNTrainingReduce. Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **scale** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling factor.
|
|
|
|
|
Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **offset** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling offset.
|
|
|
|
|
Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **mean** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling mean. Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **variance** (Tensor) - A 1-D Tensor with float16 or float32, for the update variance.
|
|
|
|
|
Tensor of shape :math:`(C,)`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **y** (Tensor) - Tensor, has the same shape data type as `x`.
|
|
|
|
|
- **mean** (Tensor) - Tensor for the updated mean, with float32 data type.
|
|
|
|
|
Has the same shape as `variance`.
|
|
|
|
|
- **variance** (Tensor) - Tensor for the updated variance, with float32 data type.
|
|
|
|
|
Has the same shape as `variance`.
|
|
|
|
|
- **batch_mean** (Tensor) - Tensor for the mean of `x`, with float32 data type.
|
|
|
|
|
Has the same shape as `variance`.
|
|
|
|
|
- **batch_variance** (Tensor) - Tensor for the mean of `variance`, with float32 data type.
|
|
|
|
|
Has the same shape as `variance`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
|
|
|
|
>>> sum = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> square_sum = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> scale = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> offset = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> mean = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> variance = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> bn_training_update = P.BNTrainingUpdate()
|
|
|
|
|
>>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance)
|
|
|
|
|
"""
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
|
|
|
|
|
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
|
|
|
|
|
validator.check_value_type("isRef", isRef, [bool], self.name)
|
|
|
|
|
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
|
|
|
|
validator.check_value_type("factor", factor, [float], self.name)
|
|
|
|
|
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
|
|
|
|
|
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
|
|
|
|
|
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
|
|
|
|
|
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name)
|
|
|
|
|
return (x, variance, variance, variance, variance)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
|
|
|
|
|
validator.check_tensor_type_same({"x_type": x}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"sum_type": sum}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"square_sum_type": square_sum}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"scale_type": scale}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"b_type": b}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"mean_type": mean}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"variance_type": variance}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
return (x, variance, variance, variance, variance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|