@ -1084,7 +1084,7 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
void operator ( ) ( Device d , X x , Out out , dOut dout , dX dx ) const {
dx . device ( d ) = dout * ( x > static_cast < T > ( 0 ) ) . template cast < T > ( ) +
dout * static_cast < T > ( alpha ) * x . exp ( ) *
( x < static_cast < T > ( 0 ) ) . template cast < T > ( ) ;
( x < = static_cast < T > ( 0 ) ) . template cast < T > ( ) ;
}
static constexpr ActBwdOpFwdDeps FwdDeps ( ) { return kDepX ; }
@ -1405,6 +1405,39 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps ( ) { return kDepOut ; }
} ;
template < typename T >
struct ELUGradGradFunctor : public BaseActivationFunctor < T > {
float alpha ;
typename BaseActivationFunctor < T > : : AttrPair GetAttrs ( ) {
return { { " alpha " , & alpha } } ;
}
template < typename Device >
void operator ( ) ( const Device & dev , const framework : : Tensor * X ,
const framework : : Tensor * ddX , framework : : Tensor * ddOut ,
const framework : : Tensor * dOut , framework : : Tensor * dX ) const {
auto * d = dev . eigen_device ( ) ;
auto ddx = framework : : EigenVector < T > : : Flatten ( detail : : Ref ( ddX ) ) ;
auto x = framework : : EigenVector < T > : : Flatten ( detail : : Ref ( X ) ) ;
if ( dX ) {
auto dx = framework : : EigenVector < T > : : Flatten ( detail : : Ref ( dX ) ) ;
auto dout = framework : : EigenVector < T > : : Flatten ( detail : : Ref ( dOut ) ) ;
dx . device ( * d ) = ddx * dout * static_cast < T > ( alpha ) * x . exp ( ) *
( x < static_cast < T > ( 0 ) ) . template cast < T > ( ) ;
}
if ( ddOut ) {
auto ddout = framework : : EigenVector < T > : : Flatten ( detail : : Ref ( ddOut ) ) ;
ddout . device ( * d ) = ddx *
( ( x > static_cast < T > ( 0 ) ) . template cast < T > ( ) +
static_cast < T > ( alpha ) * x . exp ( ) *
( x < = static_cast < T > ( 0 ) ) . template cast < T > ( ) )
. template cast < T > ( ) ;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps ( ) { return kDepX ; }
} ;
template < typename T >
struct SqrtGradGradFunctor : public BaseActivationFunctor < T > {
template < typename Device >
@ -1515,6 +1548,33 @@ class SquareDoubleGradKernel
}
} ;
template < typename DeviceContext , typename Functor >
class ELUDoubleGradKernel
: public framework : : OpKernel < typename Functor : : ELEMENT_TYPE > {
public :
using T = typename Functor : : ELEMENT_TYPE ;
void Compute ( const framework : : ExecutionContext & ctx ) const override {
const framework : : Tensor * X , * ddX , * dOut ;
X = ddX = dOut = nullptr ;
framework : : Tensor * dX , * ddOut ;
dX = ddOut = nullptr ;
ExtractDoubleGradTensorWithInputDOut ( ctx , & X , & ddX , & dX , & dOut , & ddOut ) ;
if ( dX ) dX - > mutable_data < T > ( X - > dims ( ) , ctx . GetPlace ( ) ) ;
if ( ddOut ) ddOut - > mutable_data < T > ( ctx . GetPlace ( ) ) ;
auto & place = ctx . template device_context < DeviceContext > ( ) ;
Functor functor ;
auto attrs = functor . GetAttrs ( ) ;
for ( auto & attr : attrs ) {
* attr . second = ctx . Attr < float > ( attr . first ) ;
}
functor ( place , X , ddX , ddOut , dOut , dX ) ;
}
} ;
template < typename DeviceContext , typename Functor >
class SqrtDoubleGradKernel
: public framework : : OpKernel < typename Functor : : ELEMENT_TYPE > {
@ -1688,7 +1748,6 @@ class PowGradKernel
__macro ( softsign , Softsign , SoftsignFunctor , SoftsignGradFunctor ) ; \
__macro ( relu6 , Relu6 , Relu6Functor , Relu6GradFunctor ) ; \
__macro ( tanh_shrink , TanhShrink , TanhShrinkFunctor , TanhShrinkGradFunctor ) ; \
__macro ( elu , ELU , ELUFunctor , ELUGradFunctor ) ; \
__macro ( hard_shrink , HardShrink , HardShrinkFunctor , HardShrinkGradFunctor ) ; \
__macro ( hard_sigmoid , HardSigmoid , HardSigmoidFunctor , \
HardSigmoidGradFunctor ) ; \