[API2.0] add dropout, dropout2d and dropout3d in nn.functional and nn.layer (#26111)

* [API2.0] add dropout, dropout2d and dropout3d, test=develop

* refine Interface and assertion after review, test=develop

* fix alias p=1 and use scale, test=develop

* fix doc and training, test=develop

* fix doc in Dropout2D, test=develop
test_feature_precision_test_c
huangjun12 5 years ago committed by GitHub
parent 07c1c47bc9
commit 412eca679f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -932,6 +932,7 @@ def cos_sim(X, Y):
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.dropout")
def dropout(x,
dropout_prob,
is_test=False,
@ -939,9 +940,6 @@ def dropout(x,
name=None,
dropout_implementation="downgrade_in_infer"):
"""
:alias_main: paddle.nn.functional.dropout
:alias: paddle.nn.functional.dropout,paddle.nn.functional.common.dropout
:old_api: paddle.fluid.layers.dropout
Computes dropout.

File diff suppressed because it is too large Load Diff

@ -87,6 +87,9 @@ from .layer.common import Embedding #DEFINE_ALIAS
from .layer.common import Linear #DEFINE_ALIAS
from .layer.common import Flatten #DEFINE_ALIAS
from .layer.common import UpSample #DEFINE_ALIAS
from .layer.common import Dropout #DEFINE_ALIAS
from .layer.common import Dropout2D #DEFINE_ALIAS
from .layer.common import Dropout3D #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .layer.conv import Conv2D #DEFINE_ALIAS

@ -54,6 +54,8 @@ from .activation import tanhshrink #DEFINE_ALIAS
from .activation import thresholded_relu #DEFINE_ALIAS
from .activation import log_softmax #DEFINE_ALIAS
from .common import dropout #DEFINE_ALIAS
from .common import dropout2d #DEFINE_ALIAS
from .common import dropout3d #DEFINE_ALIAS
# from .common import embedding #DEFINE_ALIAS
# from .common import fc #DEFINE_ALIAS
from .common import label_smooth #DEFINE_ALIAS

File diff suppressed because it is too large Load Diff

@ -52,6 +52,9 @@ from .common import Embedding #DEFINE_ALIAS
from .common import Linear #DEFINE_ALIAS
from .common import Flatten #DEFINE_ALIAS
from .common import UpSample #DEFINE_ALIAS
from .common import Dropout #DEFINE_ALIAS
from .common import Dropout2D #DEFINE_ALIAS
from .common import Dropout3D #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .conv import Conv2D #DEFINE_ALIAS

@ -20,6 +20,7 @@ from ...fluid.dygraph import Linear #DEFINE_ALIAS
from ...fluid.dygraph import Flatten #DEFINE_ALIAS
from ...fluid.dygraph import layers
from .. import functional as F
from ...fluid.framework import _dygraph_tracer
__all__ = [
'BilinearTensorProduct',
@ -38,6 +39,9 @@ __all__ = [
'ConstantPad3d',
'ReplicationPad3d',
'CosineSimilarity',
'Dropout',
'Dropout2D',
'Dropout3D',
]
@ -348,6 +352,189 @@ class Pad2D(layers.Layer):
data_format=self._data_format)
class Dropout(layers.Layer):
"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co-adaption during training as described in the paper:
`Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/abs/1207.0580>`_
The dropout operator randomly sets the outputs of some units to zero, while upscale others
according to the given dropout probability.
See ``paddle.nn.functional.dropout`` for more details.
In dygraph mode, please use ``eval()`` to indicate whether it is in test phrase or not.
Parameters:
p (float | int): Probability of setting units to zero. Default: 0.5
axis (int | list): The axis along which the dropout is performed. Default None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
Shape:
- input: N-D tensor.
- output: N-D tensor, the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.array([[1,2,3], [4,5,6]]).astype('float32')
x = paddle.to_tensor(x)
m = paddle.nn.Dropout(p=0.5)
y_train = m(x)
m.eval() # switch the model to test phase
y_test = m(x)
print(x.numpy())
print(y_train.numpy())
print(y_test.numpy())
"""
def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None):
super(Dropout, self).__init__()
self.p = p
self.training = _dygraph_tracer()._train_mode
self.axis = axis
self.mode = mode
self.name = name
def forward(self, input):
out = F.dropout(
input,
p=self.p,
axis=self.axis,
training=self.training,
mode=self.mode,
name=self.name)
return out
class Dropout2D(layers.Layer):
"""
Randomly zero out entire channels (in the batched input 4d tensor with the shape `NCHW` ,
a channel is a 2D feature map with the shape `HW`). Each channel will be zeroed out independently
on every forward call with probability `p` using samples from a Bernoulli distribution.
Dropout2d will help promote independence between feature maps as described in the paper:
`Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
See ``paddle.nn.functional.dropout2d`` for more details.
Please use ``eval()`` to indicate whether it is in test phrase or not.
Parameters:
p (float, optional): Probability of setting units to zero. Default: 0.5
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:
`NCHW`, `NHWC`. The default is `NCHW`. When it is `NCHW`, the data is
stored in the order of: [batch_size, input_channels, input_height, input_width].
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: 4-D tensor.
- output: 4-D tensor, the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.random.random(size=(2, 3, 4, 5)).astype('float32')
x = paddle.to_tensor(x)
m = paddle.nn.Dropout2D(p=0.5)
y_train = m(x)
m.eval() # switch the model to test phase
y_test = m(x)
print(x.numpy())
print(y_train.numpy())
print(y_test.numpy())
"""
def __init__(self, p=0.5, data_format='NCHW', name=None):
super(Dropout2D, self).__init__()
self.p = p
self.data_format = data_format
self.name = name
def forward(self, input):
out = F.dropout2d(
input,
p=self.p,
training=self.training,
data_format=self.data_format,
name=self.name)
return out
class Dropout3D(layers.Layer):
"""
Randomly zero out entire channels (in the batched input 5d tensor with the shape `NCDHW` ,
a channel is a 3D feature map with the shape `DHW` ). Each channel will be zeroed out independently
on every forward call with probability `p` using samples from a Bernoulli distribution.
Dropout3d will help promote independence between feature maps as described in the paper:
`Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
See ``paddle.nn.functional.dropout3d`` for more details.
Please use ``eval()`` to indicate whether it is in test phrase or not.
Parameters:
p (float | int): Probability of setting units to zero. Default: 0.5
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:
`NCDHW`, `NDHWC`. The default is `NCDHW`. When it is `NCDHW`, the data is
stored in the order of: [batch_size, input_channels, input_depth, input_height, input_width].
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: 5-D tensor.
- output: 5-D tensor, the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.random.random(size=(2, 3, 4, 5, 6)).astype('float32')
x = paddle.to_tensor(x)
m = paddle.nn.Dropout3D(p=0.5)
y_train = m(x)
m.eval() # switch the model to test phase
y_test = m(x)
print(x.numpy())
print(y_train.numpy())
print(y_test.numpy())
"""
def __init__(self, p=0.5, data_format='NCDHW', name=None):
super(Dropout3D, self).__init__()
self.p = p
self.training = _dygraph_tracer()._train_mode
self.data_format = data_format
self.name = name
def forward(self, input):
out = F.dropout3d(
input,
p=self.p,
training=self.training,
data_format=self.data_format,
name=self.name)
return out
class ReflectionPad1d(layers.Layer):
"""
This interface is used to construct a callable object of the ``ReflectionPad1d`` class.

Loading…
Cancel
Save