@ -21,6 +21,17 @@ limitations under the License. */
namespace paddle {
namespace operators {
/**
* Printing shape information into a string is easy to use .
*/
inline static std : : string DumpMatrixShape ( const math : : MatDescriptor & desc ) {
std : : stringstream buffer ;
buffer < < " [ " < < desc . batch_size_ < < " , " < < desc . height_ < < " , "
< < desc . width_ < < " ] " ;
return buffer . str ( ) ;
}
/**
* Get row matrix shape from a vector shape . If the rank of x_dim > 1 , the
* original x_dim is returned .
@ -303,21 +314,37 @@ class MatMulOp : public framework::OperatorWithKernel {
context - > Attrs ( ) . Get < bool > ( " transpose_Y " ) ) ;
if ( context - > IsRuntime ( ) ) {
PADDLE_ENFORCE ( mat_dim_x . batch_size_ = = mat_dim_y . batch_size_ | |
mat_dim_x . batch_size_ = = 0 | | mat_dim_y . batch_size_ = = 0 ) ;
PADDLE_ENFORCE (
mat_dim_x . batch_size_ = = mat_dim_y . batch_size_ | |
mat_dim_x . batch_size_ = = 0 | | mat_dim_y . batch_size_ = = 0 ,
" ShapeError: The batch size of the two matrices should be equal, or "
" at least one is zero. \n "
" But received X's shape: %s, Y's shape: %s. " ,
DumpMatrixShape ( mat_dim_x ) . c_str ( ) ,
DumpMatrixShape ( mat_dim_y ) . c_str ( ) ) ;
}
std : : vector < int64_t > dim_out ;
int64_t dim_out_y = mat_dim_y . width_ ;
# if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
int head_number = context - > Attrs ( ) . Get < int > ( " head_number " ) ;
bool split_vertical_y = ( mat_dim_x . width_ ! = mat_dim_y . height_ ) ;
PADDLE_ENFORCE_LE ( head_number , mat_dim_x . width_ ) ;
PADDLE_ENFORCE_LE (
head_number , mat_dim_x . width_ ,
" ShapeError: Unsatisfied mkl acceleration library requirements: "
" The number of heads "
" (%d) must be equal to X's width. But received X's shape: %s. " ,
head_number , DumpMatrixShape ( mat_dim_x ) . c_str ( ) ) ;
if ( ! split_vertical_y & & head_number > 0 ) {
dim_out_y = head_number * mat_dim_y . width_ ;
}
# else
PADDLE_ENFORCE_EQ ( mat_dim_x . width_ , mat_dim_y . height_ ) ;
PADDLE_ENFORCE_EQ (
mat_dim_x . width_ , mat_dim_y . height_ ,
" ShapeError: Input X's width should be equal to the Y's height, "
" but received X's shape: %s, "
" Y's shape: %s. " ,
DumpMatrixShape ( mat_dim_x ) . c_str ( ) , DumpMatrixShape ( mat_dim_y ) . c_str ( ) ) ;
# endif
if ( mat_dim_x . batch_size_ ! = 0 ) {
@ -461,15 +488,11 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
REGISTER_OPERATOR ( matmul_grad , ops : : MatMulOpGrad ) ;
REGISTER_OP_CPU_KERNEL (
matmul , ops : : MatMulKernel < paddle : : platform : : CPUDeviceContext , float > ,
ops : : MatMulKernel < paddle : : platform : : CPUDeviceContext , double > ,
ops : : MatMulKernel < paddle : : platform : : CPUDeviceContext ,
paddle : : platform : : float16 > ) ;
ops : : MatMulKernel < paddle : : platform : : CPUDeviceContext , double > ) ;
REGISTER_OP_CPU_KERNEL (
matmul_grad ,
ops : : MatMulGradKernel < paddle : : platform : : CPUDeviceContext , float > ,
ops : : MatMulGradKernel < paddle : : platform : : CPUDeviceContext , double > ,
ops : : MatMulGradKernel < paddle : : platform : : CPUDeviceContext ,
paddle : : platform : : float16 > ) ;
ops : : MatMulGradKernel < paddle : : platform : : CPUDeviceContext , double > ) ;
# ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL (