@ -20,6 +20,7 @@ from ...fluid.dygraph import Linear #DEFINE_ALIAS
from . . . fluid . dygraph import Flatten #DEFINE_ALIAS
from . . . fluid . dygraph import layers
from . . import functional as F
from . . . fluid . framework import _dygraph_tracer
__all__ = [
' BilinearTensorProduct ' ,
@ -38,6 +39,9 @@ __all__ = [
' ConstantPad3d ' ,
' ReplicationPad3d ' ,
' CosineSimilarity ' ,
' Dropout ' ,
' Dropout2D ' ,
' Dropout3D ' ,
]
@ -348,6 +352,189 @@ class Pad2D(layers.Layer):
data_format = self . _data_format )
class Dropout ( layers . Layer ) :
"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co - adaption during training as described in the paper :
` Improving neural networks by preventing co - adaptation of feature detectors < https : / / arxiv . org / abs / 1207.0580 > ` _
The dropout operator randomly sets the outputs of some units to zero , while upscale others
according to the given dropout probability .
See ` ` paddle . nn . functional . dropout ` ` for more details .
In dygraph mode , please use ` ` eval ( ) ` ` to indicate whether it is in test phrase or not .
Parameters :
p ( float | int ) : Probability of setting units to zero . Default : 0.5
axis ( int | list ) : The axis along which the dropout is performed . Default None .
name ( str , optional ) : Name for the operation ( optional , default is None ) . For more information , please refer to : ref : ` api_guide_Name ` .
mode ( str , optional ) : [ ' upscale_in_train ' ( default ) | ' downscale_in_infer ' ]
1. upscale_in_train ( default ) , upscale the output at training time
- train : out = input * mask / ( 1.0 - p )
- inference : out = input
2. downscale_in_infer , downscale the output at inference
- train : out = input * mask
- inference : out = input * ( 1.0 - p )
Shape :
- input : N - D tensor .
- output : N - D tensor , the same shape as input .
Examples :
. . code - block : : python
import paddle
import numpy as np
paddle . disable_static ( )
x = np . array ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) . astype ( ' float32 ' )
x = paddle . to_tensor ( x )
m = paddle . nn . Dropout ( p = 0.5 )
y_train = m ( x )
m . eval ( ) # switch the model to test phase
y_test = m ( x )
print ( x . numpy ( ) )
print ( y_train . numpy ( ) )
print ( y_test . numpy ( ) )
"""
def __init__ ( self , p = 0.5 , axis = None , mode = " upscale_in_train " , name = None ) :
super ( Dropout , self ) . __init__ ( )
self . p = p
self . training = _dygraph_tracer ( ) . _train_mode
self . axis = axis
self . mode = mode
self . name = name
def forward ( self , input ) :
out = F . dropout (
input ,
p = self . p ,
axis = self . axis ,
training = self . training ,
mode = self . mode ,
name = self . name )
return out
class Dropout2D ( layers . Layer ) :
"""
Randomly zero out entire channels ( in the batched input 4 d tensor with the shape ` NCHW ` ,
a channel is a 2 D feature map with the shape ` HW ` ) . Each channel will be zeroed out independently
on every forward call with probability ` p ` using samples from a Bernoulli distribution .
Dropout2d will help promote independence between feature maps as described in the paper :
` Efficient Object Localization Using Convolutional Networks < https : / / arxiv . org / abs / 1411.4280 > ` _
See ` ` paddle . nn . functional . dropout2d ` ` for more details .
Please use ` ` eval ( ) ` ` to indicate whether it is in test phrase or not .
Parameters :
p ( float , optional ) : Probability of setting units to zero . Default : 0.5
data_format ( str , optional ) : Specify the data format of the input , and the data format of the output
will be consistent with that of the input . An optional string from :
` NCHW ` , ` NHWC ` . The default is ` NCHW ` . When it is ` NCHW ` , the data is
stored in the order of : [ batch_size , input_channels , input_height , input_width ] .
name ( str , optional ) : Name for the operation ( optional , default is None ) . For more information , please refer to : ref : ` api_guide_Name ` .
Shape :
- input : 4 - D tensor .
- output : 4 - D tensor , the same shape as input .
Examples :
. . code - block : : python
import paddle
import numpy as np
paddle . disable_static ( )
x = np . random . random ( size = ( 2 , 3 , 4 , 5 ) ) . astype ( ' float32 ' )
x = paddle . to_tensor ( x )
m = paddle . nn . Dropout2D ( p = 0.5 )
y_train = m ( x )
m . eval ( ) # switch the model to test phase
y_test = m ( x )
print ( x . numpy ( ) )
print ( y_train . numpy ( ) )
print ( y_test . numpy ( ) )
"""
def __init__ ( self , p = 0.5 , data_format = ' NCHW ' , name = None ) :
super ( Dropout2D , self ) . __init__ ( )
self . p = p
self . data_format = data_format
self . name = name
def forward ( self , input ) :
out = F . dropout2d (
input ,
p = self . p ,
training = self . training ,
data_format = self . data_format ,
name = self . name )
return out
class Dropout3D ( layers . Layer ) :
"""
Randomly zero out entire channels ( in the batched input 5 d tensor with the shape ` NCDHW ` ,
a channel is a 3 D feature map with the shape ` DHW ` ) . Each channel will be zeroed out independently
on every forward call with probability ` p ` using samples from a Bernoulli distribution .
Dropout3d will help promote independence between feature maps as described in the paper :
` Efficient Object Localization Using Convolutional Networks < https : / / arxiv . org / abs / 1411.4280 > ` _
See ` ` paddle . nn . functional . dropout3d ` ` for more details .
Please use ` ` eval ( ) ` ` to indicate whether it is in test phrase or not .
Parameters :
p ( float | int ) : Probability of setting units to zero . Default : 0.5
data_format ( str , optional ) : Specify the data format of the input , and the data format of the output
will be consistent with that of the input . An optional string from :
` NCDHW ` , ` NDHWC ` . The default is ` NCDHW ` . When it is ` NCDHW ` , the data is
stored in the order of : [ batch_size , input_channels , input_depth , input_height , input_width ] .
name ( str , optional ) : Name for the operation ( optional , default is None ) . For more information , please refer to : ref : ` api_guide_Name ` .
Shape :
- input : 5 - D tensor .
- output : 5 - D tensor , the same shape as input .
Examples :
. . code - block : : python
import paddle
import numpy as np
paddle . disable_static ( )
x = np . random . random ( size = ( 2 , 3 , 4 , 5 , 6 ) ) . astype ( ' float32 ' )
x = paddle . to_tensor ( x )
m = paddle . nn . Dropout3D ( p = 0.5 )
y_train = m ( x )
m . eval ( ) # switch the model to test phase
y_test = m ( x )
print ( x . numpy ( ) )
print ( y_train . numpy ( ) )
print ( y_test . numpy ( ) )
"""
def __init__ ( self , p = 0.5 , data_format = ' NCDHW ' , name = None ) :
super ( Dropout3D , self ) . __init__ ( )
self . p = p
self . training = _dygraph_tracer ( ) . _train_mode
self . data_format = data_format
self . name = name
def forward ( self , input ) :
out = F . dropout3d (
input ,
p = self . p ,
training = self . training ,
data_format = self . data_format ,
name = self . name )
return out
class ReflectionPad1d ( layers . Layer ) :
"""
This interface is used to construct a callable object of the ` ` ReflectionPad1d ` ` class .