|
|
|
@ -1182,18 +1182,39 @@ class DistributeTranspiler(object):
|
|
|
|
|
program = optimize_block.program
|
|
|
|
|
pserver_block = program.global_block()
|
|
|
|
|
new_inputs = dict()
|
|
|
|
|
|
|
|
|
|
# update param/grad shape first, then other inputs like
|
|
|
|
|
# moment can use the updated shape
|
|
|
|
|
for key in opt_op.input_names:
|
|
|
|
|
if key == "Grad":
|
|
|
|
|
new_inputs[key] = merged_var
|
|
|
|
|
elif key == "Param":
|
|
|
|
|
def _get_param_block(opt_op):
|
|
|
|
|
# param is already created on global program
|
|
|
|
|
param_block = None
|
|
|
|
|
for p in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
|
if same_or_split_var(p.name, opt_op.input(key)[0]):
|
|
|
|
|
if same_or_split_var(p.name, opt_op.input("Param")[0]):
|
|
|
|
|
param_block = p
|
|
|
|
|
break
|
|
|
|
|
return param_block
|
|
|
|
|
|
|
|
|
|
for key in opt_op.input_names:
|
|
|
|
|
if key == "Grad":
|
|
|
|
|
new_inputs[key] = merged_var
|
|
|
|
|
# For RMSProp optimizer
|
|
|
|
|
elif key == "Moment" or key == "MeanSquare":
|
|
|
|
|
param_block = _get_param_block(opt_op)
|
|
|
|
|
if not param_block:
|
|
|
|
|
return
|
|
|
|
|
moment_var = origin_program.global_block().vars[opt_op.input(
|
|
|
|
|
key)[0]]
|
|
|
|
|
tmpvar = pserver_block.create_var(
|
|
|
|
|
name=moment_var.name,
|
|
|
|
|
persistable=moment_var.persistable,
|
|
|
|
|
dtype=moment_var.dtype,
|
|
|
|
|
# change to use same shape as param
|
|
|
|
|
# TODO(typhoonzero): didn't append .block in the var name,
|
|
|
|
|
# may affect checkpoint saving? Need to verify.
|
|
|
|
|
shape=param_block.shape)
|
|
|
|
|
new_inputs[key] = tmpvar
|
|
|
|
|
elif key == "Param":
|
|
|
|
|
param_block = _get_param_block(opt_op)
|
|
|
|
|
if not param_block:
|
|
|
|
|
return
|
|
|
|
|
tmpvar = pserver_block.create_var(
|
|
|
|
@ -1219,7 +1240,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
for key in opt_op.input_names:
|
|
|
|
|
new_shape = None
|
|
|
|
|
if key in ["Param", "Grad", "LearningRate"]:
|
|
|
|
|
if key in ["Param", "Grad", "LearningRate", "Moment", "MeanSquare"]:
|
|
|
|
|
continue
|
|
|
|
|
var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
|
|
|
|
|
# update accumulator variable shape
|
|
|
|
|