|
|
|
@ -377,11 +377,6 @@ class DistributeTranspiler(object):
|
|
|
|
|
# append it into the sub program.
|
|
|
|
|
|
|
|
|
|
global_ops = []
|
|
|
|
|
# HACK: optimization global ops only used to scale beta1 and beta2
|
|
|
|
|
# replace it with dependency engine.
|
|
|
|
|
for op in self.optimize_ops:
|
|
|
|
|
if self._is_adam_connected_op(op):
|
|
|
|
|
global_ops.append(op)
|
|
|
|
|
|
|
|
|
|
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
|
|
|
|
|
lr_ops):
|
|
|
|
@ -1289,22 +1284,16 @@ class DistributeTranspiler(object):
|
|
|
|
|
# If one op's input is another op's output or
|
|
|
|
|
# one op's output is another op's input, we say
|
|
|
|
|
# the two operator is connected.
|
|
|
|
|
def _append_inname_remove_beta(varname_list):
|
|
|
|
|
def _append_inname(varname_list):
|
|
|
|
|
op_input_names = []
|
|
|
|
|
for in_name in varname_list:
|
|
|
|
|
# HACK: remove beta1 and beta2 to avoid let all
|
|
|
|
|
# ops connected.
|
|
|
|
|
if in_name.startswith("beta2_pow_acc") or \
|
|
|
|
|
in_name.startswith("beta1_pow_acc"):
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
op_input_names.append(in_name)
|
|
|
|
|
op_input_names.append(in_name)
|
|
|
|
|
return op_input_names
|
|
|
|
|
|
|
|
|
|
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
|
|
|
|
|
op1_input_names = _append_inname(op1.desc.input_arg_names())
|
|
|
|
|
op1_output_names = op1.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|
op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
|
|
|
|
|
op2_input_names = _append_inname(op2.desc.input_arg_names())
|
|
|
|
|
op2_output_names = op2.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|
if set(op1_output_names) & set(op2_input_names) or \
|
|
|
|
@ -1413,7 +1402,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
def _get_optimize_pass(self):
|
|
|
|
|
"""
|
|
|
|
|
Get optimizer operators, paramters and gradients from origin_program
|
|
|
|
|
Get optimizer operators, parameters and gradients from origin_program
|
|
|
|
|
Returns:
|
|
|
|
|
opt_ops (list): optimize operators.
|
|
|
|
|
params_grads (dict): paramter->gradient.
|
|
|
|
@ -1436,20 +1425,6 @@ class DistributeTranspiler(object):
|
|
|
|
|
origin_var_dict[param_name],
|
|
|
|
|
origin_var_dict[input_name]
|
|
|
|
|
])
|
|
|
|
|
elif self._is_adam_connected_op(op):
|
|
|
|
|
opt_ops.append(op)
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
return opt_ops, params_grads
|
|
|
|
|
|
|
|
|
|
def _is_adam_connected_op(self, op):
|
|
|
|
|
"""
|
|
|
|
|
A hack function to determinate whether the input operator
|
|
|
|
|
is connected to optimize operator.
|
|
|
|
|
"""
|
|
|
|
|
if op.type == "scale":
|
|
|
|
|
for in_name in op.input_arg_names:
|
|
|
|
|
if in_name.startswith("beta1_pow_acc") or \
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|