Fix bug in ProgramTranslator.get_output, convert all items into VarBase in nested list. (#24267)

revert-24314-dev/fix_err_msg
liym27 5 years ago committed by GitHub
parent 381492fca3
commit e8869a907b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,8 +27,9 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_st
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layers.utils import map_structure
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
@ -403,10 +404,11 @@ class ProgramTranslator(object):
if not program_cache.in_build_process:
outputs = self._run(*args, **kwargs)
with guard():
outputs = map_structure(to_variable, outputs)
if len(outputs) == 1:
outputs = to_variable(outputs[0])
outputs = outputs[0]
else:
outputs = tuple(to_variable(x) for x in outputs)
outputs = tuple(outputs)
return outputs
def get_func(self, dygraph_func):

Loading…
Cancel
Save