|
|
|
@ -31,6 +31,7 @@ Steps to transpile pserver:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
import sys
|
|
|
|
|
import numpy as np
|
|
|
|
|
import collections
|
|
|
|
|
import six
|
|
|
|
@ -181,7 +182,8 @@ class DistributeTranspiler(object):
|
|
|
|
|
program=None,
|
|
|
|
|
pservers="127.0.0.1:6174",
|
|
|
|
|
trainers=1,
|
|
|
|
|
sync_mode=True):
|
|
|
|
|
sync_mode=True,
|
|
|
|
|
startup_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Run the transpiler.
|
|
|
|
|
|
|
|
|
@ -194,13 +196,17 @@ class DistributeTranspiler(object):
|
|
|
|
|
list.
|
|
|
|
|
trainers (int): number of trainers in the distributed job.
|
|
|
|
|
sync_mode (bool): Do sync training or not, default is True.
|
|
|
|
|
startup_program (Program|None): startup_program to transpile,
|
|
|
|
|
default is fluid.default_main_program().
|
|
|
|
|
"""
|
|
|
|
|
if program is None:
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
if startup_program is None:
|
|
|
|
|
startup_program = default_startup_program()
|
|
|
|
|
self.origin_program = program
|
|
|
|
|
self.origin_startup_program = default_startup_program().clone()
|
|
|
|
|
self.startup_program = startup_program
|
|
|
|
|
self.origin_startup_program = self.startup_program.clone()
|
|
|
|
|
|
|
|
|
|
self.startup_program = default_startup_program()
|
|
|
|
|
self.trainer_num = trainers
|
|
|
|
|
self.sync_mode = sync_mode
|
|
|
|
|
self.trainer_id = trainer_id
|
|
|
|
@ -376,21 +382,18 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
return self.origin_program
|
|
|
|
|
|
|
|
|
|
def _get_trainer_startup_program(self,
|
|
|
|
|
recv_vars,
|
|
|
|
|
eplist,
|
|
|
|
|
startup_program=None):
|
|
|
|
|
def _get_trainer_startup_program(self, recv_vars, eplist):
|
|
|
|
|
"""
|
|
|
|
|
Get transpiled trainer side startup program.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
startup_program(Program): Startup program.
|
|
|
|
|
recv_vars (list): Variable list to recv for current trainer_id
|
|
|
|
|
eplist (list): A list of strings indicating
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Program: trainer side startup program.
|
|
|
|
|
"""
|
|
|
|
|
if startup_program is None:
|
|
|
|
|
startup_program = self.startup_program
|
|
|
|
|
startup_program = self.startup_program
|
|
|
|
|
|
|
|
|
|
# FIXME(gongwb): delete not need ops.
|
|
|
|
|
# note that: some parameter is not trainable and those ops can't be deleted.
|
|
|
|
@ -438,7 +441,18 @@ class DistributeTranspiler(object):
|
|
|
|
|
#add concat ops to merge splited parameters received from parameter servers.
|
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
orig_param = startup_program.global_block().vars[varname]
|
|
|
|
|
# NOTE: if enable memory optimization, origin vars maybe removed.
|
|
|
|
|
if startup_program.global_block().vars.has_key(varname):
|
|
|
|
|
orig_param = startup_program.global_block().vars[varname]
|
|
|
|
|
else:
|
|
|
|
|
origin_param_var = self.origin_program.global_block().vars[
|
|
|
|
|
varname]
|
|
|
|
|
orig_param = startup_program.global_block().create_var(
|
|
|
|
|
name=varname,
|
|
|
|
|
persistable=origin_param_var.persistable,
|
|
|
|
|
type=origin_param_var.type,
|
|
|
|
|
dtype=origin_param_var.dtype,
|
|
|
|
|
shape=origin_param_var.shape)
|
|
|
|
|
startup_program.global_block().append_op(
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={"X": splited_var},
|
|
|
|
@ -461,7 +475,9 @@ class DistributeTranspiler(object):
|
|
|
|
|
# NOTE: assume blocks of the same variable is not distributed
|
|
|
|
|
# on the same pserver, only change param/grad varnames for
|
|
|
|
|
# trainers to fetch.
|
|
|
|
|
|
|
|
|
|
sys.stderr.write("get_pserver_program() is deprecated, call\
|
|
|
|
|
get_pserver_programs() to get pserver main and startup\
|
|
|
|
|
in a single call.")
|
|
|
|
|
# step1
|
|
|
|
|
pserver_program = Program()
|
|
|
|
|
pserver_program.random_seed = self.origin_program.random_seed
|
|
|
|
@ -651,32 +667,58 @@ class DistributeTranspiler(object):
|
|
|
|
|
endpoint)
|
|
|
|
|
|
|
|
|
|
pserver_program._sync_with_cpp()
|
|
|
|
|
# save pserver program to generate pserver side startup relatively.
|
|
|
|
|
self.pserver_program = pserver_program
|
|
|
|
|
return pserver_program
|
|
|
|
|
|
|
|
|
|
def get_pserver_programs(self, endpoint):
|
|
|
|
|
"""
|
|
|
|
|
Get pserver side main program and startup program for distributed training.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
endpoint (str): current pserver endpoint.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple: (main_program, startup_program), of type "Program"
|
|
|
|
|
"""
|
|
|
|
|
pserver_prog = self.get_pserver_program(endpoint)
|
|
|
|
|
pserver_startup = self.get_startup_program(endpoint)
|
|
|
|
|
return pserver_prog, pserver_startup
|
|
|
|
|
|
|
|
|
|
def get_startup_program(self,
|
|
|
|
|
endpoint,
|
|
|
|
|
pserver_program,
|
|
|
|
|
pserver_program=None,
|
|
|
|
|
startup_program=None):
|
|
|
|
|
"""
|
|
|
|
|
**Deprecated**
|
|
|
|
|
|
|
|
|
|
Get startup program for current parameter server.
|
|
|
|
|
Modify operator input variables if there are variables that
|
|
|
|
|
were split to several blocks.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
endpoint (str): current pserver endpoint.
|
|
|
|
|
pserver_program (Program): call get_pserver_program first and
|
|
|
|
|
pass the result here.
|
|
|
|
|
startup_program (Program): if pass None, will use
|
|
|
|
|
default_startup_program
|
|
|
|
|
pserver_program (Program): deprecated, call get_pserver_program first.
|
|
|
|
|
startup_program (Program): deprecated, should pass startup_program
|
|
|
|
|
when initalizing
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Program: parameter server side startup program.
|
|
|
|
|
"""
|
|
|
|
|
sys.stderr.write("get_startup_program() is deprecated, call\
|
|
|
|
|
get_pserver_programs() to get pserver main and startup\
|
|
|
|
|
in a single call.")
|
|
|
|
|
if pserver_program != None:
|
|
|
|
|
sys.stderr.write("passing pserver_program to get_startup_program()\
|
|
|
|
|
is deprecated, you can use new API get_pserver_programs() to\
|
|
|
|
|
get both pserver main program and startup program.")
|
|
|
|
|
if startup_program != None:
|
|
|
|
|
sys.stderr.write("passing startup_program to get_startup_program()\
|
|
|
|
|
is deprecated, use fluid.program_guard() or pass this argument\
|
|
|
|
|
to transpile() call.")
|
|
|
|
|
|
|
|
|
|
s_prog = Program()
|
|
|
|
|
if not startup_program:
|
|
|
|
|
orig_s_prog = default_startup_program()
|
|
|
|
|
else:
|
|
|
|
|
orig_s_prog = startup_program
|
|
|
|
|
orig_s_prog = self.startup_program
|
|
|
|
|
s_prog.random_seed = orig_s_prog.random_seed
|
|
|
|
|
params = self.param_grad_ep_mapping[endpoint]["params"]
|
|
|
|
|
|
|
|
|
|