Add static mode check on data() (#27495)

* add static check on data()

* follow comments

* fix ut
revert-27356-init_low_level_gloo
Leo Chen 5 years ago committed by GitHub
parent a5b3263782
commit 0b4bb023a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,10 +19,12 @@ from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_dtype, check_type from paddle.fluid.data_feeder import check_dtype, check_type
from ..utils import deprecated from ..utils import deprecated
from paddle.fluid.framework import static_only
__all__ = ['data'] __all__ = ['data']
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.data") @deprecated(since="2.0.0", update_to="paddle.static.data")
def data(name, shape, dtype='float32', lod_level=0): def data(name, shape, dtype='float32', lod_level=0):
""" """

@ -217,7 +217,16 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func): def _dygraph_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
assert in_dygraph_mode( assert in_dygraph_mode(
), "We Only support %s in dynamic mode, please call 'paddle.disable_static()' to enter dynamic mode." % func.__name__ ), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__
return func(*args, **kwargs)
return __impl__
def _static_only_(func):
def __impl__(*args, **kwargs):
assert not in_dygraph_mode(
), "We only support '%s()' in static graph mode, please call 'paddle.enable_static()' to enter static graph mode." % func.__name__
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__ return __impl__
@ -260,6 +269,7 @@ def deprecate_stat_dict(func):
dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_) dygraph_only = wrap_decorator(_dygraph_only_)
static_only = wrap_decorator(_static_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_) fake_interface_only = wrap_decorator(_fake_interface_only_)

@ -31,6 +31,7 @@ from ..unique_name import generate as unique_name
import logging import logging
from ..data_feeder import check_dtype, check_type from ..data_feeder import check_dtype, check_type
from paddle.fluid.framework import static_only
__all__ = [ __all__ = [
'data', 'read_file', 'double_buffer', 'py_reader', 'data', 'read_file', 'double_buffer', 'py_reader',
@ -38,6 +39,7 @@ __all__ = [
] ]
@static_only
def data(name, def data(name,
shape, shape,
append_batch_size=True, append_batch_size=True,

@ -99,5 +99,17 @@ class TestApiStaticDataError(unittest.TestCase):
self.assertRaises(TypeError, test_shape_type) self.assertRaises(TypeError, test_shape_type)
class TestApiErrorWithDynamicMode(unittest.TestCase):
def test_error(self):
with program_guard(Program(), Program()):
paddle.disable_static()
self.assertRaises(AssertionError, fluid.data, 'a', [2, 25])
self.assertRaises(
AssertionError, fluid.layers.data, 'b', shape=[2, 25])
self.assertRaises(
AssertionError, paddle.static.data, 'c', shape=[2, 25])
paddle.enable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -72,6 +72,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
test old fluid elementwise_mul api, it should fire Warinng function, test old fluid elementwise_mul api, it should fire Warinng function,
which insert the Warinng info on top of API's doc string. which insert the Warinng info on top of API's doc string.
""" """
paddle.enable_static()
# Initialization # Initialization
x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32') x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32')
@ -80,6 +81,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
# captured # captured
captured = get_warning_index(fluid.data) captured = get_warning_index(fluid.data)
paddle.disable_static()
# testting # testting
self.assertGreater(expected, captured) self.assertGreater(expected, captured)

@ -19,10 +19,12 @@ from paddle.fluid import core, Variable
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import static_only
__all__ = ['data', 'InputSpec'] __all__ = ['data', 'InputSpec']
@static_only
def data(name, shape, dtype=None, lod_level=0): def data(name, shape, dtype=None, lod_level=0):
""" """
**Data Layer** **Data Layer**

Loading…
Cancel
Save