|
|
|
@ -17,11 +17,11 @@ from __future__ import print_function
|
|
|
|
|
from six.moves import reduce
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..layers import utils
|
|
|
|
|
from ..layers import nn
|
|
|
|
|
from ..layers import nn as F
|
|
|
|
|
from .. import dygraph_utils
|
|
|
|
|
from . import layers
|
|
|
|
|
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator, default_main_program
|
|
|
|
|
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
|
|
|
|
|
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator
|
|
|
|
|
from ..param_attr import ParamAttr
|
|
|
|
|
from ..initializer import Normal, Constant, NumpyArrayInitializer
|
|
|
|
|
from .. import unique_name
|
|
|
|
@ -31,9 +31,10 @@ import numbers
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Embedding', 'GRUUnit',
|
|
|
|
|
'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose',
|
|
|
|
|
'Conv3DTranspose', 'GroupNorm', 'SpectralNorm', 'TreeConv'
|
|
|
|
|
'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding',
|
|
|
|
|
'GRUUnit', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct',
|
|
|
|
|
'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', 'SpectralNorm',
|
|
|
|
|
'TreeConv'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1007,7 +1008,9 @@ class BatchNorm(layers.Layer):
|
|
|
|
|
Parameters:
|
|
|
|
|
num_channels(int): Indicate the number of channels of the input ``Tensor``.
|
|
|
|
|
act(str, optional): Activation to be applied to the output of batch normalization. Default: None.
|
|
|
|
|
is_test (bool, optional): A flag indicating whether it is in test phrase or not. Default: False.
|
|
|
|
|
is_test (bool, optional): A flag indicating whether it is in test phrase or not.
|
|
|
|
|
This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``.
|
|
|
|
|
Default: False.
|
|
|
|
|
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
|
|
|
|
|
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
|
|
|
|
|
param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
|
|
|
|
@ -1134,8 +1137,7 @@ class BatchNorm(layers.Layer):
|
|
|
|
|
variance_out = self._variance
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
_is_test = (not _dygraph_tracer()._train_mode) and (
|
|
|
|
|
not self._trainable_statistics)
|
|
|
|
|
_is_test = not self.training and not self._trainable_statistics
|
|
|
|
|
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
|
|
|
|
|
"is_test", _is_test, "data_layout", self._data_layout,
|
|
|
|
|
"use_mkldnn", False, "fuse_with_relu",
|
|
|
|
@ -1157,8 +1159,7 @@ class BatchNorm(layers.Layer):
|
|
|
|
|
"data_layout": self._data_layout,
|
|
|
|
|
"use_mkldnn": False,
|
|
|
|
|
"fuse_with_relu": self._fuse_with_relu,
|
|
|
|
|
"use_global_stats": self._use_global_stats,
|
|
|
|
|
"trainable_statistics": self._trainable_statistics
|
|
|
|
|
"use_global_stats": self._use_global_stats
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inputs = {
|
|
|
|
@ -1191,6 +1192,115 @@ class BatchNorm(layers.Layer):
|
|
|
|
|
return self._helper.append_activation(batch_norm_out, self._act)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Dropout(layers.Layer):
|
|
|
|
|
"""
|
|
|
|
|
This interface is used to construct a callable object of the ``Dropout`` class.
|
|
|
|
|
For more details, refer to code examples.
|
|
|
|
|
|
|
|
|
|
Drop or keep each element of input independently. Dropout is a regularization
|
|
|
|
|
technique for reducing overfitting by preventing neuron co-adaption during
|
|
|
|
|
training. The dropout operator randomly sets (according to the given dropout
|
|
|
|
|
probability) the outputs of some units to zero, while others are remain
|
|
|
|
|
unchanged.
|
|
|
|
|
|
|
|
|
|
Dropout layer can be removed for efficiency concern.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
p (float, optional): Probability of setting units to zero. Default: 0.5
|
|
|
|
|
seed (int, optional): A Python integer used to create random seeds. If this
|
|
|
|
|
parameter is set to None, a random seed is used.
|
|
|
|
|
NOTE: If an integer seed is given, always the same output
|
|
|
|
|
units will be dropped. DO NOT use a fixed seed in training. Default: None.
|
|
|
|
|
dropout_implementation(string, optional): ['downgrade_in_infer'(default)|'upscale_in_train']
|
|
|
|
|
|
|
|
|
|
1. downgrade_in_infer(default), downgrade the outcome at inference
|
|
|
|
|
|
|
|
|
|
- train: out = input * mask
|
|
|
|
|
- inference: out = input * (1.0 - p)
|
|
|
|
|
|
|
|
|
|
(mask is a tensor same shape with input, value is 0 or 1
|
|
|
|
|
ratio of 0 is dropout_prob)
|
|
|
|
|
2. upscale_in_train, upscale the outcome at training time
|
|
|
|
|
|
|
|
|
|
- train: out = input * mask / ( 1.0 - p )
|
|
|
|
|
- inference: out = input
|
|
|
|
|
|
|
|
|
|
(mask is a tensor same shape with input, value is 0 or 1
|
|
|
|
|
ratio of 0 is p)
|
|
|
|
|
is_test (bool, optional): A flag indicating whether it is in test phrase or not.
|
|
|
|
|
This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``.
|
|
|
|
|
Default: False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid.dygraph.base import to_variable
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
x = np.random.random(size=(3, 10, 3, 7)).astype('float32')
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
x = to_variable(x)
|
|
|
|
|
m = fluid.dygraph.Dropout(p=0.5)
|
|
|
|
|
droped_train = m(x)
|
|
|
|
|
# switch to eval mode
|
|
|
|
|
m.eval()
|
|
|
|
|
droped_eval = m(x)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
p=0.5,
|
|
|
|
|
seed=None,
|
|
|
|
|
dropout_implementation="downgrade_in_infer",
|
|
|
|
|
is_test=False):
|
|
|
|
|
super(Dropout, self).__init__()
|
|
|
|
|
assert isinstance(p, (float, int)), "p argument should be a number"
|
|
|
|
|
assert 0 <= p <= 1, "p argument should between 0 and 1"
|
|
|
|
|
self._dropout_prob = p
|
|
|
|
|
assert seed is None or isinstance(
|
|
|
|
|
seed, int), "seed argument should be None or a integer"
|
|
|
|
|
self._seed = seed
|
|
|
|
|
assert dropout_implementation in (
|
|
|
|
|
'downgrade_in_infer', 'upscale_in_train'
|
|
|
|
|
), "dropout_implementation argument should be 'downgrade_in_infer' or 'upscale_in_train'"
|
|
|
|
|
self._dropout_implementation = dropout_implementation
|
|
|
|
|
self._is_test = is_test
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
|
prog = default_main_program()
|
|
|
|
|
if (self._seed is None or self._seed == 0) and prog.random_seed != 0:
|
|
|
|
|
self._seed = prog.random_seed
|
|
|
|
|
attrs = {
|
|
|
|
|
'dropout_prob': self._dropout_prob,
|
|
|
|
|
'is_test': not self.training
|
|
|
|
|
if in_dygraph_mode() else self._is_test,
|
|
|
|
|
'fix_seed': self._seed is not None,
|
|
|
|
|
'seed': self._seed if self._seed is not None else 0,
|
|
|
|
|
'dropout_implementation': self._dropout_implementation,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
attrs = sum(attrs.items(), ())
|
|
|
|
|
out, mask = core.ops.dropout(input, *attrs)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
|
|
|
|
|
mask = self._helper.create_variable_for_type_inference(
|
|
|
|
|
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
|
|
|
|
|
|
|
|
|
|
self._helper.append_op(
|
|
|
|
|
type='dropout',
|
|
|
|
|
inputs={'X': [input]},
|
|
|
|
|
outputs={'Out': [out],
|
|
|
|
|
'Mask': [mask]},
|
|
|
|
|
attrs=attrs)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Embedding(layers.Layer):
|
|
|
|
|
"""
|
|
|
|
|
**Embedding Layer**
|
|
|
|
|