|
|
|
@ -307,15 +307,57 @@ class DistributeTranspiler:
|
|
|
|
|
# Iterate through the ops, and if an op and the optimize ops
|
|
|
|
|
# which located on current pserver are in one set, then
|
|
|
|
|
# append it into the sub program.
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
for _, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
if ufind.is_connected(op, opt_op):
|
|
|
|
|
if self._is_opt_op(op):
|
|
|
|
|
self._append_pserver_ops(optimize_block, op, endpoint,
|
|
|
|
|
default_main_program())
|
|
|
|
|
else:
|
|
|
|
|
self._append_pserver_non_opt_ops(optimize_block, op)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# We try to put optimization program run parallelly, assume
|
|
|
|
|
# optimization program always looks like:
|
|
|
|
|
#
|
|
|
|
|
# prevop -> prevop -> opt op -> following op -> following op; ->
|
|
|
|
|
# prevop -> prevop -> opt op -> following op -> following op; ->
|
|
|
|
|
# global op -> global op
|
|
|
|
|
#
|
|
|
|
|
# we put operators that can run parallelly to many program blocks.
|
|
|
|
|
# in above example, we seperate ops by the ";". Global ops must run
|
|
|
|
|
# after all the optimize ops finished.
|
|
|
|
|
|
|
|
|
|
global_ops = []
|
|
|
|
|
# HACK: optimization global ops only used to scale beta1 and beta2
|
|
|
|
|
# replace it with dependency engine.
|
|
|
|
|
for op in self.optimize_ops:
|
|
|
|
|
if op.type == "scale":
|
|
|
|
|
for in_name in op.input_arg_names:
|
|
|
|
|
if in_name.startswith("beta1_pow_acc") or\
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
global_ops.append(op)
|
|
|
|
|
|
|
|
|
|
def __append_optimize_op__(op, block):
|
|
|
|
|
if self._is_opt_op(op):
|
|
|
|
|
self._append_pserver_ops(block, op, endpoint,
|
|
|
|
|
default_main_program())
|
|
|
|
|
else:
|
|
|
|
|
self._append_pserver_non_opt_ops(block, op)
|
|
|
|
|
|
|
|
|
|
# append op to the current block
|
|
|
|
|
per_opt_block = optimize_block
|
|
|
|
|
for _, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if ufind.is_connected(op, opt_op) and \
|
|
|
|
|
op not in global_ops:
|
|
|
|
|
__append_optimize_op__(op, per_opt_block)
|
|
|
|
|
per_opt_block = pserver_program.create_block(0)
|
|
|
|
|
|
|
|
|
|
# append global ops
|
|
|
|
|
for glb_op in global_ops:
|
|
|
|
|
__append_optimize_op__(glb_op, per_opt_block)
|
|
|
|
|
|
|
|
|
|
# NOT USED: single block version:
|
|
|
|
|
#
|
|
|
|
|
# for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# for _, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
# if ufind.is_connected(op, opt_op):
|
|
|
|
|
# __append_optimize_op__(glb_op, optimize_block)
|
|
|
|
|
# break
|
|
|
|
|
|
|
|
|
|
# step5 append the listen_and_serv op
|
|
|
|
|
pserver_program.global_block().append_op(
|
|
|
|
|
type="listen_and_serv",
|
|
|
|
@ -660,10 +702,22 @@ class DistributeTranspiler:
|
|
|
|
|
# If one op's input is another op's output or
|
|
|
|
|
# one op's output is another op's input, we say
|
|
|
|
|
# the two operator is connected.
|
|
|
|
|
op1_input_names = op1.desc.input_arg_names()
|
|
|
|
|
def _append_inname_remove_beta(varname_list):
|
|
|
|
|
op_input_names = []
|
|
|
|
|
for in_name in varname_list:
|
|
|
|
|
# HACK: remove beta1 and beta2 to avoid let all
|
|
|
|
|
# ops connected.
|
|
|
|
|
if in_name.startswith("beta2_pow_acc") or \
|
|
|
|
|
in_name.startswith("beta1_pow_acc"):
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
op_input_names.append(in_name)
|
|
|
|
|
return op_input_names
|
|
|
|
|
|
|
|
|
|
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
|
|
|
|
|
op1_output_names = op1.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|
op2_input_names = op2.desc.input_arg_names()
|
|
|
|
|
op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
|
|
|
|
|
op2_output_names = op2.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|
if set(op1_output_names) & set(op2_input_names) or \
|
|
|
|
|