|
|
|
@ -27,8 +27,9 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_st
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
|
|
|
|
|
from paddle.fluid.framework import in_dygraph_mode
|
|
|
|
|
from paddle.fluid.data_feeder import check_type
|
|
|
|
|
from paddle.fluid.framework import in_dygraph_mode
|
|
|
|
|
from paddle.fluid.layers.utils import map_structure
|
|
|
|
|
|
|
|
|
|
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
|
|
|
|
|
|
|
|
|
@ -403,10 +404,11 @@ class ProgramTranslator(object):
|
|
|
|
|
if not program_cache.in_build_process:
|
|
|
|
|
outputs = self._run(*args, **kwargs)
|
|
|
|
|
with guard():
|
|
|
|
|
outputs = map_structure(to_variable, outputs)
|
|
|
|
|
if len(outputs) == 1:
|
|
|
|
|
outputs = to_variable(outputs[0])
|
|
|
|
|
outputs = outputs[0]
|
|
|
|
|
else:
|
|
|
|
|
outputs = tuple(to_variable(x) for x in outputs)
|
|
|
|
|
outputs = tuple(outputs)
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def get_func(self, dygraph_func):
|
|
|
|
|