optimize momentum to speedup dygraph, a little, test=develop (#30099)

revert-31562-mean
wanghuancoder 4 years ago committed by GitHub
parent 254ad61959
commit 88e6dc4ac5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -198,10 +198,6 @@ class Momentum(Optimizer):
velocity_acc = self._get_accumulator(self._velocity_acc_str, velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0]) param_and_grad[0])
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
@ -213,6 +209,11 @@ class Momentum(Optimizer):
self._regularization_coeff) self._regularization_coeff)
return None return None
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
attrs = { attrs = {
"mu": self._momentum, "mu": self._momentum,
"use_nesterov": self._use_nesterov, "use_nesterov": self._use_nesterov,

Loading…
Cancel
Save