fix dist transpiler bug

fea/docker_cudnn7
typhoonzero 7 years ago
parent ddff83ff14
commit d02b17e597

@ -278,6 +278,15 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives. # we don't need to create them when grad arrives.
# change client side var name to origin name by # change client side var name to origin name by
# removing ".trainer_%d" suffix # removing ".trainer_%d" suffix
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var = \
pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
@ -293,12 +302,6 @@ class DistributeTranspiler:
shape=v.shape) shape=v.shape)
recv_inputs.append(var) recv_inputs.append(var)
else: else:
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(single_trainer_var) recv_inputs.append(single_trainer_var)
# step3 # step3

Loading…
Cancel
Save