|
|
|
@ -2478,6 +2478,27 @@ class LARSUpdate(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, representing the new gradient.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from mindspore import Tensor
|
|
|
|
|
>>> from mindspore.ops import operations as P
|
|
|
|
|
>>> from mindspore.ops import functional as F
|
|
|
|
|
>>> import mindspore.nn as nn
|
|
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(Net, self).__init__()
|
|
|
|
|
>>> self.lars = P.LARSUpdate()
|
|
|
|
|
>>> self.reduce = P.ReduceSum()
|
|
|
|
|
>>> def construct(self, weight, gradient):
|
|
|
|
|
>>> w_square_sum = self.reduce(F.square(weight))
|
|
|
|
|
>>> grad_square_sum = self.reduce(F.square(gradient))
|
|
|
|
|
>>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
|
|
|
|
|
>>> return grad_t
|
|
|
|
|
>>> weight = np.random.random(size=(2, 3)).astype(np.float32)
|
|
|
|
|
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> ms_output = net(Tensor(weight), Tensor(gradient))
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|