|
|
|
@ -75,26 +75,25 @@ void MklSmmCompare(int m, int n, int k) {
|
|
|
|
|
for (int i = 0; i < mat_b.numel(); ++i) {
|
|
|
|
|
B[i] = static_cast<T>(i);
|
|
|
|
|
}
|
|
|
|
|
// lda,ldb,ldc follow RowMajor
|
|
|
|
|
int lda = k;
|
|
|
|
|
int ldb = n;
|
|
|
|
|
int ldc = n;
|
|
|
|
|
|
|
|
|
|
auto smm = [&, m, n, k, alpha, beta]() {
|
|
|
|
|
auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
|
|
|
|
|
const char transa = 'N';
|
|
|
|
|
const char transb = 'N';
|
|
|
|
|
const int lda = m;
|
|
|
|
|
const int ldb = k;
|
|
|
|
|
const int ldc = m;
|
|
|
|
|
paddle::operators::math::CBlas<T>::SMM_GEMM(&transa, &transb, &n, &m, &k,
|
|
|
|
|
&alpha, B, &ldb, A, &lda, &beta,
|
|
|
|
|
CSMM, &ldc);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto mkl = [&, m, n, k, alpha, beta]() {
|
|
|
|
|
int lda = k;
|
|
|
|
|
int ldb = n;
|
|
|
|
|
int ldc = n;
|
|
|
|
|
auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
|
|
|
|
|
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
|
|
|
|
|
CblasNoTrans, m, n, k, alpha, A,
|
|
|
|
|
lda, B, ldb, beta, CMKL, ldc);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
smm();
|
|
|
|
|
mkl();
|
|
|
|
|
ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel());
|
|
|
|
@ -105,6 +104,8 @@ void MklSmmCompare(int m, int n, int k) {
|
|
|
|
|
TEST(math_function, gemm_mkl_vs_smm) {
|
|
|
|
|
MklSmmCompare<float>(1, 2, 3);
|
|
|
|
|
MklSmmCompare<double>(1, 2, 3);
|
|
|
|
|
MklSmmCompare<float>(3, 2, 1);
|
|
|
|
|
MklSmmCompare<double>(3, 2, 1);
|
|
|
|
|
MklSmmCompare<float>(3, 8, 5);
|
|
|
|
|
MklSmmCompare<double>(3, 8, 5);
|
|
|
|
|
}
|
|
|
|
|