|
|
@ -15,6 +15,7 @@
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import print_function
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from .strategy_compiler import StrategyCompiler
|
|
|
|
from .strategy_compiler import StrategyCompiler
|
|
|
|
|
|
|
|
from .distributed_strategy import DistributedStrategy
|
|
|
|
from .meta_optimizer_factory import MetaOptimizerFactory
|
|
|
|
from .meta_optimizer_factory import MetaOptimizerFactory
|
|
|
|
from .runtime_factory import RuntimeFactory
|
|
|
|
from .runtime_factory import RuntimeFactory
|
|
|
|
from .util_factory import UtilFactory
|
|
|
|
from .util_factory import UtilFactory
|
|
|
@ -209,7 +210,7 @@ class Fleet(object):
|
|
|
|
assert self._runtime_handle is not None
|
|
|
|
assert self._runtime_handle is not None
|
|
|
|
self._runtime_handle._stop_worker()
|
|
|
|
self._runtime_handle._stop_worker()
|
|
|
|
|
|
|
|
|
|
|
|
def distributed_optimizer(self, optimizer, strategy):
|
|
|
|
def distributed_optimizer(self, optimizer, strategy=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
distirbuted_optimizer
|
|
|
|
distirbuted_optimizer
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
@ -225,6 +226,8 @@ class Fleet(object):
|
|
|
|
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
|
|
|
|
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self.user_defined_optimizer = optimizer
|
|
|
|
self.user_defined_optimizer = optimizer
|
|
|
|
|
|
|
|
if strategy == None:
|
|
|
|
|
|
|
|
strategy = DistributedStrategy()
|
|
|
|
self.user_defined_strategy = strategy
|
|
|
|
self.user_defined_strategy = strategy
|
|
|
|
self.valid_strategy = None
|
|
|
|
self.valid_strategy = None
|
|
|
|
return self
|
|
|
|
return self
|
|
|
|