|
|
|
@ -53,21 +53,21 @@ def create_program_from_desc(program_desc):
|
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_vars(inputs, result_list):
|
|
|
|
|
def _extract_vars(inputs, result_list, err_tag='inputs'):
|
|
|
|
|
if isinstance(inputs, Variable):
|
|
|
|
|
result_list.append(inputs)
|
|
|
|
|
elif isinstance(inputs, (list, tuple)):
|
|
|
|
|
for var in inputs:
|
|
|
|
|
_extract_vars(var, result_list)
|
|
|
|
|
_extract_vars(var, result_list, err_tag)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.".
|
|
|
|
|
format(type(inputs)))
|
|
|
|
|
"The type of 'each element of {}' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.".
|
|
|
|
|
format(err_tag, type(inputs)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_vars(inputs):
|
|
|
|
|
def extract_vars(inputs, err_tag='inputs'):
|
|
|
|
|
result_list = []
|
|
|
|
|
_extract_vars(inputs, result_list)
|
|
|
|
|
_extract_vars(inputs, result_list, err_tag)
|
|
|
|
|
return result_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1032,7 +1032,7 @@ def _trace(layer,
|
|
|
|
|
outputs = [original_outputs]
|
|
|
|
|
else:
|
|
|
|
|
outputs = original_outputs
|
|
|
|
|
out_vars = [var for var in outputs]
|
|
|
|
|
out_vars = extract_vars(outputs, err_tag='outputs')
|
|
|
|
|
|
|
|
|
|
program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
|
|
|
|
|
var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
|
|
|
|
|