|
|
|
@ -30,9 +30,8 @@ lars_opt = C.MultitypeFuncGraph("lars_opt")
|
|
|
|
|
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
|
|
|
|
|
"""Apply lars optimizer to the weight parameter."""
|
|
|
|
|
if lars_flag:
|
|
|
|
|
op_reduce = P.ReduceSum()
|
|
|
|
|
w_square_sum = op_reduce(F.square(weight))
|
|
|
|
|
grad_square_sum = op_reduce(F.square(gradient))
|
|
|
|
|
op_reduce_sum = P.SquareSumAll()
|
|
|
|
|
w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient)
|
|
|
|
|
if decay_flag:
|
|
|
|
|
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate)
|
|
|
|
|
else:
|
|
|
|
|