[Dy2stat] Add Test and Example Code for Different Access to ProgramTranslator and Fix Related Bug (#23958)

To prepare for publishing APIs, I added tests for that we can access dy2stat through:

@fluid.dygraph.declarative
@fluid.dygraph.jit.declarative
fluid.dygraph.ProgramTranslator()
fluid.dygraph.dygraph_to_static.ProgramTranslator()
fluid.dygraph.dygraph_to_static.program_translator.ProgramTranslator()

It surprised me that we had bugs on those different usages. I have fixed them.

I also added example codes for these new APIs

This PR also pulls my current PR https://github.com/PaddlePaddle/Paddle/pull/23880, so the PR history is long. For reviewer information, you could review this PR after https://github.com/PaddlePaddle/Paddle/pull/23880 is merged
revert-22778-infer_var_type
Huihuang Zheng 5 years ago committed by GitHub
parent 2291634c5c
commit 45e48c3c32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,6 +47,9 @@ from .jit import *
from . import static_runner
from .static_runner import StaticModelRunner
from . import dygraph_to_static
from .dygraph_to_static import ProgramTranslator
__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
@ -57,3 +60,4 @@ __all__ += checkpoint.__all__
__all__ += learning_rate_scheduler.__all__
__all__ += backward_strategy.__all__
__all__ += jit.__all__
__all__ += ['ProgramTranslator']

@ -37,9 +37,10 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
__all__ = ['DygraphToStaticAst', 'convert_to_static']
@ -96,9 +97,24 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = [
d for d in node.decorator_list if d.id not in DECORATOR_NAMES
]
decorator_list = []
for d in node.decorator_list:
if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ d.id + " in " + self.decorate_func_name)
if isinstance(d, gast.Attribute):
full_attribute_name = get_attribute_full_name(d)
has_translate_decorator = False
for deco in DECORATOR_NAMES:
if deco in full_attribute_name:
has_translate_decorator = True
break
if not has_translate_decorator:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ full_attribute_name + " in " +
self.decorate_func_name)
node.decorator_list = decorator_list
return node

