|
|
|
@ -143,7 +143,8 @@ class DistributeTranspiler:
|
|
|
|
|
program=None,
|
|
|
|
|
pservers="127.0.0.1:6174",
|
|
|
|
|
trainers=1,
|
|
|
|
|
split_method=splitter.round_robin):
|
|
|
|
|
split_method=splitter.round_robin,
|
|
|
|
|
sync_mode=True):
|
|
|
|
|
"""
|
|
|
|
|
Transpile the program to distributed data-parallelism programs.
|
|
|
|
|
The main_program will be transformed to use a remote parameter server
|
|
|
|
@ -184,6 +185,9 @@ class DistributeTranspiler:
|
|
|
|
|
:param split_method: A function to determin how to split variables
|
|
|
|
|
to different servers equally.
|
|
|
|
|
:type split_method: function
|
|
|
|
|
:type sync_mode: boolean default True
|
|
|
|
|
:param sync_mode: if sync_mode is set True, it means that dist transpiler
|
|
|
|
|
will transpile the program into sync_mode pserver and trainer program.
|
|
|
|
|
"""
|
|
|
|
|
assert (callable(split_method))
|
|
|
|
|
if program is None:
|
|
|
|
@ -191,6 +195,7 @@ class DistributeTranspiler:
|
|
|
|
|
self.origin_program = program
|
|
|
|
|
self.trainer_num = trainers
|
|
|
|
|
self.optimize_ops = optimize_ops
|
|
|
|
|
self.sync_mode = sync_mode
|
|
|
|
|
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
|
|
|
|
|
# like Kubernetes, we should port this to use etcd later when developing
|
|
|
|
|
# fluid distributed training with fault-tolerance.
|
|
|
|
@ -295,8 +300,11 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={"X": send_inputs},
|
|
|
|
|
outputs={"Out": send_outputs,
|
|
|
|
|
"RPCClient": rpc_client_var},
|
|
|
|
|
attrs={"endpoints": pserver_endpoints,
|
|
|
|
|
"epmap": eplist})
|
|
|
|
|
attrs={
|
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|
|
"epmap": eplist,
|
|
|
|
|
"sync_mode": self.sync_mode
|
|
|
|
|
})
|
|
|
|
|
# step4: Concat the parameters splits together after recv.
|
|
|
|
|
for varname, splited_var in param_var_mapping.iteritems():
|
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
@ -356,7 +364,7 @@ class DistributeTranspiler:
|
|
|
|
|
type=v.type,
|
|
|
|
|
dtype=v.dtype,
|
|
|
|
|
shape=v.shape)
|
|
|
|
|
if self.trainer_num > 1:
|
|
|
|
|
if self.sync_mode and self.trainer_num > 1:
|
|
|
|
|
for trainer_id in xrange(self.trainer_num):
|
|
|
|
|
var = pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d" % (orig_var_name, trainer_id),
|
|
|
|
@ -402,13 +410,13 @@ class DistributeTranspiler:
|
|
|
|
|
for op in self.optimize_ops:
|
|
|
|
|
if op.type == "scale":
|
|
|
|
|
for in_name in op.input_arg_names:
|
|
|
|
|
if in_name.startswith("beta1_pow_acc") or\
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
if in_name.startswith("beta1_pow_acc") or \
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
global_ops.append(op)
|
|
|
|
|
|
|
|
|
|
def __append_optimize_op__(op, block):
|
|
|
|
|
def __append_optimize_op__(op, block, grad_to_block_id):
|
|
|
|
|
if self._is_opt_op(op):
|
|
|
|
|
self._append_pserver_ops(block, op, endpoint,
|
|
|
|
|
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
|
|
|
|
|
default_main_program())
|
|
|
|
|
else:
|
|
|
|
|
self._append_pserver_non_opt_ops(block, op)
|
|
|
|
@ -422,16 +430,16 @@ class DistributeTranspiler:
|
|
|
|
|
self._append_pserver_non_opt_ops(lr_decay_block, op)
|
|
|
|
|
|
|
|
|
|
# append op to the current block
|
|
|
|
|
grad_to_block_id = []
|
|
|
|
|
pre_block_idx = pserver_program.num_blocks - 1
|
|
|
|
|
for idx, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
per_opt_block = pserver_program.create_block(pre_block_idx)
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if ufind.is_connected(op, opt_op) and op not in global_ops:
|
|
|
|
|
__append_optimize_op__(op, per_opt_block)
|
|
|
|
|
__append_optimize_op__(op, per_opt_block, grad_to_block_id)
|
|
|
|
|
|
|
|
|
|
# append global ops
|
|
|
|
|
opt_state_block = None
|
|
|
|
|
if global_ops:
|
|
|
|
|
opt_state_block = pserver_program.create_block(
|
|
|
|
|
pserver_program.num_blocks - 1)
|
|
|
|
@ -472,7 +480,9 @@ class DistributeTranspiler:
|
|
|
|
|
"OptimizeBlock": pserver_program.block(1),
|
|
|
|
|
"endpoint": endpoint,
|
|
|
|
|
"Fanin": self.trainer_num,
|
|
|
|
|
"PrefetchBlock": prefetch_block
|
|
|
|
|
"PrefetchBlock": prefetch_block,
|
|
|
|
|
"sync_mode": self.sync_mode,
|
|
|
|
|
"grad_to_block_id": grad_to_block_id
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
pserver_program.sync_with_cpp()
|
|
|
|
@ -683,17 +693,6 @@ class DistributeTranspiler:
|
|
|
|
|
self.table_name)],
|
|
|
|
|
persistable=False)
|
|
|
|
|
|
|
|
|
|
# create grad vars in pserver program
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
table_grad_list = [
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, index, pserver_index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype) for index in range(self.trainer_num)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# create table optimize block in pserver program
|
|
|
|
|
table_opt_op = [
|
|
|
|
|
op for op in self.optimize_ops
|
|
|
|
@ -703,11 +702,24 @@ class DistributeTranspiler:
|
|
|
|
|
# only support sgd now
|
|
|
|
|
assert table_opt_op.type == "sgd"
|
|
|
|
|
|
|
|
|
|
# append sum op for table_grad_list
|
|
|
|
|
table_opt_block.append_op(
|
|
|
|
|
type="sum",
|
|
|
|
|
inputs={"X": table_grad_list},
|
|
|
|
|
outputs={"Out": [grad_var]})
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
# create grad vars in pserver program
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
table_grad_list = [
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, index, pserver_index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(self.trainer_num)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# append sum op for table_grad_list
|
|
|
|
|
table_opt_block.append_op(
|
|
|
|
|
type="sum",
|
|
|
|
|
inputs={"X": table_grad_list},
|
|
|
|
|
outputs={"Out": [grad_var]})
|
|
|
|
|
|
|
|
|
|
lr_var = pserver_program.global_block().vars[table_opt_op.input(
|
|
|
|
|
"LearningRate")[0]]
|
|
|
|
@ -746,7 +758,7 @@ class DistributeTranspiler:
|
|
|
|
|
for varname, splited in block_map.iteritems():
|
|
|
|
|
orig_var = program.global_block().var(varname)
|
|
|
|
|
if len(splited) == 1:
|
|
|
|
|
if add_trainer_suffix:
|
|
|
|
|
if self.sync_mode and add_trainer_suffix:
|
|
|
|
|
new_var_name = "%s.trainer_%d" % \
|
|
|
|
|
(orig_var.name, self.trainer_id)
|
|
|
|
|
program.global_block().rename_var(varname, new_var_name)
|
|
|
|
@ -770,7 +782,7 @@ class DistributeTranspiler:
|
|
|
|
|
if len(orig_shape) >= 2:
|
|
|
|
|
splited_shape.extend(orig_shape[1:])
|
|
|
|
|
new_var_name = ""
|
|
|
|
|
if add_trainer_suffix:
|
|
|
|
|
if self.sync_mode and add_trainer_suffix:
|
|
|
|
|
new_var_name = "%s.block%d.trainer_%d" % \
|
|
|
|
|
(varname, i, self.trainer_id)
|
|
|
|
|
else:
|
|
|
|
@ -879,7 +891,7 @@ class DistributeTranspiler:
|
|
|
|
|
return orig_var_name
|
|
|
|
|
|
|
|
|
|
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
|
|
|
|
|
origin_program):
|
|
|
|
|
grad_to_block_id, origin_program):
|
|
|
|
|
program = optimize_block.program
|
|
|
|
|
pserver_block = program.global_block()
|
|
|
|
|
new_inputs = dict()
|
|
|
|
@ -900,7 +912,9 @@ class DistributeTranspiler:
|
|
|
|
|
return
|
|
|
|
|
merged_var = \
|
|
|
|
|
pserver_block.vars[self._orig_varname(grad_block.name)]
|
|
|
|
|
if self.trainer_num > 1:
|
|
|
|
|
grad_to_block_id.append(merged_var.name + ":" + str(
|
|
|
|
|
optimize_block.idx))
|
|
|
|
|
if self.sync_mode and self.trainer_num > 1:
|
|
|
|
|
vars2merge = []
|
|
|
|
|
for i in xrange(self.trainer_num):
|
|
|
|
|
per_trainer_name = "%s.trainer_%d" % \
|
|
|
|
@ -918,6 +932,7 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={"X": merged_var},
|
|
|
|
|
outputs={"Out": merged_var},
|
|
|
|
|
attrs={"scale": 1.0 / float(self.trainer_num)})
|
|
|
|
|
|
|
|
|
|
new_inputs[key] = merged_var
|
|
|
|
|
elif key == "Param":
|
|
|
|
|
# param is already created on global program
|
|
|
|
|