|
|
|
@ -29,15 +29,6 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
|
|
|
|
|
BuildStrategy = core.ParallelExecutor.BuildStrategy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_pserver_mode(main_program):
|
|
|
|
|
main = main_program if main_program \
|
|
|
|
|
else framework.default_main_program()
|
|
|
|
|
for op in main.global_block().ops:
|
|
|
|
|
if op.type in ["send", "recv"]:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelExecutor(object):
|
|
|
|
|
"""
|
|
|
|
|
ParallelExecutor is designed for data parallelism, which focuses on distributing
|
|
|
|
@ -140,7 +131,7 @@ class ParallelExecutor(object):
|
|
|
|
|
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
|
|
|
|
|
# num_trainers is 1, so the current fields of build_strategy doesn't tell if
|
|
|
|
|
# it's distributed model.
|
|
|
|
|
build_strategy.is_distribution = _is_pserver_mode(
|
|
|
|
|
build_strategy.is_distribution = framework.is_pserver_mode(
|
|
|
|
|
main_program) or num_trainers > 1
|
|
|
|
|
|
|
|
|
|
# step4: get main_program, scope, local_scopes
|
|
|
|
|