|
|
|
@ -358,7 +358,7 @@ class DistributeTranspiler:
|
|
|
|
|
type=v.type,
|
|
|
|
|
dtype=v.dtype,
|
|
|
|
|
shape=v.shape)
|
|
|
|
|
if self.trainer_num > 1:
|
|
|
|
|
if self.sync_mode and self.trainer_num > 1:
|
|
|
|
|
for trainer_id in xrange(self.trainer_num):
|
|
|
|
|
var = pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d" % (orig_var_name, trainer_id),
|
|
|
|
@ -688,6 +688,16 @@ class DistributeTranspiler:
|
|
|
|
|
self.table_name)],
|
|
|
|
|
persistable=False)
|
|
|
|
|
|
|
|
|
|
# create table optimize block in pserver program
|
|
|
|
|
table_opt_op = [
|
|
|
|
|
op for op in self.optimize_ops
|
|
|
|
|
if op.input("Param")[0] == self.table_name
|
|
|
|
|
][0]
|
|
|
|
|
table_opt_block = pserver_program.create_block(pre_block_idx)
|
|
|
|
|
# only support sgd now
|
|
|
|
|
assert table_opt_op.type == "sgd"
|
|
|
|
|
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
# create grad vars in pserver program
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
table_grad_list = [
|
|
|
|
@ -696,18 +706,10 @@ class DistributeTranspiler:
|
|
|
|
|
(table_grad_var.name, index, pserver_index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype) for index in range(self.trainer_num)
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(self.trainer_num)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# create table optimize block in pserver program
|
|
|
|
|
table_opt_op = [
|
|
|
|
|
op for op in self.optimize_ops
|
|
|
|
|
if op.input("Param")[0] == self.table_name
|
|
|
|
|
][0]
|
|
|
|
|
table_opt_block = pserver_program.create_block(pre_block_idx)
|
|
|
|
|
# only support sgd now
|
|
|
|
|
assert table_opt_op.type == "sgd"
|
|
|
|
|
|
|
|
|
|
# append sum op for table_grad_list
|
|
|
|
|
table_opt_block.append_op(
|
|
|
|
|
type="sum",
|
|
|
|
@ -751,7 +753,7 @@ class DistributeTranspiler:
|
|
|
|
|
for varname, splited in block_map.iteritems():
|
|
|
|
|
orig_var = program.global_block().var(varname)
|
|
|
|
|
if len(splited) == 1:
|
|
|
|
|
if add_trainer_suffix:
|
|
|
|
|
if self.sync_mode and add_trainer_suffix:
|
|
|
|
|
new_var_name = "%s.trainer_%d" % \
|
|
|
|
|
(orig_var.name, self.trainer_id)
|
|
|
|
|
program.global_block().rename_var(varname, new_var_name)
|
|
|
|
@ -775,7 +777,7 @@ class DistributeTranspiler:
|
|
|
|
|
if len(orig_shape) >= 2:
|
|
|
|
|
splited_shape.extend(orig_shape[1:])
|
|
|
|
|
new_var_name = ""
|
|
|
|
|
if add_trainer_suffix:
|
|
|
|
|
if self.sync_mode and add_trainer_suffix:
|
|
|
|
|
new_var_name = "%s.block%d.trainer_%d" % \
|
|
|
|
|
(varname, i, self.trainer_id)
|
|
|
|
|
else:
|
|
|
|
@ -907,7 +909,7 @@ class DistributeTranspiler:
|
|
|
|
|
pserver_block.vars[self._orig_varname(grad_block.name)]
|
|
|
|
|
grad_to_block_id.append(merged_var.name + ":" + str(
|
|
|
|
|
optimize_block.idx))
|
|
|
|
|
if self.trainer_num > 1:
|
|
|
|
|
if self.sync_mode and self.trainer_num > 1:
|
|
|
|
|
vars2merge = []
|
|
|
|
|
for i in xrange(self.trainer_num):
|
|
|
|
|
per_trainer_name = "%s.trainer_%d" % \
|
|
|
|
@ -925,6 +927,7 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={"X": merged_var},
|
|
|
|
|
outputs={"Out": merged_var},
|
|
|
|
|
attrs={"scale": 1.0 / float(self.trainer_num)})
|
|
|
|
|
|
|
|
|
|
new_inputs[key] = merged_var
|
|
|
|
|
elif key == "Param":
|
|
|
|
|
# param is already created on global program
|
|
|
|
|