|
|
|
@ -278,6 +278,12 @@ class DistributeTranspiler:
|
|
|
|
|
# we don't need to create them when grad arrives.
|
|
|
|
|
# change client side var name to origin name by
|
|
|
|
|
# removing ".trainer_%d" suffix
|
|
|
|
|
|
|
|
|
|
suff_idx = v.name.find(".trainer_")
|
|
|
|
|
if suff_idx >= 0:
|
|
|
|
|
orig_var_name = v.name[:suff_idx]
|
|
|
|
|
else:
|
|
|
|
|
orig_var_name = v.name
|
|
|
|
|
# NOTE: single_trainer_var must be created for multi-trainer
|
|
|
|
|
# case to merge grads from multiple trainers
|
|
|
|
|
single_trainer_var = \
|
|
|
|
@ -287,11 +293,6 @@ class DistributeTranspiler:
|
|
|
|
|
type=v.type,
|
|
|
|
|
dtype=v.dtype,
|
|
|
|
|
shape=v.shape)
|
|
|
|
|
suff_idx = v.name.find(".trainer_")
|
|
|
|
|
if suff_idx >= 0:
|
|
|
|
|
orig_var_name = v.name[:suff_idx]
|
|
|
|
|
else:
|
|
|
|
|
orig_var_name = v.name
|
|
|
|
|
if self.trainers > 1:
|
|
|
|
|
for trainer_id in xrange(self.trainers):
|
|
|
|
|
var = pserver_program.global_block().create_var(
|
|
|
|
|