|
|
|
@ -14,8 +14,8 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.fluid import program_guard, layers, default_main_program
|
|
|
|
|
from paddle.fluid.optimizer import Momentum, SGD
|
|
|
|
|
from .meta_optimizer_base import MetaOptimizerBase
|
|
|
|
|
from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op
|
|
|
|
|
|
|
|
|
@ -35,8 +35,10 @@ class LocalSGDOptimizer(MetaOptimizerBase):
|
|
|
|
|
if self.role_maker.worker_num() <= 1:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return isinstance(self.inner_opt, Momentum) \
|
|
|
|
|
or isinstance(self.inner_opt, SGD)
|
|
|
|
|
return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \
|
|
|
|
|
or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \
|
|
|
|
|
or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \
|
|
|
|
|
or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD)
|
|
|
|
|
|
|
|
|
|
def _disable_strategy(self, dist_strategy):
|
|
|
|
|
dist_strategy.localsgd = False
|
|
|
|
|