|
|
|
@ -229,6 +229,25 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>> {
|
|
|
|
|
void operator()(const typename jit::MatMulTuples<T>::func_type tgt,
|
|
|
|
|
const std::vector<T>& a, const std::vector<T>& b,
|
|
|
|
|
const std::vector<T>& cref, int m, int n, int k) {
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
EXPECT_EQ(a.size(), static_cast<size_t>(m * k));
|
|
|
|
|
EXPECT_EQ(b.size(), static_cast<size_t>(k * n));
|
|
|
|
|
EXPECT_EQ(cref.size(), static_cast<size_t>(m * n));
|
|
|
|
|
std::vector<T> c(cref.size());
|
|
|
|
|
const T* a_data = a.data();
|
|
|
|
|
const T* b_data = b.data();
|
|
|
|
|
const T* cref_data = cref.data();
|
|
|
|
|
T* c_data = c.data();
|
|
|
|
|
tgt(a_data, b_data, c_data, m, n, k);
|
|
|
|
|
ExpectEQ<T>(c_data, cref_data, m * n);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
|
|
|
|
|
typename PlaceType, typename... Args>
|
|
|
|
|
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
|
|
|
|
@ -458,6 +477,28 @@ void TestSeqPoolKernel() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestMatMulKernel() {
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
for (int m : {1, 2, 3, 4}) {
|
|
|
|
|
for (int n : {1, 2, 3, 4}) {
|
|
|
|
|
for (int k : TestSizes()) {
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
std::vector<T> a(m * k), b(k * n), c(m * n);
|
|
|
|
|
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<T>(k * n, b.data(), -2.f, 2.f);
|
|
|
|
|
const T* a_data = a.data();
|
|
|
|
|
const T* b_data = b.data();
|
|
|
|
|
T* c_data = c.data();
|
|
|
|
|
ref(a_data, b_data, c_data, m, n, k);
|
|
|
|
|
TestAllImpls<KT, jit::MatMulTuples<T>, PlaceType, std::vector<T>,
|
|
|
|
|
std::vector<T>, std::vector<T>>(k, a, b, c, m, n, k);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestNCHW16CMulNCKernel() {
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
@ -618,6 +659,12 @@ TEST(JITKernel, kSeqPool) {
|
|
|
|
|
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, kMatMul) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestMatMulKernel<jit::kMatMul, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestMatMulKernel<jit::kMatMul, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, kNCHW16CMulNC) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
|
|
|
|
|