Merge pull request #9558 from typhoonzero/fix_dist_transpile_one_trainer

fix single pserver single trainer error
helinwang-patch-1
武毅 7 years ago committed by GitHub
commit b9d8bbe4f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -276,12 +276,15 @@ class DistributeTranspiler:
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
pserver_program.global_block().create_var( else:
orig_var_name = v.name
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name, name=orig_var_name,
persistable=True, persistable=True,
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
if self.trainers > 1:
for trainer_id in xrange(self.trainers): for trainer_id in xrange(self.trainers):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
@ -290,6 +293,8 @@ class DistributeTranspiler:
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
recv_inputs.append(var) recv_inputs.append(var)
else:
recv_inputs.append(single_trainer_var)
# step3 # step3
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
@ -511,8 +516,11 @@ class DistributeTranspiler:
def _append_split_op(self, program, gradblocks): def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops # Split variables that need to be split and append respective ops
add_suffix = False
if self.trainers > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist( var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=True) program, gradblocks, add_trainer_suffix=add_suffix)
for varname, splited_vars in var_mapping.iteritems(): for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars # variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1: if len(splited_vars) <= 1:

Loading…
Cancel
Save