@ -11,6 +11,7 @@ limitations under the License. */
# pragma once
# pragma once
# include <glog/logging.h>
# include <glog/logging.h>
# include <algorithm>
# include <string>
# include <string>
# include <unordered_set>
# include <unordered_set>
# include <utility>
# include <utility>
@ -24,6 +25,7 @@ limitations under the License. */
# include "paddle/fluid/framework/eigen.h"
# include "paddle/fluid/framework/eigen.h"
# include "paddle/fluid/framework/op_registry.h"
# include "paddle/fluid/framework/op_registry.h"
# include "paddle/fluid/operators/detail/safe_ref.h"
# include "paddle/fluid/operators/detail/safe_ref.h"
# include "paddle/fluid/operators/math/blas.h"
# include "paddle/fluid/platform/float16.h"
# include "paddle/fluid/platform/float16.h"
# ifdef PADDLE_WITH_MKLDNN
# ifdef PADDLE_WITH_MKLDNN
@ -301,8 +303,28 @@ template <typename T>
struct GeluFunctor : public BaseActivationFunctor < T > {
struct GeluFunctor : public BaseActivationFunctor < T > {
template < typename Device , typename X , typename Out >
template < typename Device , typename X , typename Out >
void operator ( ) ( Device d , X x , Out out ) const {
void operator ( ) ( Device d , X x , Out out ) const {
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
# if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
! defined ( __OSX__ ) & & ! defined ( PADDLE_WITH_CUDA )
auto x_data = x . data ( ) ;
auto out_data = out . data ( ) ;
int n = std : : min ( x . size ( ) , out . size ( ) ) ;
std : : memset ( out_data , 0 , n * sizeof ( T ) ) ;
math : : CBlas < T > : : AXPY ( n , static_cast < T > ( M_SQRT1_2 ) , x_data , 1 , out_data , 1 ) ;
math : : CBlas < T > : : VMERF ( n , out_data , out_data , VML_LA ) ;
for ( int i = 0 ; i < n ; i + + ) {
out_data [ i ] + = static_cast < T > ( 1 ) ;
}
math : : CBlas < T > : : VMUL ( n , x_data , out_data , out_data ) ;
for ( int i = 0 ; i < n ; i + + ) {
out_data [ i ] * = static_cast < T > ( 0.5 ) ;
}
# else
auto temp = ( x * static_cast < T > ( M_SQRT1_2 ) ) . erf ( ) ;
auto temp = ( x * static_cast < T > ( M_SQRT1_2 ) ) . erf ( ) ;
out . device ( d ) = x * static_cast < T > ( 0.5 ) * ( static_cast < T > ( 1 ) + temp ) ;
out . device ( d ) = x * static_cast < T > ( 0.5 ) * ( static_cast < T > ( 1 ) + temp ) ;
# endif
}
}
} ;
} ;