|
|
@ -28,7 +28,7 @@ from paddle.fluid.data_feeder import check_type
|
|
|
|
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
|
|
|
|
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticFunction, unwrap_decorators
|
|
|
|
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
|
|
|
|
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
|
|
|
|
from paddle.fluid.dygraph.layers import Layer
|
|
|
|
from paddle.fluid.dygraph.layers import Layer
|
|
|
|
from paddle.fluid.executor import Executor, scope_guard
|
|
|
|
from paddle.fluid.executor import Executor, scope_guard
|
|
|
@ -141,7 +141,7 @@ def copy_decorator_attrs(original_func, decorated_obj):
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
original_func(callable): the original decorated function.
|
|
|
|
original_func(callable): the original decorated function.
|
|
|
|
decorated_obj(StaticLayer): the target decorated StaticLayer object.
|
|
|
|
decorated_obj(StaticFunction): the target decorated StaticFunction object.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
decorator_name = "declarative"
|
|
|
|
decorator_name = "declarative"
|
|
|
|
|
|
|
|
|
|
|
@ -198,7 +198,7 @@ def declarative(function=None, input_spec=None):
|
|
|
|
|
|
|
|
|
|
|
|
def decorated(python_func):
|
|
|
|
def decorated(python_func):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Decorates a python function into a StaticLayer object.
|
|
|
|
Decorates a python function into a StaticFunction object.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Step 1. unwrap the function if it is already decorated.
|
|
|
|
# Step 1. unwrap the function if it is already decorated.
|
|
|
|
_, python_func = unwrap_decorators(python_func)
|
|
|
|
_, python_func = unwrap_decorators(python_func)
|
|
|
@ -206,7 +206,7 @@ def declarative(function=None, input_spec=None):
|
|
|
|
# Step 2. copy some attributes from original python function.
|
|
|
|
# Step 2. copy some attributes from original python function.
|
|
|
|
static_layer = copy_decorator_attrs(
|
|
|
|
static_layer = copy_decorator_attrs(
|
|
|
|
original_func=python_func,
|
|
|
|
original_func=python_func,
|
|
|
|
decorated_obj=StaticLayer(
|
|
|
|
decorated_obj=StaticFunction(
|
|
|
|
function=python_func, input_spec=input_spec))
|
|
|
|
function=python_func, input_spec=input_spec))
|
|
|
|
|
|
|
|
|
|
|
|
return static_layer
|
|
|
|
return static_layer
|
|
|
@ -214,7 +214,7 @@ def declarative(function=None, input_spec=None):
|
|
|
|
# for usage: `declarative(foo, ...)`
|
|
|
|
# for usage: `declarative(foo, ...)`
|
|
|
|
if function is not None:
|
|
|
|
if function is not None:
|
|
|
|
if isinstance(function, Layer):
|
|
|
|
if isinstance(function, Layer):
|
|
|
|
if isinstance(function.forward, StaticLayer):
|
|
|
|
if isinstance(function.forward, StaticFunction):
|
|
|
|
class_name = function.__class__.__name__
|
|
|
|
class_name = function.__class__.__name__
|
|
|
|
logging_utils.warn(
|
|
|
|
logging_utils.warn(
|
|
|
|
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
|
|
|
|
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
|
|
|
@ -868,7 +868,7 @@ def save(layer, model_path, input_spec=None, config=None):
|
|
|
|
|
|
|
|
|
|
|
|
# 2. get program from Layer
|
|
|
|
# 2. get program from Layer
|
|
|
|
# TODO(chenweihang): add support for other method, not only forward
|
|
|
|
# TODO(chenweihang): add support for other method, not only forward
|
|
|
|
if isinstance(layer.forward, StaticLayer):
|
|
|
|
if isinstance(layer.forward, StaticFunction):
|
|
|
|
concrete_program = layer.forward.concrete_program
|
|
|
|
concrete_program = layer.forward.concrete_program
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# transform in jit.save, if input_spec is incomplete, declarative will throw error
|
|
|
|
# transform in jit.save, if input_spec is incomplete, declarative will throw error
|
|
|
|