【paddle.fleet】Set default value to strategy in distributed_optimizer (#26246)

* set default value to strategy in distributed_optimizer test=develop
revert-24895-update_cub
Qinghe JING 5 years ago committed by GitHub
parent 672578a797
commit d549a9b1fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,6 +15,7 @@
from __future__ import print_function
import paddle
from .strategy_compiler import StrategyCompiler
from .distributed_strategy import DistributedStrategy
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
from .util_factory import UtilFactory
@ -209,7 +210,7 @@ class Fleet(object):
assert self._runtime_handle is not None
self._runtime_handle._stop_worker()
def distributed_optimizer(self, optimizer, strategy):
def distributed_optimizer(self, optimizer, strategy=None):
"""
distirbuted_optimizer
Returns:
@ -225,6 +226,8 @@ class Fleet(object):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
"""
self.user_defined_optimizer = optimizer
if strategy == None:
strategy = DistributedStrategy()
self.user_defined_strategy = strategy
self.valid_strategy = None
return self

@ -145,9 +145,9 @@ class TestFleetBase(unittest.TestCase):
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer = fleet.distributed_optimizer(optimizer)
def test_minimize(self):
import paddle

Loading…
Cancel
Save