|
|
|
@ -13,32 +13,38 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
import gast
|
|
|
|
|
|
|
|
|
|
import collections
|
|
|
|
|
import inspect
|
|
|
|
|
import warnings
|
|
|
|
|
import textwrap
|
|
|
|
|
import threading
|
|
|
|
|
import collections
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
import gast
|
|
|
|
|
import numpy as np
|
|
|
|
|
from paddle.fluid import core, scope_guard
|
|
|
|
|
from paddle.fluid import framework
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from paddle.fluid import executor
|
|
|
|
|
from paddle.fluid import framework
|
|
|
|
|
from paddle.fluid import scope_guard
|
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
|
from paddle.fluid.data_feeder import check_type
|
|
|
|
|
from paddle.fluid.dygraph import layers
|
|
|
|
|
from paddle.fluid.layers.utils import flatten
|
|
|
|
|
from paddle.fluid.layers.utils import pack_sequence_as
|
|
|
|
|
from paddle.fluid.dygraph.base import param_guard
|
|
|
|
|
from paddle.fluid.dygraph.base import switch_to_static_graph
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap
|
|
|
|
|
from paddle.fluid.layers.utils import flatten
|
|
|
|
|
from paddle.fluid.layers.utils import pack_sequence_as
|
|
|
|
|
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
from paddle.fluid.dygraph.base import param_guard
|
|
|
|
|
from paddle.fluid.data_feeder import check_type
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA
|
|
|
|
|
|
|
|
|
|
__all__ = ['ProgramTranslator', 'convert_to_static']
|
|
|
|
|
|
|
|
|
@ -89,7 +95,7 @@ class FunctionCache(object):
|
|
|
|
|
"""
|
|
|
|
|
# Note: In Python2, it will raise OSError when inspect function
|
|
|
|
|
# with decorator directly and function.__wrapped__ holds the actual function.
|
|
|
|
|
func = getattr(func, '__wrapped__', func)
|
|
|
|
|
func = unwrap(func)
|
|
|
|
|
source_code = func_to_source_code(func)
|
|
|
|
|
|
|
|
|
|
# TODO(liym27):
|
|
|
|
@ -669,7 +675,9 @@ class ProgramTranslator(object):
|
|
|
|
|
dygraph_func
|
|
|
|
|
), "Input dygraph_func is not a callable in ProgramTranslator.get_code"
|
|
|
|
|
# Gets AST from dygraph function
|
|
|
|
|
raw_code = inspect.getsource(dygraph_func)
|
|
|
|
|
|
|
|
|
|
unwrap_func = unwrap(dygraph_func)
|
|
|
|
|
raw_code = inspect.getsource(unwrap_func)
|
|
|
|
|
code = textwrap.dedent(raw_code)
|
|
|
|
|
root = gast.parse(code)
|
|
|
|
|
|
|
|
|
|