|
|
|
@ -1284,20 +1284,8 @@ class DistributeTranspiler(object):
|
|
|
|
|
# If one op's input is another op's output or
|
|
|
|
|
# one op's output is another op's input, we say
|
|
|
|
|
# the two operator is connected.
|
|
|
|
|
def _append_inname(varname_list):
|
|
|
|
|
op_input_names = []
|
|
|
|
|
for in_name in varname_list:
|
|
|
|
|
op_input_names.append(in_name)
|
|
|
|
|
return op_input_names
|
|
|
|
|
|
|
|
|
|
op1_input_names = _append_inname(op1.desc.input_arg_names())
|
|
|
|
|
op1_output_names = op1.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|
op2_input_names = _append_inname(op2.desc.input_arg_names())
|
|
|
|
|
op2_output_names = op2.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|
if set(op1_output_names) & set(op2_input_names) or \
|
|
|
|
|
set(op1_input_names) & set(op2_output_names):
|
|
|
|
|
if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or \
|
|
|
|
|
set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|