|
|
|
@ -14,13 +14,10 @@
|
|
|
|
|
|
|
|
|
|
# TODO: define activation functions of neural network
|
|
|
|
|
from ...fluid.layers import brelu #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import elu #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import erf #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import gelu #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import hard_swish #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import leaky_relu #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import logsigmoid #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import maxout #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import relu6 #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.layers import selu #DEFINE_ALIAS
|
|
|
|
@ -69,6 +66,108 @@ from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def elu(x, alpha=1.0, name=None):
|
|
|
|
|
"""
|
|
|
|
|
elu activation.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
elu(x) = max(0, x) + min(0, \\alpha * (e^{x}-1))
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
x (Tensor): The input Tensor with data type float32, float64.
|
|
|
|
|
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
|
|
|
|
|
name (str, optional): Name for the operation (optional, default is None).
|
|
|
|
|
For more information, please refer to :ref:`api_guide_Name`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A Tensor with the same data type and shape as ``x`` .
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
|
|
|
|
|
x = paddle.to_tensor(np.array([[-1,6],[1,15.6]]))
|
|
|
|
|
out = F.elu(x, alpha=0.2)
|
|
|
|
|
# [[-0.12642411 6. ]
|
|
|
|
|
# [ 1. 15.6 ]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
return core.ops.elu(x, 'alpha', alpha)
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'elu')
|
|
|
|
|
helper = LayerHelper("elu", **locals())
|
|
|
|
|
out = helper.create_variable_for_type_inference(x.dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='elu',
|
|
|
|
|
inputs={'X': x},
|
|
|
|
|
outputs={'Out': out},
|
|
|
|
|
attrs={'alpha': alpha})
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gelu(x, approximate=False, name=None):
|
|
|
|
|
"""
|
|
|
|
|
gelu activation.
|
|
|
|
|
|
|
|
|
|
if approximate is True
|
|
|
|
|
.. math::
|
|
|
|
|
gelu(x) = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3})))
|
|
|
|
|
else
|
|
|
|
|
.. math::
|
|
|
|
|
gelu(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}}))
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
x (Tensor): The input Tensor with data type float32, float64.
|
|
|
|
|
approximate (bool, optional): Wether to enable approximation. Default is False.
|
|
|
|
|
name (str, optional): Name for the operation (optional, default is None).
|
|
|
|
|
For more information, please refer to :ref:`api_guide_Name`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A Tensor with the same data type and shape as ``x`` .
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
|
|
|
|
|
data = np.random.randn(2, 3).astype("float32")
|
|
|
|
|
x = paddle.to_tensor(data)
|
|
|
|
|
|
|
|
|
|
out = F.gelu(x)
|
|
|
|
|
|
|
|
|
|
data
|
|
|
|
|
# array([[ 0.87165993, -1.0541513 , -0.37214822],
|
|
|
|
|
# [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32)
|
|
|
|
|
out
|
|
|
|
|
# array([[ 0.70456535, -0.15380788, -0.13207214],
|
|
|
|
|
# [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
return core.ops.gelu(x, 'approximate', approximate)
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'gelu')
|
|
|
|
|
helper = LayerHelper("gelu", **locals())
|
|
|
|
|
out = helper.create_variable_for_type_inference(x.dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='gelu',
|
|
|
|
|
inputs={'X': x},
|
|
|
|
|
outputs={'Out': out},
|
|
|
|
|
attrs={'approximate': approximate})
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hardshrink(x, threshold=0.5, name=None):
|
|
|
|
|
"""
|
|
|
|
|
hard shrinkage activation
|
|
|
|
@ -245,11 +344,8 @@ def hsigmoid(input,
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def relu(input, inplace=False, name=None):
|
|
|
|
|
def relu(x, name=None):
|
|
|
|
|
"""
|
|
|
|
|
:alias_main: paddle.nn.functional.relu
|
|
|
|
|
:alias: paddle.nn.functional.relu,paddle.nn.functional.activation.relu
|
|
|
|
|
|
|
|
|
|
ReLU Activation.
|
|
|
|
|
|
|
|
|
|
.. math:
|
|
|
|
@ -257,44 +353,74 @@ def relu(input, inplace=False, name=None):
|
|
|
|
|
out = max(x, 0)
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
input (Variable): The input variable. A multi-dimension Tensor with type float16, float32, or float64.
|
|
|
|
|
inplace (bool, optional): If inplace is True, the input and output of ``ReLU`` are the same variable.
|
|
|
|
|
Otherwise, the input and output of ``ReLU`` are different variables. Default: False. Note that if x is
|
|
|
|
|
more than one OPs' input, inplace must be False.
|
|
|
|
|
name (str, optional): The default value is None. Normally there is no need for user to set this property.
|
|
|
|
|
For more information, please refer to :ref:`api_guide_Name` .
|
|
|
|
|
x (Tensor): The input Tensor with data type float32, float64.
|
|
|
|
|
name (str, optional): Name for the operation (optional, default is None).
|
|
|
|
|
For more information, please refer to :ref:`api_guide_Name`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Output of relu operator, a Tensor with shape same as input
|
|
|
|
|
A Tensor with the same data type and shape as ``x`` .
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.nn.functional as functional
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
|
|
|
|
|
data = np.array([-2, 0, 1]).astype('float32')
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
data = fluid.dygraph.to_variable(data)
|
|
|
|
|
res = functional.relu(data) # [0, 0, 1]
|
|
|
|
|
x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32'))
|
|
|
|
|
out = F.relu(x) # [0., 0., 1.]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
if inplace:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"Inplace on ReLU is not allowed and will be discarded in dygraph mode currently."
|
|
|
|
|
)
|
|
|
|
|
return core.ops.relu(input)
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
|
|
|
|
|
'relu')
|
|
|
|
|
return core.ops.relu(x)
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu')
|
|
|
|
|
helper = LayerHelper('relu', **locals())
|
|
|
|
|
outs = input if inplace else helper.create_variable_for_type_inference(
|
|
|
|
|
input.dtype)
|
|
|
|
|
helper.append_op(type='relu', inputs={'X': [input]}, outputs={'Out': outs})
|
|
|
|
|
return outs
|
|
|
|
|
out = helper.create_variable_for_type_inference(x.dtype)
|
|
|
|
|
helper.append_op(type='relu', inputs={'X': x}, outputs={'Out': out})
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def logsigmoid(x, name=None):
|
|
|
|
|
"""
|
|
|
|
|
logsigmoid activation.
|
|
|
|
|
|
|
|
|
|
.. math:
|
|
|
|
|
|
|
|
|
|
logsigmoid(x) = \log \frac{1}{1 + e^{-x}}
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
x (Tensor): The input Tensor with data type float32, float64.
|
|
|
|
|
name (str, optional): Name for the operation (optional, default is None).
|
|
|
|
|
For more information, please refer to :ref:`api_guide_Name`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A Tensor with the same data type and shape as ``x`` .
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
|
|
|
|
|
x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0]))
|
|
|
|
|
out = F.logsigmoid(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
return core.ops.logsigmoid(x)
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
|
|
|
|
|
'logsigmoid')
|
|
|
|
|
helper = LayerHelper("logsigmoid", **locals())
|
|
|
|
|
out = helper.create_variable_for_type_inference(x.dtype)
|
|
|
|
|
helper.append_op(type='logsigmoid', inputs={'X': x}, outputs={'Out': out})
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def softmax(x, axis=-1, name=None):
|
|
|
|
|