|
|
|
@ -5690,18 +5690,19 @@ class LARSUpdate(PrimitiveWithInfer):
|
|
|
|
|
... self.lars = ops.LARSUpdate()
|
|
|
|
|
... self.reduce = ops.ReduceSum()
|
|
|
|
|
... def construct(self, weight, gradient):
|
|
|
|
|
... w_square_sum = self.reduce(ops.Square(weight))
|
|
|
|
|
... grad_square_sum = self.reduce(ops.Square(gradient))
|
|
|
|
|
... w_square_sum = self.reduce(ops.Square()(weight))
|
|
|
|
|
... grad_square_sum = self.reduce(ops.Square()(gradient))
|
|
|
|
|
... grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
|
|
|
|
|
... return grad_t
|
|
|
|
|
...
|
|
|
|
|
>>> np.random.seed(0)
|
|
|
|
|
>>> weight = np.random.random(size=(2, 3)).astype(np.float32)
|
|
|
|
|
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> output = net(Tensor(weight), Tensor(gradient))
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[1.0630977e-03 1.0647357e-03 1.0038106e-03]
|
|
|
|
|
[2.9038603e-04 5.9235965e-05 6.8709702e-04]]
|
|
|
|
|
[[0.00036534 0.00074454 0.00080456]
|
|
|
|
|
[0.00032014 0.00066101 0.00044157]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|