fix dist transpiler bug

fea/docker_cudnn7
typhoonzero 7 years ago
parent ddff83ff14
commit d02b17e597

@ -278,6 +278,15 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# removing ".trainer_%d" suffix
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var = \
pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
@ -293,12 +302,6 @@ class DistributeTranspiler:
shape=v.shape)
recv_inputs.append(var)
else:
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(single_trainer_var)
# step3

Loading…
Cancel
Save