|
|
|
@ -250,6 +250,106 @@ TEST(JITKernel, vaddbias) {
|
|
|
|
|
TestAXYNKernel<jit::vaddbias, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename KernelTuples>
|
|
|
|
|
void TestXYNFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
const std::vector<T>& x, const std::vector<T>& yref) {
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
EXPECT_EQ(yref.size(), x.size());
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
const T* yref_data = yref.data();
|
|
|
|
|
const int d = yref.size();
|
|
|
|
|
std::vector<T> ytgt(d);
|
|
|
|
|
T* ytgt_data = ytgt.data();
|
|
|
|
|
// test normal
|
|
|
|
|
tgt(x_data, ytgt_data, d);
|
|
|
|
|
ExpectEQ<T>(ytgt_data, yref_data, d);
|
|
|
|
|
// test inplace x
|
|
|
|
|
std::copy(x.begin(), x.end(), ytgt.begin());
|
|
|
|
|
tgt(ytgt_data, ytgt_data, d);
|
|
|
|
|
ExpectEQ<T>(ytgt_data, yref_data, d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestXYNKernel() {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
for (int d : TestSizes()) {
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
|
|
|
|
|
std::vector<T> x(d), yref(d);
|
|
|
|
|
std::vector<T> xinp(d); // inplace test
|
|
|
|
|
RandomVec<T>(d, x.data());
|
|
|
|
|
std::copy(x.begin(), x.end(), xinp.begin());
|
|
|
|
|
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
T* yref_data = yref.data();
|
|
|
|
|
T* xinp_data = xinp.data();
|
|
|
|
|
// test refer code inplace
|
|
|
|
|
ref(x_data, yref_data, d);
|
|
|
|
|
ref(xinp_data, xinp_data, d);
|
|
|
|
|
ExpectEQ<T>(xinp_data, yref_data, d);
|
|
|
|
|
|
|
|
|
|
// test jitcode
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::XYNTuples<T>, PlaceType>(d);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
VLOG(10) << "Test Jitcode Kernel, size: " << d;
|
|
|
|
|
TestXYNFunc<T, jit::XYNTuples<T>>(jitcode, x, yref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// test all impls in more
|
|
|
|
|
jit::KernelKey kkey(KT, PlaceType());
|
|
|
|
|
auto& pool = jit::KernelPool().Instance().AllKernels();
|
|
|
|
|
auto iter = pool.find(kkey);
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i =
|
|
|
|
|
dynamic_cast<const jit::KernelImpl<jit::XYNTuples<T>>*>(impl.get());
|
|
|
|
|
if (i && i->UseMe(d)) {
|
|
|
|
|
auto more = i->GetFunc();
|
|
|
|
|
VLOG(10) << "Test More Kernel, size: " << d;
|
|
|
|
|
TestXYNFunc<T, jit::XYNTuples<T>>(more, x, yref);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Test result from Get function
|
|
|
|
|
VLOG(10) << "Test Get function, size: " << d;
|
|
|
|
|
auto tgt = jit::Get<KT, jit::XYNTuples<T>, PlaceType>(d);
|
|
|
|
|
TestXYNFunc<T, jit::XYNTuples<T>>(tgt, x, yref);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vrelu) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vrelu, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vrelu, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, videntity) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::videntity, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::videntity, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vexp) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vexp, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vexp, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vsigmoid) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vsigmoid, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vsigmoid, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vtanh) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vtanh, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, pool) {
|
|
|
|
|
// TODO(TJ): add some test
|
|
|
|
|
}
|
|
|
|
|