|
|
|
@ -119,6 +119,8 @@ class Fleet(object):
|
|
|
|
|
self.strategy_compiler = None
|
|
|
|
|
self._is_collective = False
|
|
|
|
|
self._runtime_handle = None
|
|
|
|
|
self._util = None
|
|
|
|
|
self._context = {}
|
|
|
|
|
|
|
|
|
|
def init(self, role_maker=None, is_collective=False):
|
|
|
|
|
"""
|
|
|
|
@ -233,7 +235,7 @@ class Fleet(object):
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int: worker numbers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
@ -569,8 +571,9 @@ class Fleet(object):
|
|
|
|
|
|
|
|
|
|
if strategy == None:
|
|
|
|
|
strategy = DistributedStrategy()
|
|
|
|
|
self.user_defined_strategy = strategy
|
|
|
|
|
self.valid_strategy = None
|
|
|
|
|
|
|
|
|
|
self._user_defined_strategy = copy.deepcopy(strategy)
|
|
|
|
|
self._context = {}
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
@dygraph_only
|
|
|
|
@ -909,6 +912,15 @@ class Fleet(object):
|
|
|
|
|
# imitate target optimizer retrieval
|
|
|
|
|
return self.user_defined_optimizer.clear_grad()
|
|
|
|
|
|
|
|
|
|
def _final_strategy(self):
|
|
|
|
|
if "valid_strategy" not in self._context:
|
|
|
|
|
print(
|
|
|
|
|
"WARNING: You may need to call minimize function before this function is called"
|
|
|
|
|
)
|
|
|
|
|
return {}
|
|
|
|
|
else:
|
|
|
|
|
return self._context["valid_strategy"]
|
|
|
|
|
|
|
|
|
|
def minimize(self,
|
|
|
|
|
loss,
|
|
|
|
|
startup_program=None,
|
|
|
|
@ -958,12 +970,15 @@ class Fleet(object):
|
|
|
|
|
# for more examples, please reference https://github.com/PaddlePaddle/FleetX
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
context = {}
|
|
|
|
|
context["user_defined_strategy"] = copy.deepcopy(
|
|
|
|
|
self._user_defined_strategy)
|
|
|
|
|
if paddle.fluid.framework.in_dygraph_mode():
|
|
|
|
|
# imitate target optimizer retrieval
|
|
|
|
|
target_opt = self.user_defined_optimizer
|
|
|
|
|
self._context = context
|
|
|
|
|
return target_opt.minimize(loss)
|
|
|
|
|
|
|
|
|
|
context = {}
|
|
|
|
|
# cache original feed forward program
|
|
|
|
|
self.origin_main_program = loss.block.program
|
|
|
|
|
context["origin_main_program"] = self.origin_main_program
|
|
|
|
@ -984,17 +999,19 @@ class Fleet(object):
|
|
|
|
|
MetaOptimizerFactory()._get_valid_meta_optimizers(
|
|
|
|
|
self.user_defined_optimizer)
|
|
|
|
|
|
|
|
|
|
context["user_defined_strategy"] = copy.copy(self.user_defined_strategy)
|
|
|
|
|
context["user_defined_strategy"] = copy.deepcopy(
|
|
|
|
|
self._user_defined_strategy)
|
|
|
|
|
copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
|
|
|
|
|
|
|
|
|
|
# trigger the auto-parallel in very strict condition
|
|
|
|
|
# strategy = DistributedStrategy()
|
|
|
|
|
# strategy.auto = True
|
|
|
|
|
# optimizer = paddle.optimizer.SGD(learning_rate=0.1)
|
|
|
|
|
# optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
|
|
|
|
if self.user_defined_strategy._is_strict_auto():
|
|
|
|
|
if copy_user_defined_strategy._is_strict_auto():
|
|
|
|
|
# turn on all the strategy for each optimizer
|
|
|
|
|
for opt in distributed_optimizer_list:
|
|
|
|
|
opt._enable_strategy(self.user_defined_strategy, context)
|
|
|
|
|
opt._enable_strategy(copy_user_defined_strategy, context)
|
|
|
|
|
|
|
|
|
|
valid_optimizer_list = []
|
|
|
|
|
valid_graph_optimizer_list = []
|
|
|
|
@ -1003,7 +1020,7 @@ class Fleet(object):
|
|
|
|
|
for opt in distributed_optimizer_list:
|
|
|
|
|
opt._set_basic_info(loss, self._role_maker,
|
|
|
|
|
self.user_defined_optimizer,
|
|
|
|
|
self.user_defined_strategy)
|
|
|
|
|
copy_user_defined_strategy)
|
|
|
|
|
if opt._can_apply() and not opt._is_graph_out():
|
|
|
|
|
valid_optimizer_list.append(opt)
|
|
|
|
|
elif opt._can_apply() and opt._is_graph_out():
|
|
|
|
@ -1014,13 +1031,15 @@ class Fleet(object):
|
|
|
|
|
meta_optimizer, graph_optimizer = \
|
|
|
|
|
self.strategy_compiler.generate_optimizer(
|
|
|
|
|
loss, self._role_maker, self.user_defined_optimizer,
|
|
|
|
|
self.user_defined_strategy, valid_optimizer_list,
|
|
|
|
|
copy_user_defined_strategy, valid_optimizer_list,
|
|
|
|
|
valid_graph_optimizer_list)
|
|
|
|
|
|
|
|
|
|
valid_strategy = self.strategy_compiler._get_valid_strategy(
|
|
|
|
|
self.user_defined_strategy, can_not_apply_optimizer_list)
|
|
|
|
|
copy_user_defined_strategy, can_not_apply_optimizer_list)
|
|
|
|
|
|
|
|
|
|
context["valid_strategy"] = copy.deepcopy(valid_strategy)
|
|
|
|
|
|
|
|
|
|
context["valid_strategy"] = valid_strategy
|
|
|
|
|
self._context = context
|
|
|
|
|
|
|
|
|
|
self.valid_strategy = valid_strategy
|
|
|
|
|
self.valid_strategy._enable_env()
|
|
|
|
|