|
|
|
@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
|
|
|
|
|
std::vector<T>> {
|
|
|
|
|
void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
|
|
|
|
|
const std::vector<T>& x, const std::vector<T>& yref,
|
|
|
|
|
const typename jit::SeqPoolTuples<T>::attr_type& attr) {
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
EXPECT_EQ(x.size() % yref.size(), 0);
|
|
|
|
|
int w = yref.size();
|
|
|
|
|
std::vector<T> y(w);
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
const T* yref_data = yref.data();
|
|
|
|
|
T* y_data = y.data();
|
|
|
|
|
tgt(x_data, y_data, &attr);
|
|
|
|
|
ExpectEQ<T>(y_data, yref_data, w);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
|
|
|
|
|
typename PlaceType, typename... Args>
|
|
|
|
|
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
|
|
|
|
@ -415,6 +433,30 @@ void TestGRUKernel() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestSeqPoolKernel() {
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
// TODO(TJ): support more
|
|
|
|
|
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
|
|
|
|
|
for (auto type : pool_types) {
|
|
|
|
|
for (int h : TestSizes()) {
|
|
|
|
|
for (int w : TestSizes()) {
|
|
|
|
|
const jit::seq_pool_attr_t attr(h, w, type);
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
std::vector<T> x(h * w), yref(w);
|
|
|
|
|
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
T* yref_data = yref.data();
|
|
|
|
|
ref(x_data, yref_data, &attr);
|
|
|
|
|
VLOG(10) << attr;
|
|
|
|
|
TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>,
|
|
|
|
|
std::vector<T>>(attr, x, yref, attr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestNCHW16CMulNCKernel() {
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
@ -569,6 +611,12 @@ TEST(JITKernel, kGRUHtPart2) {
|
|
|
|
|
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, kSeqPool) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, kNCHW16CMulNC) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
|
|
|
|
|