|
|
|
@ -279,8 +279,11 @@ class Fleet(object):
|
|
|
|
|
# for more examples, please reference https://github.com/PaddlePaddle/Fleet
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
context = {}
|
|
|
|
|
# cache original feed forward program
|
|
|
|
|
self.origin_main_program = loss.block.program
|
|
|
|
|
context["origin_main_program"] = self.origin_main_program
|
|
|
|
|
context["loss"] = loss
|
|
|
|
|
if startup_program == None:
|
|
|
|
|
self.origin_startup_program = \
|
|
|
|
|
paddle.default_startup_program().clone(for_test=False)
|
|
|
|
@ -288,6 +291,8 @@ class Fleet(object):
|
|
|
|
|
else:
|
|
|
|
|
self.origin_startup_program = \
|
|
|
|
|
startup_program.clone(for_test=False)
|
|
|
|
|
context["origin_startup_program"] = startup_program
|
|
|
|
|
context["role_maker"] = self._role_maker
|
|
|
|
|
|
|
|
|
|
# compile time
|
|
|
|
|
distributed_optimizer_list = \
|
|
|
|
@ -317,6 +322,9 @@ class Fleet(object):
|
|
|
|
|
|
|
|
|
|
valid_strategy = self.strategy_compiler._get_valid_strategy(
|
|
|
|
|
self.user_defined_strategy, can_not_apply_optimizer_list)
|
|
|
|
|
|
|
|
|
|
context["valid_strategy"] = valid_strategy
|
|
|
|
|
|
|
|
|
|
self.valid_strategy = valid_strategy
|
|
|
|
|
|
|
|
|
|
optimize_ops = []
|
|
|
|
@ -334,6 +342,8 @@ class Fleet(object):
|
|
|
|
|
parameter_list=parameter_list,
|
|
|
|
|
no_grad_set=no_grad_set)
|
|
|
|
|
|
|
|
|
|
context["program_optimize_ops"] = optimize_ops
|
|
|
|
|
context["program_params_grads"] = params_grads
|
|
|
|
|
if graph_optimizer:
|
|
|
|
|
optimize_ops, params_grads = graph_optimizer.minimize(
|
|
|
|
|
loss,
|
|
|
|
@ -344,12 +354,13 @@ class Fleet(object):
|
|
|
|
|
# if a graph optimizer takes effect, mostly
|
|
|
|
|
# optimizers_ops and params_grads are None
|
|
|
|
|
# i.e. users can not modify current computation graph anymore
|
|
|
|
|
context["graph_optimize_ops"] = optimize_ops
|
|
|
|
|
context["graph_optimize_grads"] = params_grads
|
|
|
|
|
|
|
|
|
|
if self._runtime_handle is None:
|
|
|
|
|
self._runtime_handle = RuntimeFactory()._create_runtime(
|
|
|
|
|
valid_strategy, self._role_maker, optimize_ops, params_grads)
|
|
|
|
|
self._runtime_handle = RuntimeFactory()._create_runtime(context)
|
|
|
|
|
|
|
|
|
|
if self._util is None:
|
|
|
|
|
self._util = UtilFactory()._create_util(
|
|
|
|
|
valid_strategy, self._role_maker, optimize_ops, params_grads)
|
|
|
|
|
self._util = UtilFactory()._create_util(context)
|
|
|
|
|
|
|
|
|
|
return optimize_ops, params_grads
|
|
|
|
|