@ -21,7 +21,7 @@ template <class T>
struct EigenBlasGemm {
typedef Eigen : : TensorMap < Eigen : : Tensor < T , 2 , Eigen : : RowMajor , int > ,
Eigen : : Aligned >
Matrix;
Eigen Matrix;
static void compute ( const bool transA ,
const bool transB ,
@ -56,14 +56,13 @@ struct EigenBlasGemm {
sizeB [ 1 ] = N ;
CHECK_EQ ( N , ldb ) ;
}
Eigen : : array < int , 2 > sizeC ;
sizeC [ 0 ] = M ;
sizeC [ 1 ] = N ;
CHECK_EQ ( N , ldc ) ;
Eigen : : array < int , 2 > sizeC = { { M , ldc } } ;
Eigen : : array < int , 2 > offsetC = { { 0 , 0 } } ;
Eigen : : array < int , 2 > extentC = { { M , N } } ;
const Matrix a ( const_cast < T * > ( A ) , sizeA ) ;
const Matrix b ( const_cast < T * > ( B ) , sizeB ) ;
Matrix c ( C , sizeC ) ;
const Eigen Matrix a ( const_cast < T * > ( A ) , sizeA ) ;
const Eigen Matrix b ( const_cast < T * > ( B ) , sizeB ) ;
Eigen Matrix c ( C , sizeC ) ;
typedef typename Eigen : : Tensor < T , 2 > : : DimensionPair DimPair ;
Eigen : : array < DimPair , 1 > dims ;
@ -72,12 +71,23 @@ struct EigenBlasGemm {
dims [ 0 ] . second = transB ? 1 : 0 ;
Eigen : : DefaultDevice device ;
if ( alpha = = T ( 1 ) & & beta = = T ( 0 ) ) {
c . device ( device ) = a . contract ( b , dims ) ;
} else if ( alpha = = T ( 1 ) & & beta = = T ( 1 ) ) {
c . device ( device ) + = a . contract ( b , dims ) ;
if ( N = = ldc ) {
if ( alpha = = T ( 1 ) & & beta = = T ( 0 ) ) {
c . device ( device ) = a . contract ( b , dims ) ;
} else if ( alpha = = T ( 1 ) & & beta = = T ( 1 ) ) {
c . device ( device ) + = a . contract ( b , dims ) ;
} else {
c . device ( device ) = alpha * a . contract ( b , dims ) + beta * c ;
}
} else {
c . device ( device ) = alpha * a . contract ( b , dims ) + beta * c ;
if ( alpha = = T ( 1 ) & & beta = = T ( 0 ) ) {
c . slice ( offsetC , extentC ) . device ( device ) = a . contract ( b , dims ) ;
} else if ( alpha = = T ( 1 ) & & beta = = T ( 1 ) ) {
c . slice ( offsetC , extentC ) . device ( device ) + = a . contract ( b , dims ) ;
} else {
c . slice ( offsetC , extentC ) . device ( device ) =
alpha * a . contract ( b , dims ) + beta * c . slice ( offsetC , extentC ) ;
}
}
}
} ;