|
|
@ -90,7 +90,8 @@ class Flatten(PrimitiveWithInfer):
|
|
|
|
>>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
|
|
|
|
>>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
|
|
|
|
>>> flatten = P.Flatten()
|
|
|
|
>>> flatten = P.Flatten()
|
|
|
|
>>> output = flatten(input_tensor)
|
|
|
|
>>> output = flatten(input_tensor)
|
|
|
|
>>> assert output.shape == (1, 24)
|
|
|
|
>>> print(output.shape)
|
|
|
|
|
|
|
|
(1, 24)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@prim_attr_register
|
|
|
@ -700,7 +701,7 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|
|
|
Outputs:
|
|
|
|
Outputs:
|
|
|
|
Tuple of 6 Tensors, the normalized input, the updated parameters and reserve.
|
|
|
|
Tuple of 6 Tensors, the normalized input, the updated parameters and reserve.
|
|
|
|
|
|
|
|
|
|
|
|
- **output_x** (Tensor) - The input of FusedBatchNormEx, same type and shape as the `input_x`.
|
|
|
|
- **output_x** (Tensor) - The output of FusedBatchNormEx, same type and shape as the `input_x`.
|
|
|
|
- **updated_scale** (Tensor) - Updated parameter scale, Tensor of shape :math:`(C,)`, data type: float32.
|
|
|
|
- **updated_scale** (Tensor) - Updated parameter scale, Tensor of shape :math:`(C,)`, data type: float32.
|
|
|
|
- **updated_bias** (Tensor) - Updated parameter bias, Tensor of shape :math:`(C,)`, data type: float32.
|
|
|
|
- **updated_bias** (Tensor) - Updated parameter bias, Tensor of shape :math:`(C,)`, data type: float32.
|
|
|
|
- **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
|
|
|
- **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
|
|
@ -3206,7 +3207,7 @@ class Adam(PrimitiveWithInfer):
|
|
|
|
If false, the result is unpredictable. Default: False.
|
|
|
|
If false, the result is unpredictable. Default: False.
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
If true, update the gradients using NAG.
|
|
|
|
If true, update the gradients using NAG.
|
|
|
|
If true, update the gradients without using NAG. Default: False.
|
|
|
|
If false, update the gradients without using NAG. Default: False.
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **var** (Tensor) - Weights to be updated.
|
|
|
|
- **var** (Tensor) - Weights to be updated.
|
|
|
@ -3306,7 +3307,7 @@ class FusedSparseAdam(PrimitiveWithInfer):
|
|
|
|
If false, the result is unpredictable. Default: False.
|
|
|
|
If false, the result is unpredictable. Default: False.
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
If true, update the gradients using NAG.
|
|
|
|
If true, update the gradients using NAG.
|
|
|
|
If true, update the gradients without using NAG. Default: False.
|
|
|
|
If false, update the gradients without using NAG. Default: False.
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **var** (Parameter) - Parameters to be updated with float32 data type.
|
|
|
|
- **var** (Parameter) - Parameters to be updated with float32 data type.
|
|
|
@ -3439,7 +3440,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
|
|
|
|
If false, the result is unpredictable. Default: False.
|
|
|
|
If false, the result is unpredictable. Default: False.
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
If true, update the gradients using NAG.
|
|
|
|
If true, update the gradients using NAG.
|
|
|
|
If true, update the gradients without using NAG. Default: False.
|
|
|
|
If false, update the gradients without using NAG. Default: False.
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **var** (Parameter) - Parameters to be updated with float32 data type.
|
|
|
|
- **var** (Parameter) - Parameters to be updated with float32 data type.
|
|
|
|