follow comments

emailweixu-patch-1
typhoonzero 7 years ago
parent c74445017d
commit 7a6000a0b8

@ -385,7 +385,7 @@ class DistributeTranspiler:
# param is already created on global program
param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, opt_op.input(key)):
if same_or_split_var(p.name, opt_op.input(key)[0]):
param_block = p
break
if not param_block:
@ -403,7 +403,7 @@ class DistributeTranspiler:
continue
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)]
var = program.global_block().vars[opt_op.input(key)[0]]
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
tmpvar = program.global_block().create_var(
@ -440,20 +440,18 @@ class DistributeTranspiler:
else:
varlist = [var]
for var in varlist:
# TODO(typhoonzero): will remove below line later.
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
try:
if not pserver_program.global_block().vars.has_key(var.name):
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
except ValueError:
# create var if not created yet.
pass
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)

Loading…
Cancel
Save