|
|
|
@ -198,10 +198,6 @@ class Momentum(Optimizer):
|
|
|
|
|
|
|
|
|
|
velocity_acc = self._get_accumulator(self._velocity_acc_str,
|
|
|
|
|
param_and_grad[0])
|
|
|
|
|
find_master = self._multi_precision and param_and_grad[
|
|
|
|
|
0].dtype == core.VarDesc.VarType.FP16
|
|
|
|
|
master_weight = (self._master_weights[param_and_grad[0].name]
|
|
|
|
|
if find_master else None)
|
|
|
|
|
lr = self._create_param_lr(param_and_grad)
|
|
|
|
|
|
|
|
|
|
if framework.in_dygraph_mode():
|
|
|
|
@ -213,6 +209,11 @@ class Momentum(Optimizer):
|
|
|
|
|
self._regularization_coeff)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
find_master = self._multi_precision and param_and_grad[
|
|
|
|
|
0].dtype == core.VarDesc.VarType.FP16
|
|
|
|
|
master_weight = (self._master_weights[param_and_grad[0].name]
|
|
|
|
|
if find_master else None)
|
|
|
|
|
|
|
|
|
|
attrs = {
|
|
|
|
|
"mu": self._momentum,
|
|
|
|
|
"use_nesterov": self._use_nesterov,
|
|
|
|
|