|
|
|
@ -623,6 +623,73 @@ class FusedBatchNorm(Primitive):
|
|
|
|
|
self._update_parameter = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FusedBatchNormEx(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
FusedBatchNormEx is an extension of FusedBatchNorm
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
|
|
|
|
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
|
|
|
|
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
|
|
|
|
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
|
|
|
|
|
Momentum value should be [0, 1]. Default: 0.9.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
|
|
|
|
- **scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tuple of 6 Tensor, the normalized input and the updated parameters.
|
|
|
|
|
|
|
|
|
|
- **output_x** (Tensor) - The same type and shape as the `input_x`.
|
|
|
|
|
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
- **reserve** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
|
|
|
|
>>> scale = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> bias = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> mean = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> variance = Tensor(np.ones([64]), mindspore.float32)
|
|
|
|
|
>>> op = P.FusedBatchNormEx()
|
|
|
|
|
>>> output = op(input_x, scale, bias, mean, variance)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
|
|
|
|
|
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
|
|
|
|
|
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
|
|
|
|
|
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
|
|
|
|
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
self._update_parameter = True
|
|
|
|
|
self.add_prim_attr('data_format', "NCHW")
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, input_x, scale, bias, mean, variance):
|
|
|
|
|
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
|
|
|
|
|
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
|
|
|
|
|
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
|
|
|
|
|
return (input_x, scale, scale, scale, scale, scale)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, input_x, scale, bias, mean, variance):
|
|
|
|
|
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
args = {"scale": scale, "bias": bias}
|
|
|
|
|
validator.check_tensor_type_same(args, [mstype.float32], self.name)
|
|
|
|
|
args_moving = {"mean": mean, "variance": variance}
|
|
|
|
|
valid_types = [mstype.tensor_type(mstype.float32)]
|
|
|
|
|
validator.check_type_same(args_moving, valid_types, self.name)
|
|
|
|
|
return (input_x, scale, scale, scale, scale, scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BNTrainingReduce(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
reduce sum at axis [0, 2, 3].
|
|
|
|
|