@ -14,9 +14,25 @@ limitations under the License. */
# pragma once
# ifdef PADDLE_WITH_XPU
# include <string>
# include <unordered_map>
# include "paddle/fluid/framework/tensor.h"
# include "paddle/fluid/platform/place.h"
inline std : : string get_xpu_error_message ( int error_type ) {
static std : : unordered_map < int , std : : string > xpu_error_map = {
{ baidu : : xpu : : api : : INVALID_PARAM , " Parameter is invalid. " } ,
{ baidu : : xpu : : api : : RUNTIME_ERROR ,
" Please check whether Baidu Kunlun Card "
" is properly installed. " } ,
{ baidu : : xpu : : api : : NO_ENOUGH_WORKSPACE ,
" There is not enough memory in Baidu "
" Kunlun Card. " } } ;
if ( xpu_error_map . find ( error_type ) = = xpu_error_map . end ( ) ) {
return " Unknown error type! " ;
}
return xpu_error_map [ error_type ] ;
}
# define XPU_MALLOC(addr, num_bytes) \
PADDLE_ENFORCE_EQ ( xpu_malloc ( reinterpret_cast < void * * > ( addr ) , num_bytes ) , \
XPU_SUCCESS , \
@ -102,21 +118,27 @@ limitations under the License. */
int res = \
xpu : : broadcast_ew ( dev_ctx . x_context ( ) , y_data , y_broadcast , pre , \
n , post , xpu : : ElementwiseOp : : ASSIGN ) ; \
PADDLE_ENFORCE_EQ ( res , xpu : : Error_t : : SUCCESS , \
platform : : errors : : Fatal ( " XPU kernel error! " ) ) ; \
PADDLE_ENFORCE_EQ ( \
res , xpu : : Error_t : : SUCCESS , \
platform : : errors : : External ( " XPU kernel error occur! %s " , \
get_xpu_error_message ( res ) ) ) ; \
y_data = y_broadcast ; \
} \
} \
int res = xpu : : elementwise_ # # kernel_name # # _grad ( \
dev_ctx . x_context ( ) , x_data , y_data , dout - > data < T > ( ) /*out*/ , \
dout - > data < T > ( ) , dx_data , dy_data , len ) ; \
PADDLE_ENFORCE_EQ ( res , xpu : : Error_t : : SUCCESS , \
platform : : errors : : Fatal ( " XPU kernel error! " ) ) ; \
PADDLE_ENFORCE_EQ ( \
res , xpu : : Error_t : : SUCCESS , \
platform : : errors : : External ( " XPU kernel error occur! %s " , \
get_xpu_error_message ( res ) ) ) ; \
if ( ( dy ! = nullptr ) & & ( len ! = n ) ) { \
int res = xpu : : reduce_ew ( dev_ctx . x_context ( ) , dy_data , dy - > data < T > ( ) , \
pre , n , post , xpu : : ElementwiseOp : : ASSIGN ) ; \
PADDLE_ENFORCE_EQ ( res , xpu : : Error_t : : SUCCESS , \
platform : : errors : : Fatal ( " XPU kernel error! " ) ) ; \
PADDLE_ENFORCE_EQ ( \
res , xpu : : Error_t : : SUCCESS , \
platform : : errors : : External ( " XPU kernel error occur! %s " , \
get_xpu_error_message ( res ) ) ) ; \
dev_ctx . Wait ( ) ; \
xpu_free ( dy_data ) ; \
} \
@ -161,8 +183,8 @@ void XPUElementwise(const framework::ExecutionContext& ctx) {
platform : : errors : : PreconditionNotMet (
" This kernel only runs on XPU device. " ) ) ;
auto x_var = ctx . InputVar ( " X " ) ;
PADDLE_ENFORCE_NE ( x_var , nullptr ,
platform : : errors : : Fatal ( " Cannot get input Variable X " ) ) ;
PADDLE_ENFORCE_NE ( x_var , nullptr , platform : : errors : : InvalidArgument (
" Cannot get input Variable X " ) ) ;
PADDLE_ENFORCE_EQ (
x_var - > IsType < framework : : LoDTensor > ( ) , true ,
platform : : errors : : InvalidArgument (
@ -206,36 +228,36 @@ void XPUElementwise(const framework::ExecutionContext& ctx) {
if ( std : : is_same < Functor , XPUAddFunctor < T > > : : value ) {
int res = xpu : : matrix_vector_add ( dev_ctx . x_context ( ) , x_data , y_data ,
z_data , pre , n ) ;
PADDLE_ENFORCE_EQ (
res , xpu : : Error_t : : SUCCESS ,
platform : : errors : : Fatal ( " XPU kernel error! res = %d " , res ) ) ;
PADDLE_ENFORCE_EQ ( res , xpu : : Error_t : : SUCCESS ,
platform : : errors : : External ( " XPU kernel error occur! %s " ,
get_xpu_error_message ( res ) ) ) ;
return ;
}
if ( std : : is_same < Functor , XPUMulFunctor < T > > : : value ) {
int res = xpu : : matrix_vector_mul ( dev_ctx . x_context ( ) , x_data , y_data ,
z_data , pre , n ) ;
PADDLE_ENFORCE_EQ (
res , xpu : : Error_t : : SUCCESS ,
platform : : errors : : Fatal ( " XPU kernel error! res = %d " , res ) ) ;
PADDLE_ENFORCE_EQ ( res , xpu : : Error_t : : SUCCESS ,
platform : : errors : : External ( " XPU kernel error occur! %s " ,
get_xpu_error_message ( res ) ) ) ;
return ;
}
}
if ( pre ! = 1 | | post ! = 1 ) {
PADDLE_ENFORCE ( xpu_malloc ( reinterpret_cast < void * * > ( & y_broadcast ) ,
len * sizeof ( T ) ) = = XPU_SUCCESS ) ;
XPU_MALLOC ( & y_broadcast , len * sizeof ( T ) ) ;
int res = xpu : : broadcast_ew ( dev_ctx . x_context ( ) , y_data , y_broadcast , pre ,
n , post , xpu : : ElementwiseOp : : ASSIGN ) ;
PADDLE_ENFORCE_EQ (
res , xpu : : Error_t : : SUCCESS ,
platform : : errors : : Fatal ( " XPU kernel error! res = %d " , res ) ) ;
PADDLE_ENFORCE_EQ ( res , xpu : : Error_t : : SUCCESS ,
platform : : errors : : External ( " XPU kernel error occur! %s " ,
get_xpu_error_message ( res ) ) ) ;
y_data = y_broadcast ;
}
Functor functor ;
int res = functor ( dev_ctx . x_context ( ) , x_data , y_data , z_data , len ) ;
PADDLE_ENFORCE_EQ ( res , xpu : : Error_t : : SUCCESS ,
platform : : errors : : Fatal ( " XPU kernel error! res = %d " , res ) ) ;
platform : : errors : : External ( " XPU kernel error occur! %s " ,
get_xpu_error_message ( res ) ) ) ;
if ( pre ! = 1 | | post ! = 1 ) {
dev_ctx . Wait ( ) ;