|
|
|
@ -54,8 +54,63 @@ TEST(math_function, gemm_notrans_cblas) {
|
|
|
|
|
EXPECT_EQ(input3_ptr[6], 86);
|
|
|
|
|
EXPECT_EQ(input3_ptr[7], 99);
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
template <typename T>
|
|
|
|
|
void MklSmmCompare(int m, int n, int k) {
|
|
|
|
|
paddle::framework::Tensor mat_a;
|
|
|
|
|
paddle::framework::Tensor mat_b;
|
|
|
|
|
paddle::framework::Tensor mat_c_smm;
|
|
|
|
|
paddle::framework::Tensor mat_c_mkl;
|
|
|
|
|
auto* cpu_place = new paddle::platform::CPUPlace();
|
|
|
|
|
|
|
|
|
|
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
|
|
|
|
|
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
|
|
|
|
|
T* CSMM = mat_c_smm.mutable_data<T>({m, n}, *cpu_place);
|
|
|
|
|
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
|
|
|
|
|
T alpha = static_cast<T>(1);
|
|
|
|
|
T beta = static_cast<T>(0);
|
|
|
|
|
for (int i = 0; i < mat_a.numel(); ++i) {
|
|
|
|
|
A[i] = static_cast<T>(i);
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < mat_b.numel(); ++i) {
|
|
|
|
|
B[i] = static_cast<T>(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto smm = [&, m, n, k, alpha, beta]() {
|
|
|
|
|
const char transa = 'N';
|
|
|
|
|
const char transb = 'N';
|
|
|
|
|
const int lda = m;
|
|
|
|
|
const int ldb = k;
|
|
|
|
|
const int ldc = m;
|
|
|
|
|
paddle::operators::math::CBlas<T>::SMM_GEMM(&transa, &transb, &m, &n, &k,
|
|
|
|
|
&alpha, A, &lda, B, &ldb, &beta,
|
|
|
|
|
CSMM, &ldc);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto mkl = [&, m, n, k, alpha, beta]() {
|
|
|
|
|
int lda = k;
|
|
|
|
|
int ldb = n;
|
|
|
|
|
int ldc = n;
|
|
|
|
|
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
|
|
|
|
|
CblasNoTrans, m, n, k, alpha, A,
|
|
|
|
|
lda, B, ldb, beta, CMKL, ldc);
|
|
|
|
|
};
|
|
|
|
|
smm();
|
|
|
|
|
mkl();
|
|
|
|
|
ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel());
|
|
|
|
|
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
|
|
|
|
|
EXPECT_FLOAT_EQ(CSMM[i], CMKL[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
TEST(math_function, gemm_mkl_vs_smm) {
|
|
|
|
|
MklSmmCompare<float>(1, 2, 3);
|
|
|
|
|
MklSmmCompare<double>(1, 2, 3);
|
|
|
|
|
MklSmmCompare<float>(3, 8, 5);
|
|
|
|
|
MklSmmCompare<double>(3, 8, 5);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
TEST(math_function, gemm_trans_clbas) {
|
|
|
|
|
TEST(math_function, gemm_trans_cblas) {
|
|
|
|
|
paddle::framework::Tensor input1;
|
|
|
|
|
paddle::framework::Tensor input2;
|
|
|
|
|
paddle::framework::Tensor input3;
|
|
|
|
|