|
|
|
@ -1706,13 +1706,27 @@ to transpile() call.")
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
attrs=opt_op.all_attrs())
|
|
|
|
|
|
|
|
|
|
def _is_splited_grad_var(self, var, var_dict):
|
|
|
|
|
def _get_pserver_grad_param_var(self, var, var_dict):
|
|
|
|
|
"""
|
|
|
|
|
Return pserver side grad/param variable, return None
|
|
|
|
|
if the variable is not grad/param, e.g.
|
|
|
|
|
|
|
|
|
|
a@GRAD -> a@GRAD.block0
|
|
|
|
|
a@GRAD -> a@GRAD (a is not splited)
|
|
|
|
|
fc_0.w_0 -> fc_0.w_0.block_0
|
|
|
|
|
fc_0.w_0 -> fc_0.w_0 (weight is not splited)
|
|
|
|
|
_generated_var_123 -> None
|
|
|
|
|
"""
|
|
|
|
|
grad_block = None
|
|
|
|
|
for _, g in six.iteritems(var_dict):
|
|
|
|
|
if self._orig_varname(g.name) == self._orig_varname(var.name):
|
|
|
|
|
# skip per trainer vars
|
|
|
|
|
if g.name.find(".trainer_") == -1:
|
|
|
|
|
grad_block = g
|
|
|
|
|
break
|
|
|
|
|
# only param or grads have splited blocks
|
|
|
|
|
if self._orig_varname(g.name) in self.grad_name_to_param_name or\
|
|
|
|
|
self._orig_varname(g.name) in self.param_name_to_grad_name:
|
|
|
|
|
grad_block = g
|
|
|
|
|
break
|
|
|
|
|
return grad_block
|
|
|
|
|
|
|
|
|
|
def _clone_lr_op(self, program, block, op):
|
|
|
|
@ -1745,32 +1759,38 @@ to transpile() call.")
|
|
|
|
|
for key, varlist in six.iteritems(inputs):
|
|
|
|
|
if not isinstance(varlist, list):
|
|
|
|
|
varlist = [varlist]
|
|
|
|
|
for var in varlist:
|
|
|
|
|
# for ops like clipping and weight decay, get the splited var
|
|
|
|
|
for i in range(len(varlist)):
|
|
|
|
|
var = varlist[i]
|
|
|
|
|
# for ops like clipping and weight decay, get the splited var (xxx.block0)
|
|
|
|
|
# for inputs/outputs
|
|
|
|
|
grad_block = self._is_splited_grad_var(
|
|
|
|
|
grad_block = self._get_pserver_grad_param_var(
|
|
|
|
|
var, program.global_block().vars)
|
|
|
|
|
if grad_block:
|
|
|
|
|
inputs[key] = grad_block
|
|
|
|
|
varlist[i] = grad_block
|
|
|
|
|
elif var.name not in program.global_block().vars:
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
tmpvar = program.global_block()._clone_variable(var)
|
|
|
|
|
varlist[i] = tmpvar
|
|
|
|
|
else:
|
|
|
|
|
varlist[i] = program.global_block().vars[var.name]
|
|
|
|
|
inputs[key] = varlist
|
|
|
|
|
|
|
|
|
|
outputs = self._get_output_map_from_op(
|
|
|
|
|
self.origin_program.global_block().vars, opt_op)
|
|
|
|
|
for key, varlist in six.iteritems(outputs):
|
|
|
|
|
if not isinstance(varlist, list):
|
|
|
|
|
varlist = [varlist]
|
|
|
|
|
for var in varlist:
|
|
|
|
|
grad_block = self._is_splited_grad_var(
|
|
|
|
|
for i in range(len(varlist)):
|
|
|
|
|
var = varlist[i]
|
|
|
|
|
grad_block = self._get_pserver_grad_param_var(
|
|
|
|
|
var, program.global_block().vars)
|
|
|
|
|
if grad_block:
|
|
|
|
|
outputs[key] = grad_block
|
|
|
|
|
varlist[i] = grad_block
|
|
|
|
|
elif var.name not in program.global_block().vars:
|
|
|
|
|
program.global_block()._clone_variable(var)
|
|
|
|
|
tmpvar = program.global_block()._clone_variable(var)
|
|
|
|
|
varlist[i] = tmpvar
|
|
|
|
|
else:
|
|
|
|
|
varlist[i] = program.global_block().vars[var.name]
|
|
|
|
|
outputs[key] = varlist
|
|
|
|
|
|
|
|
|
|
return optimize_block.append_op(
|
|
|
|
|
type=opt_op.type,
|
|
|
|
|