|
|
|
@ -276,20 +276,25 @@ class DistributeTranspiler:
|
|
|
|
|
suff_idx = v.name.find(".trainer_")
|
|
|
|
|
if suff_idx >= 0:
|
|
|
|
|
orig_var_name = v.name[:suff_idx]
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
else:
|
|
|
|
|
orig_var_name = v.name
|
|
|
|
|
single_trainer_var = pserver_program.global_block().create_var(
|
|
|
|
|
name=orig_var_name,
|
|
|
|
|
persistable=True,
|
|
|
|
|
type=v.type,
|
|
|
|
|
dtype=v.dtype,
|
|
|
|
|
shape=v.shape)
|
|
|
|
|
for trainer_id in xrange(self.trainers):
|
|
|
|
|
var = pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d" % (orig_var_name, trainer_id),
|
|
|
|
|
persistable=False,
|
|
|
|
|
type=v.type,
|
|
|
|
|
dtype=v.dtype,
|
|
|
|
|
shape=v.shape)
|
|
|
|
|
recv_inputs.append(var)
|
|
|
|
|
if self.trainers > 1:
|
|
|
|
|
for trainer_id in xrange(self.trainers):
|
|
|
|
|
var = pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d" % (orig_var_name, trainer_id),
|
|
|
|
|
persistable=False,
|
|
|
|
|
type=v.type,
|
|
|
|
|
dtype=v.dtype,
|
|
|
|
|
shape=v.shape)
|
|
|
|
|
recv_inputs.append(var)
|
|
|
|
|
else:
|
|
|
|
|
recv_inputs.append(single_trainer_var)
|
|
|
|
|
|
|
|
|
|
# step3
|
|
|
|
|
optimize_block = pserver_program.create_block(0)
|
|
|
|
@ -511,8 +516,11 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
def _append_split_op(self, program, gradblocks):
|
|
|
|
|
# Split variables that need to be split and append respective ops
|
|
|
|
|
add_suffix = False
|
|
|
|
|
if self.trainers > 1:
|
|
|
|
|
add_suffix = True
|
|
|
|
|
var_mapping = self._create_vars_from_blocklist(
|
|
|
|
|
program, gradblocks, add_trainer_suffix=True)
|
|
|
|
|
program, gradblocks, add_trainer_suffix=add_suffix)
|
|
|
|
|
for varname, splited_vars in var_mapping.iteritems():
|
|
|
|
|
# variable that don't need to split have empty splited_vars
|
|
|
|
|
if len(splited_vars) <= 1:
|
|
|
|
|