Add static mode check on data() (#27495)

* add static check on data()

* follow comments

* fix ut
revert-27356-init_low_level_gloo
Leo Chen 4 years ago committed by GitHub
parent a5b3263782
commit 0b4bb023a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,10 +19,12 @@ from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_dtype, check_type
from ..utils import deprecated
from paddle.fluid.framework import static_only
__all__ = ['data']
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.data")
def data(name, shape, dtype='float32', lod_level=0):
"""

@ -217,7 +217,16 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func):
def __impl__(*args, **kwargs):
assert in_dygraph_mode(
), "We Only support %s in dynamic mode, please call 'paddle.disable_static()' to enter dynamic mode." % func.__name__
), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__
return func(*args, **kwargs)
return __impl__
def _static_only_(func):
def __impl__(*args, **kwargs):
assert not in_dygraph_mode(
), "We only support '%s()' in static graph mode, please call 'paddle.enable_static()' to enter static graph mode." % func.__name__
return func(*args, **kwargs)
return __impl__
@ -260,6 +269,7 @@ def deprecate_stat_dict(func):
dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_)
static_only = wrap_decorator(_static_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_)

@ -31,6 +31,7 @@ from ..unique_name import generate as unique_name
import logging
from ..data_feeder import check_dtype, check_type
from paddle.fluid.framework import static_only
__all__ = [
'data', 'read_file', 'double_buffer', 'py_reader',
@ -38,6 +39,7 @@ __all__ = [
]
@static_only
def data(name,
shape,
append_batch_size=True,

@ -99,5 +99,17 @@ class TestApiStaticDataError(unittest.TestCase):
self.assertRaises(TypeError, test_shape_type)
class TestApiErrorWithDynamicMode(unittest.TestCase):
def test_error(self):
with program_guard(Program(), Program()):
paddle.disable_static()
self.assertRaises(AssertionError, fluid.data, 'a', [2, 25])
self.assertRaises(
AssertionError, fluid.layers.data, 'b', shape=[2, 25])
self.assertRaises(
AssertionError, paddle.static.data, 'c', shape=[2, 25])
paddle.enable_static()
if __name__ == "__main__":
unittest.main()

@ -72,6 +72,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
test old fluid elementwise_mul api, it should fire Warinng function,
which insert the Warinng info on top of API's doc string.
"""
paddle.enable_static()
# Initialization
x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32')
@ -80,6 +81,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
# captured
captured = get_warning_index(fluid.data)
paddle.disable_static()
# testting
self.assertGreater(expected, captured)

@ -19,10 +19,12 @@ from paddle.fluid import core, Variable
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import static_only
__all__ = ['data', 'InputSpec']
@static_only
def data(name, shape, dtype=None, lod_level=0):
"""
**Data Layer**

Loading…
Cancel
Save