|
|
|
@ -72,9 +72,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|
|
|
|
update_with_lr = op_mul(lr, update)
|
|
|
|
|
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
|
|
|
|
|
|
|
|
|
next_v = F.depend(next_v, F.assign(param, next_param))
|
|
|
|
|
next_v = F.depend(next_v, F.assign(m, next_m))
|
|
|
|
|
next_v = F.depend(next_v, F.assign(v, next_v))
|
|
|
|
|
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, mstype.float16)))
|
|
|
|
|
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, mstype.float16)))
|
|
|
|
|
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, mstype.float16)))
|
|
|
|
|
return next_v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|