diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 1a020a655f..0b52089e8b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1267,12 +1267,20 @@ class BatchNorm(PrimitiveWithInfer): Default: "NCHW". Inputs: + If `is_training` is False, inputs are Tensors. - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`. + If `is_training` is True, `scale`, `bias`, `mean` and `variance` are Parameters. + - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. + - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type. + - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. + - **mean** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type. + - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `mean`. + Outputs: Tuple of 5 Tensor, the normalized inputs and the updated parameters.