|
|
|
@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
|
|
|
|
|
}
|
|
|
|
|
delete ctx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void GemmWarpTest(int m, int n, int k, T alpha, T beta) {
|
|
|
|
|
paddle::framework::Tensor mat_a;
|
|
|
|
|
paddle::framework::Tensor mat_b;
|
|
|
|
|
paddle::framework::Tensor mat_c_ref;
|
|
|
|
|
paddle::framework::Tensor mat_c_mkl;
|
|
|
|
|
auto* cpu_place = new paddle::platform::CPUPlace();
|
|
|
|
|
|
|
|
|
|
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
|
|
|
|
|
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
|
|
|
|
|
T* CREF = mat_c_ref.mutable_data<T>({m, n}, *cpu_place);
|
|
|
|
|
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(mat_c_mkl.numel(), mat_c_ref.numel());
|
|
|
|
|
for (int i = 0; i < mat_a.numel(); ++i) {
|
|
|
|
|
A[i] = static_cast<T>(i);
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < mat_b.numel(); ++i) {
|
|
|
|
|
B[i] = static_cast<T>(i + 1);
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < mat_c_ref.numel(); ++i) {
|
|
|
|
|
CREF[i] = static_cast<T>(i + 2);
|
|
|
|
|
CMKL[i] = CREF[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// this would call gemm_warp
|
|
|
|
|
paddle::platform::CPUDeviceContext context(*cpu_place);
|
|
|
|
|
GetBlas<T>(context).GEMM(CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B,
|
|
|
|
|
beta, CREF);
|
|
|
|
|
|
|
|
|
|
// lda,ldb,ldc follow RowMajor
|
|
|
|
|
int lda = k;
|
|
|
|
|
int ldb = n;
|
|
|
|
|
int ldc = n;
|
|
|
|
|
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
|
|
|
|
|
CblasNoTrans, m, n, k, alpha, A, lda,
|
|
|
|
|
B, ldb, beta, CMKL, ldc);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
|
|
|
|
|
EXPECT_FLOAT_EQ(CREF[i], CMKL[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(math_function, gemm_warp) {
|
|
|
|
|
GemmWarpTest<float>(3, 2, 5, 1.f, 0.f);
|
|
|
|
|
GemmWarpTest<float>(3, 2, 5, 2.f, 1.f);
|
|
|
|
|
GemmWarpTest<float>(8, 5, 6, 1.f, 0.f);
|
|
|
|
|
GemmWarpTest<float>(8, 5, 6, 2.f, 1.f);
|
|
|
|
|
GemmWarpTest<double>(3, 2, 5, 1.0, 0.0);
|
|
|
|
|
GemmWarpTest<double>(3, 2, 5, 2.0, 1.0);
|
|
|
|
|
GemmWarpTest<double>(8, 5, 6, 1.0, 0.0);
|
|
|
|
|
GemmWarpTest<double>(8, 5, 6, 2.0, 1.0);
|
|
|
|
|
}
|
|
|
|
|