!10399 fix bugs of op BatchNorm2d, Lamb, LARSUpdate and Depend

From: @lihongkang1
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
pull/10399/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 791e994ea6

@ -383,10 +383,8 @@ class BatchNorm2d(_BatchNorm):
>>> print(output)
[[[[171.99915 46.999763 ]
[116.99941 191.99904 ]]
[[ 66.999664 250.99875 ]
[194.99902 102.99948 ]]
[[ 8.999955 210.99895 ]
[ 20.999895 241.9988 ]]]]
"""

@ -250,7 +250,7 @@ class Lamb(Optimizer):
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': poly_decay_lr},
... {'order_params': net.trainable_params(0.01, 0.0001, 10, 0.5)}]
... {'order_params': net.trainable_params(0.01)}]
>>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default

@ -5690,18 +5690,19 @@ class LARSUpdate(PrimitiveWithInfer):
... self.lars = ops.LARSUpdate()
... self.reduce = ops.ReduceSum()
... def construct(self, weight, gradient):
... w_square_sum = self.reduce(ops.Square(weight))
... grad_square_sum = self.reduce(ops.Square(gradient))
... w_square_sum = self.reduce(ops.Square()(weight))
... grad_square_sum = self.reduce(ops.Square()(gradient))
... grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
... return grad_t
...
>>> np.random.seed(0)
>>> weight = np.random.random(size=(2, 3)).astype(np.float32)
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
>>> net = Net()
>>> output = net(Tensor(weight), Tensor(gradient))
>>> print(output)
[[1.0630977e-03 1.0647357e-03 1.0038106e-03]
[2.9038603e-04 5.9235965e-05 6.8709702e-04]]
[[0.00036534 0.00074454 0.00080456]
[0.00032014 0.00066101 0.00044157]]
"""
@prim_attr_register

@ -426,6 +426,9 @@ class Depend(Primitive):
Outputs:
Tensor, the value passed by last operator.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register

Loading…
Cancel
Save