@ -107,7 +107,7 @@ def _dygraph_to_static_func_(dygraph_func):
if in_dygraph_mode() or not program_translator.enable_declarative:
logger.info(
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set enable_declarative_function to False. "
"dygraph mode or set ProgramTranslator.enable to False. "
"We will just return dygraph output.")
return dygraph_func(*args, **kwargs)
static_func = program_translator.get_func(dygraph_func)
@ -159,7 +159,7 @@ def _declarative_(dygraph_func):
if in_dygraph_mode() or not program_translator.enable_declarative:
logger.info(
"The decorator 'declarative' doesn't work in dygraph "
"mode or set enable_declarative_function to False. We will "
"mode or set ProgramTranslator.enable to False. We will "
"just return dygraph output.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()

@ -0,0 +1,93 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import declarative
@fluid.dygraph.declarative
def dygraph_decorated_func(x):
x = fluid.dygraph.to_variable(x)
if fluid.layers.mean(x) > 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
@fluid.dygraph.jit.declarative
def jit_decorated_func(x):
x = fluid.dygraph.to_variable(x)
if fluid.layers.mean(x) > 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
@fluid.dygraph.declarative
def decorated_call_decorated(x):
return jit_decorated_func(x)
class DoubleDecorated(object):
@classmethod
@declarative
def double_decorated_func1(self, x):
return dygraph_decorated_func(x)
@classmethod
@fluid.dygraph.declarative
def double_decorated_func2(self, x):
return jit_decorated_func(x)
class TestFullNameDecorator(unittest.TestCase):
def test_run_success(self):
x = np.ones([1, 2]).astype("float32")
answer = np.zeros([1, 2]).astype("float32")
with fluid.program_guard(fluid.Program(), fluid.Program()):
self.assertTrue(
np.allclose(dygraph_decorated_func(x).numpy(), answer))
with fluid.program_guard(fluid.Program(), fluid.Program()):
self.assertTrue(np.allclose(jit_decorated_func(x).numpy(), answer))
with fluid.program_guard(fluid.Program(), fluid.Program()):
self.assertTrue(
np.allclose(decorated_call_decorated(x).numpy(), answer))
with fluid.program_guard(fluid.Program(), fluid.Program()):
with self.assertRaises(NotImplementedError):
DoubleDecorated().double_decorated_func1(x)
with fluid.program_guard(fluid.Program(), fluid.Program()):
with self.assertRaises(NotImplementedError):
DoubleDecorated().double_decorated_func2(x)
class TestImportProgramTranslator(unittest.TestCase):
def test_diff_pkg_same_cls(self):
dygraph_prog_trans = fluid.dygraph.ProgramTranslator()
dy_to_stat_prog_trans = fluid.dygraph.dygraph_to_static.ProgramTranslator(
)
full_pkg_prog_trans = fluid.dygraph.dygraph_to_static.program_translator.ProgramTranslator(
)
self.assertEqual(dygraph_prog_trans, dy_to_stat_prog_trans)
self.assertEqual(dygraph_prog_trans, full_pkg_prog_trans)
if __name__ == '__main__':
unittest.main()

@ -123,11 +123,11 @@ class TestEnableDeclarative(unittest.TestCase):
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
program_translator.enable(True)
static_output = program_translator.get_output(simple_func, x,
weight)
program_translator.enable_declarative_function(False)
program_translator.enable(False)
with fluid.dygraph.guard():
dygraph_output = program_translator.get_output(simple_func, x,
weight)
@ -141,13 +141,13 @@ class TestEnableDeclarative(unittest.TestCase):
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
program_translator.enable(True)
static_func = program_translator.get_func(simple_func)
self.assertTrue(callable(static_func))
static_output = static_func(x, weight)
self.assertTrue(isinstance(static_output, fluid.Variable))
program_translator.enable_declarative_function(False)
program_translator.enable(False)
with fluid.dygraph.guard():
dygraph_func = program_translator.get_func(simple_func)
self.assertTrue(callable(dygraph_func))
@ -160,7 +160,7 @@ class TestEnableDeclarative(unittest.TestCase):
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
program_translator.enable(True)
static_output = program_translator.get_program(simple_func, x,
weight)
self.assertTrue(isinstance(static_output, tuple))
@ -168,7 +168,7 @@ class TestEnableDeclarative(unittest.TestCase):
self.assertTrue(isinstance(static_output[0], fluid.Program))
self.assertTrue(isinstance(static_output[1], fluid.Program))
program_translator.enable_declarative_function(False)
program_translator.enable(False)
with fluid.dygraph.guard():
dygraph_output = program_translator.get_program(simple_func, x,
weight)
@ -180,10 +180,10 @@ class TestEnableDeclarative(unittest.TestCase):
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
program_translator.enable(True)
static_output = decorated_simple_func(x, weight)
program_translator.enable_declarative_function(False)
program_translator.enable(False)
with fluid.dygraph.guard():
dygraph_output = decorated_simple_func(x, weight)
self.assertTrue(

@ -40,7 +40,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
y = self._linear(x)
z = self._linear(y)
out = fluid.layers.mean(z)
return out
return out, y
class TestDyToStaticSaveInferenceModel(unittest.TestCase):
@ -69,6 +69,15 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
])
self.assertEqual(saved_var_names, expected_persistable_vars)
infer_model_dir = "./test_dy2stat_save_inference_model_with_fetch"
ProgramTranslator.get_instance().save_inference_model(
infer_model_dir, fetch=[0])
saved_var_names = set([
filename for filename in os.listdir(infer_model_dir)
if filename != '__model__'
])
self.assertEqual(saved_var_names, expected_persistable_vars)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save