|
|
|
|
@ -401,11 +401,8 @@ class DistributeTranspiler:
|
|
|
|
|
# HACK: optimization global ops only used to scale beta1 and beta2
|
|
|
|
|
# replace it with dependency engine.
|
|
|
|
|
for op in self.optimize_ops:
|
|
|
|
|
if op.type == "scale":
|
|
|
|
|
for in_name in op.input_arg_names:
|
|
|
|
|
if in_name.startswith("beta1_pow_acc") or \
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
global_ops.append(op)
|
|
|
|
|
if self._is_adam_connected_op(op):
|
|
|
|
|
global_ops.append(op)
|
|
|
|
|
|
|
|
|
|
def __append_optimize_op__(op, block, grad_to_block_id):
|
|
|
|
|
if self._is_opt_op(op):
|
|
|
|
|
@ -1152,13 +1149,20 @@ class DistributeTranspiler:
|
|
|
|
|
op.input("Param")[0]),
|
|
|
|
|
self.origin_program.global_block().var(
|
|
|
|
|
op.input("Grad")[0])))
|
|
|
|
|
elif op.type == "scale":
|
|
|
|
|
# for adam optimize op
|
|
|
|
|
for in_name in op.input_arg_names:
|
|
|
|
|
if in_name.startswith("beta1_pow_acc") or \
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
opt_ops.append(op)
|
|
|
|
|
break
|
|
|
|
|
elif self._is_adam_connected_op(op):
|
|
|
|
|
opt_ops.append(op)
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
return opt_ops, params_grads
|
|
|
|
|
|
|
|
|
|
def _is_adam_connected_op(self, op):
|
|
|
|
|
"""
|
|
|
|
|
A hack function to determinate whether the input operator
|
|
|
|
|
is connected to optimize operator.
|
|
|
|
|
"""
|
|
|
|
|
if op.type == "scale":
|
|
|
|
|
for in_name in op.input_arg_names:
|
|
|
|
|
if in_name.startswith("beta1_pow_acc") or \
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|