|
|
|
@ -385,7 +385,7 @@ class DistributeTranspiler:
|
|
|
|
|
# param is already created on global program
|
|
|
|
|
param_block = None
|
|
|
|
|
for p in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
|
if same_or_split_var(p.name, opt_op.input(key)):
|
|
|
|
|
if same_or_split_var(p.name, opt_op.input(key)[0]):
|
|
|
|
|
param_block = p
|
|
|
|
|
break
|
|
|
|
|
if not param_block:
|
|
|
|
@ -403,7 +403,7 @@ class DistributeTranspiler:
|
|
|
|
|
continue
|
|
|
|
|
# update accumulator variable shape
|
|
|
|
|
param_shape = new_inputs["Param"].shape
|
|
|
|
|
var = program.global_block().vars[opt_op.input(key)]
|
|
|
|
|
var = program.global_block().vars[opt_op.input(key)[0]]
|
|
|
|
|
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
|
|
|
|
|
var.shape, param_shape)
|
|
|
|
|
tmpvar = program.global_block().create_var(
|
|
|
|
@ -440,20 +440,18 @@ class DistributeTranspiler:
|
|
|
|
|
else:
|
|
|
|
|
varlist = [var]
|
|
|
|
|
for var in varlist:
|
|
|
|
|
# TODO(typhoonzero): will remove below line later.
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
try:
|
|
|
|
|
if not pserver_program.global_block().vars.has_key(var.name):
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
except ValueError:
|
|
|
|
|
# create var if not created yet.
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
outputs = self._get_output_map_from_op(self.program.global_block().vars,
|
|
|
|
|
opt_op)
|
|
|
|
|