|
|
|
@ -353,16 +353,6 @@ class DistributeTranspiler:
|
|
|
|
|
pass
|
|
|
|
|
return orig_shape
|
|
|
|
|
|
|
|
|
|
# def _fetch_var_names(self, param_dict):
|
|
|
|
|
# res = []
|
|
|
|
|
# if not param_dict:
|
|
|
|
|
# return res
|
|
|
|
|
# for _, values in param_dict.iteritems():
|
|
|
|
|
# if not isinstance(values, list):
|
|
|
|
|
# values = [values]
|
|
|
|
|
# res += [v.name for v in values]
|
|
|
|
|
# return res
|
|
|
|
|
|
|
|
|
|
def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
|
|
|
|
|
program = optimize_block.program
|
|
|
|
|
pserver_block = program.global_block()
|
|
|
|
@ -483,13 +473,9 @@ class DistributeTranspiler:
|
|
|
|
|
# If one op's input is another op's output or
|
|
|
|
|
# one op's output is another op's input, we say
|
|
|
|
|
# the two operator is connected.
|
|
|
|
|
# op1_input_names = self._fetch_var_names(op1.inputs)
|
|
|
|
|
# op1_output_names = self._fetch_var_names(op1.outputs)
|
|
|
|
|
op1_input_names = op1.desc.input_arg_names()
|
|
|
|
|
op1_output_names = op1.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|
# op2_input_names = self._fetch_var_names(op2.inputs)
|
|
|
|
|
# op2_output_names = self._fetch_var_names(op2.outputs)
|
|
|
|
|
op2_input_names = op2.desc.input_arg_names()
|
|
|
|
|
op2_output_names = op2.desc.output_arg_names()
|
|
|
|
|
|
|
|
|
|