@ -583,33 +583,64 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
template < >
template < typename T >
void Blas < platform : : CPUDeviceContext > : : BatchedGEMMWithHead (
CBLAS_TRANSPOSE transA , CBLAS_TRANSPOSE transB , int M , int N , int K ,
T alpha , const T * A , const T * B , T beta , T * C , int batchCount ,
int64_t strideA , int64_t strideB , int64_t head_number ) const {
int lda = ( transA = = CblasNoTrans ) ? K : M ;
int ldb = ( transB = = CblasNoTrans ) ? N : K ;
int ldc = N * head_number ;
int sub_width = K / head_number ;
CBLAS_TRANSPOSE transA , CBLAS_TRANSPOSE transB , int W1 , int H1 , int W2 ,
int H2 , T alpha , const T * A , const T * B , T beta , T * C , int batchCount ,
int64_t strideA , int64_t strideB , int64_t head_number ,
bool split_b_vertical ) const {
int lda = ( transA = = CblasNoTrans ) ? W1 : H1 ;
int ldb = ( transB = = CblasNoTrans ) ? W2 : H2 ;
auto a_array = std : : vector < const T * > ( batchCount ) ;
auto b_array = std : : vector < const T * > ( batchCount ) ;
auto c_array = std : : vector < T * > ( batchCount ) ;
for ( int i = 0 ; i < head_number ; i + + ) {
int sub_matA_offset = ( transA = = CblasNoTrans ) ? i * ( K / head_number )
: i * ( K / head_number ) * M ;
int sub_matB_offset = ( transB = = CblasNoTrans ) ? i * ( K / head_number ) * N
: i * ( K / head_number ) ;
int sub_matC_offset = i * N ;
for ( int k = 0 ; k < batchCount ; + + k ) {
a_array [ k ] = & A [ k * strideA ] + sub_matA_offset ;
b_array [ k ] = & B [ k * strideB ] + sub_matB_offset ;
c_array [ k ] = & C [ k * M * head_number * N ] + sub_matC_offset ;
if ( split_b_vertical ) {
int ldc = W2 ;
int sub_width = W2 / head_number ;
for ( int i = 0 ; i < head_number ; i + + ) {
int sub_matA_offset = ( transA = = CblasNoTrans )
? i * ( W1 / head_number )
: i * ( W1 / head_number ) * H1 ;
int sub_matB_offset = ( transB = = CblasNoTrans )
? i * ( W2 / head_number )
: i * ( W2 / head_number ) * H2 ;
int sub_matC_offset = i * W2 / head_number ;
for ( int k = 0 ; k < batchCount ; + + k ) {
a_array [ k ] = & A [ k * strideA ] + sub_matA_offset ;
b_array [ k ] = & B [ k * strideB ] + sub_matB_offset ;
c_array [ k ] = & C [ k * H1 * W2 ] + sub_matC_offset ;
}
CBlas < T > : : GEMM_BATCH ( CblasRowMajor , & transA , & transB , & H1 , & sub_width ,
& H2 , & alpha , a_array . data ( ) , & lda , b_array . data ( ) ,
& ldb , & beta , c_array . data ( ) , & ldc ,
1 /* group_count */ , & batchCount ) ;
}
CBlas < T > : : GEMM_BATCH ( CblasRowMajor , & transA , & transB , & M , & N , & sub_width ,
& alpha , a_array . data ( ) , & lda , b_array . data ( ) , & ldb ,
& beta , c_array . data ( ) , & ldc , 1 /* group_count */ ,
& batchCount ) ;
} else {
PADDLE_ENFORCE_EQ ( W1 , H2 ) ;
int ldc = W2 * head_number ;
int sub_width = W1 / head_number ;
for ( int i = 0 ; i < head_number ; i + + ) {
int sub_matA_offset = ( transA = = CblasNoTrans )
? i * ( W1 / head_number )
: i * ( W1 / head_number ) * H1 ;
int sub_matB_offset = ( transB = = CblasNoTrans )
? i * ( W1 / head_number ) * W2
: i * ( W1 / head_number ) ;
int sub_matC_offset = i * W2 ;
for ( int k = 0 ; k < batchCount ; + + k ) {
a_array [ k ] = & A [ k * strideA ] + sub_matA_offset ;
b_array [ k ] = & B [ k * strideB ] + sub_matB_offset ;
c_array [ k ] = & C [ k * H1 * head_number * W2 ] + sub_matC_offset ;
}
CBlas < T > : : GEMM_BATCH ( CblasRowMajor , & transA , & transB , & H1 , & W2 ,
& sub_width , & alpha , a_array . data ( ) , & lda ,
b_array . data ( ) , & ldb , & beta , c_array . data ( ) , & ldc ,
1 /* group_count */ , & batchCount ) ;
}
}
}
# endif
@ -690,51 +721,86 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
* When user calls this API , the multiplication of two big matrixes is split
* into multiplication of several ( head_number_ ) small matrixes . e . g . if Mat A
* is [ 3 , 24 ] and Mat B is [ 24 , 4 ] , when multiple A and B with head_number as
* 4 , Mat A will be split as 4 matrix of [ 3 , 6 ] and Mat B will be 4 matrix of
* [ 6 , 4 ] . The result of final matrix will be 4 matrix of [ 3 , 4 ] , i . e . [ 3 , 16 ] .
*
* 4 , Mat A will be splitted as 4 matrix of [ 3 , 6 ] and Mat B will be
* ( horizontally ) splitted as 4 matrix of [ 6 , 4 ] . The result of final matrix
* will be 4 matrix of [ 3 , 4 ] , i . e . [ 3 , 16 ] .
* Another example is A is [ 3 , 8 ] , B is [ 2 , 16 ] , head_number is 4. In this
* case , A will be splitted as [ 3 , 2 ] , B will be ( vertically ) splitted as
* [ 2 , 4 ] . The final result will be 4 matrix of 4 matrix of [ 3 , 4 ] , i . e . [ 3 , 16 ]
*/
template < typename DeviceContext >
template < typename T >
void Blas < DeviceContext > : : MatMulWithHead (
const framework : : Tensor & mat_a , const MatDescriptor & dim_a ,
const framework : : Tensor & mat_b , const MatDescriptor & dim_b , T alpha ,
int head_number , framework : : Tensor * mat_out , T beta ) const {
PADDLE_ENFORCE_EQ ( dim_a . width_ , dim_b . height_ ) ;
void Blas < DeviceContext > : : MatMulWithHead ( const framework : : Tensor & mat_a ,
const MatDescriptor & dim_a ,
const framework : : Tensor & mat_b ,
const MatDescriptor & dim_b , T alpha ,
int head_number ,
framework : : Tensor * mat_out , T beta ,
bool mat_b_split_vertical ) const {
PADDLE_ENFORCE_EQ ( dim_a . width_ % head_number , 0 ) ;
PADDLE_ENFORCE_GE ( head_number , 1 ) ;
PADDLE_ENFORCE_LE ( head_number , dim_a . width_ ) ;
CBLAS_TRANSPOSE transA = ! dim_a . trans_ ? CblasNoTrans : CblasTrans ;
CBLAS_TRANSPOSE transB = ! dim_b . trans_ ? CblasNoTrans : CblasTrans ;
if ( mat_b_split_vertical ) {
PADDLE_ENFORCE_EQ ( dim_b . height_ , dim_a . width_ / head_number ) ;
PADDLE_ENFORCE_EQ ( dim_b . width_ % head_number , 0 ) ;
}
if ( dim_a . batch_size_ = = 0 & & dim_b . batch_size_ = = 0 ) {
int lda = ! dim_a . trans_ ? dim_a . width_ : dim_a . height_ ;
int ldb = ! dim_b . trans_ ? dim_b . width_ : dim_b . height_ ;
int sub_matA_offset ;
int sub_matB_offset ;
int sub_matC_offset ;
int sub_mat_M = dim_a . height_ ;
int sub_mat_N ;
int sub_mat_K ;
int ldc ;
for ( int i = 0 ; i < head_number ; i + + ) {
int sub_matA_offset =
dim_a . trans_ ? i * ( dim_a . width_ / head_number ) * dim_a . height_
: i * ( dim_a . width_ / head_number ) ;
int sub_matB_offset =
dim_b . trans_ ? i * ( dim_b . height_ / head_number )
: i * ( dim_b . height_ / head_number ) * dim_b . width_ ;
int sub_matC_offset = i * dim_b . width_ ;
int lda = ! dim_a . trans_ ? dim_a . width_ : dim_a . height_ ;
int ldb = ! dim_b . trans_ ? dim_b . width_ : dim_b . height_ ;
int ldc = head_number * dim_b . width_ ;
this - > template GEMM < T > ( transA , transB , dim_a . height_ , dim_b . width_ ,
dim_a . width_ / head_number , alpha ,
mat_a . data < T > ( ) + sub_matA_offset , lda ,
sub_matA_offset = dim_a . trans_
? i * ( dim_a . width_ / head_number ) * dim_a . height_
: i * ( dim_a . width_ / head_number ) ;
if ( mat_b_split_vertical ) {
sub_matB_offset = dim_b . trans_
? i * ( dim_b . width_ / head_number ) * dim_b . height_
: i * ( dim_b . width_ / head_number ) ;
sub_matC_offset = i * dim_b . width_ / head_number ;
sub_mat_N = dim_b . width_ / head_number ;
sub_mat_K = dim_b . height_ ;
ldc = dim_b . width_ ;
} else {
sub_matB_offset =
dim_b . trans_ ? i * ( dim_b . height_ / head_number )
: i * ( dim_b . height_ / head_number ) * dim_b . width_ ;
sub_matC_offset = i * dim_b . width_ ;
sub_mat_N = dim_b . width_ ;
sub_mat_K = dim_a . width_ / head_number ;
ldc = head_number * dim_b . width_ ;
}
this - > template GEMM < T > ( transA , transB , sub_mat_M , sub_mat_N , sub_mat_K ,
alpha , mat_a . data < T > ( ) + sub_matA_offset , lda ,
mat_b . data < T > ( ) + sub_matB_offset , ldb , beta ,
mat_out - > data < T > ( ) + sub_matC_offset , ldc ) ;
}
} else {
PADDLE_ENFORCE ( dim_a . batch_size_ = = dim_b . batch_size_ | |
dim_a . batch_size_ = = 0 | | dim_b . batch_size_ = = 0 ) ;
PADDLE_ENFORCE_EQ ( ( dim_a . batch_size_ = = dim_b . batch_size_ | |
dim_a . batch_size_ = = 0 | | dim_b . batch_size_ = = 0 ) ,
true ) ;
this - > template BatchedGEMMWithHead < T > (
transA , transB , dim_a . height_ , dim_b . width_ , dim_a . width_ , alpha ,
mat_a . data < T > ( ) , mat_b . data < T > ( ) , beta , mat_out - > data < T > ( ) ,
transA , transB , dim_a . width_ , dim_a . height_ , dim_b . width_ ,
dim_b . height_ , alpha , mat_a . data < T > ( ) , mat_b . data < T > ( ) , beta ,
mat_out - > data < T > ( ) ,
dim_a . batch_size_ = = 0 ? dim_b . batch_size_ : dim_a . batch_size_ ,
dim_a . stride_ , dim_b . stride_ , head_number ) ;
dim_a . stride_ , dim_b . stride_ , head_number , mat_b_split_vertical );
}
}
# endif