|
|
|
@ -38,6 +38,7 @@ from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server import version
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _has_global_step
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, DistributedStrategy, \
|
|
|
|
|
SyncStrategy, AsyncStrategy, HalfAsyncStrategy, GeoStrategy, StrategyFactory
|
|
|
|
|
|
|
|
|
@ -161,9 +162,9 @@ class FleetTranspiler(Fleet):
|
|
|
|
|
|
|
|
|
|
print(trainer_config)
|
|
|
|
|
|
|
|
|
|
lrs = _get_lr_ops(self._origin_main_program)
|
|
|
|
|
lrs = _has_global_step(_get_lr_ops(self._origin_main_program))
|
|
|
|
|
|
|
|
|
|
if len(lrs) > 0:
|
|
|
|
|
if lrs > 0:
|
|
|
|
|
kwargs = {"need_global_step": "1"}
|
|
|
|
|
else:
|
|
|
|
|
kwargs = {"need_global_step": "0"}
|
|
|
|
@ -186,14 +187,6 @@ class FleetTranspiler(Fleet):
|
|
|
|
|
recv_ctx = fleet.compiled_config.get_communicator_recv_context(
|
|
|
|
|
recv_type=1)
|
|
|
|
|
|
|
|
|
|
for name, ctx in send_ctx.items():
|
|
|
|
|
print("name: {}, ctx: {}".format(name, ctx))
|
|
|
|
|
|
|
|
|
|
print("==== = ==== =============== ====")
|
|
|
|
|
|
|
|
|
|
for name, ctx in recv_ctx.items():
|
|
|
|
|
print("name: {}, ctx: {}".format(name, ctx))
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.communicator import Communicator
|
|
|
|
|
self._communicator = Communicator(
|
|
|
|
|
trainer_config.mode, kwargs,
|
|
|
|
|