fleet support paddle.optimzier (#28026)

fleet support paddle.optimzier

* bug fix

* fix fleet_base

* bug fix

* fix coverage
swt-req
MRXLT 4 years ago committed by GitHub
parent 5bb348a1c2
commit 55098b975e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1084,17 +1084,11 @@ class Fleet(object):
loss_name=loss.name, share_vars_from=None) loss_name=loss.name, share_vars_from=None)
loss.block.program._graph = compiled_program loss.block.program._graph = compiled_program
return self.user_defined_optimizer.minimize( return self.user_defined_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
if meta_optimizer: if meta_optimizer:
optimize_ops, params_grads = meta_optimizer.minimize( optimize_ops, params_grads = meta_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
default_program = paddle.static.default_main_program() default_program = paddle.static.default_main_program()
@ -1103,20 +1097,14 @@ class Fleet(object):
else: else:
optimize_ops, params_grads = self.user_defined_optimizer.minimize( optimize_ops, params_grads = self.user_defined_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
context["program_optimize_ops"] = optimize_ops context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads context["program_params_grads"] = params_grads
if graph_optimizer: if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize( optimize_ops, params_grads = graph_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
# since we do not encourage users to use graph operations # since we do not encourage users to use graph operations
# if a graph optimizer takes effect, mostly # if a graph optimizer takes effect, mostly
# optimizers_ops and params_grads are None # optimizers_ops and params_grads are None

@ -19,6 +19,7 @@ import abc
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import SGD from paddle.fluid.optimizer import SGD
from paddle.optimizer import SGD as SGD_v2
from paddle.fluid.incubate.fleet.base.mode import Mode from paddle.fluid.incubate.fleet.base.mode import Mode
from paddle.distributed.fleet.base.role_maker import RoleMakerBase from paddle.distributed.fleet.base.role_maker import RoleMakerBase
@ -291,7 +292,8 @@ class DistributedOptimizer(object):
def __init__(self, optimizer, strategy=None): def __init__(self, optimizer, strategy=None):
if not isinstance(optimizer, SGD.__bases__) \ if not isinstance(optimizer, SGD.__bases__) \
and not isinstance(optimizer, OptimizerWithMixedPrecision): and not isinstance(optimizer, OptimizerWithMixedPrecision) \
and not isinstance(optimizer, SGD_v2.__base__):
raise TypeError("optimizer must be an instance of Optimizer") raise TypeError("optimizer must be an instance of Optimizer")
self._optimizer = optimizer self._optimizer = optimizer

@ -28,6 +28,8 @@ from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid import compiler from paddle.fluid import compiler
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel, CheckpointSaver from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel, CheckpointSaver
import paddle
import os import os
import sys import sys
import six import six
@ -505,10 +507,7 @@ class CollectiveOptimizer(DistributedOptimizer):
self._strategy) self._strategy)
optimize_ops, param_grads = self._optimizer.minimize( optimize_ops, param_grads = self._optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
fleet._origin_program = main_program.clone(for_test=False) fleet._origin_program = main_program.clone(for_test=False)
fleet._transpiled_program = main_program fleet._transpiled_program = main_program

@ -60,7 +60,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.nccl_comm_num = 2 strategy.nccl_comm_num = 2
strategy.sync_nccl_allreduce = True strategy.sync_nccl_allreduce = True
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer( optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy) optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)

Loading…
Cancel
Save