|
|
|
@ -19,6 +19,7 @@ import sys
|
|
|
|
|
from .. import compat as cpt
|
|
|
|
|
|
|
|
|
|
from . import core
|
|
|
|
|
from . import framework
|
|
|
|
|
|
|
|
|
|
__all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy']
|
|
|
|
|
|
|
|
|
@ -34,6 +35,15 @@ def _place_obj(place):
|
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_pserver_mode(main_program):
|
|
|
|
|
main = main_program if main_program \
|
|
|
|
|
else framework.default_main_program()
|
|
|
|
|
for op in main.global_block().ops:
|
|
|
|
|
if op.type in ["send", "recv"]:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CompiledProgram(object):
|
|
|
|
|
"""
|
|
|
|
|
Compiles a Program for execution.
|
|
|
|
@ -110,6 +120,8 @@ class CompiledProgram(object):
|
|
|
|
|
self._exec_strategy = ExecutionStrategy()
|
|
|
|
|
if self._build_strategy is None:
|
|
|
|
|
self._build_strategy = BuildStrategy()
|
|
|
|
|
self._build_strategy.is_distribution = _is_pserver_mode(
|
|
|
|
|
self._program) or self._build_strategy.num_trainers > 1
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def with_inference_optimize(self, config):
|
|
|
|
@ -185,8 +197,7 @@ class CompiledProgram(object):
|
|
|
|
|
self._build_strategy.trainers_endpoints = trainers_endpoints
|
|
|
|
|
|
|
|
|
|
self._persistable_vars = set([
|
|
|
|
|
cpt.to_text(v.name)
|
|
|
|
|
for v in [
|
|
|
|
|
cpt.to_text(v.name) for v in [
|
|
|
|
|
var for var in self._program.list_vars()
|
|
|
|
|
if var.persistable and var.type != core.VarDesc.VarType.RAW
|
|
|
|
|
]
|
|
|
|
|