|
|
|
@ -485,6 +485,108 @@ TEST(JITKernel, lstmc1h1) {
|
|
|
|
|
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename KernelTuples>
|
|
|
|
|
void TestGRUFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
const std::vector<T>& xsrc, const std::vector<T>& ht_1,
|
|
|
|
|
const std::vector<T>& ht_ref,
|
|
|
|
|
const paddle::operators::jit::gru_attr_t& attr) {
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
EXPECT_EQ(ht_1.size(), ht_ref.size());
|
|
|
|
|
EXPECT_EQ(xsrc.size(), 3 * ht_ref.size());
|
|
|
|
|
|
|
|
|
|
// x could be changed after compute, so copy to save src
|
|
|
|
|
int d = ht_ref.size();
|
|
|
|
|
std::vector<T> x(xsrc.size()), ht(ht_ref.size());
|
|
|
|
|
std::copy(xsrc.begin(), xsrc.end(), x.begin());
|
|
|
|
|
const T* ht_1_data = ht_1.data();
|
|
|
|
|
const T* ht_ref_data = ht_ref.data();
|
|
|
|
|
T* x_data = x.data();
|
|
|
|
|
T* ht_data = ht.data();
|
|
|
|
|
paddle::operators::jit::gru_t step;
|
|
|
|
|
step.gates = x_data;
|
|
|
|
|
step.ht_1 = ht_1_data;
|
|
|
|
|
step.ht = ht_data;
|
|
|
|
|
tgt(&step, &attr);
|
|
|
|
|
ExpectEQ<T>(ht_data, ht_ref_data, d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestGRUKernel() {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
|
|
|
|
|
for (int d : TestSizes()) {
|
|
|
|
|
for (auto& act_gate : all_acts) {
|
|
|
|
|
for (auto& act_cand : all_acts) {
|
|
|
|
|
std::string info = act_gate + act_cand + "size_" + std::to_string(d);
|
|
|
|
|
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
|
|
|
|
|
jit::to_kerneltype(act_cand));
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
|
|
|
|
|
RandomVec<T>(3 * d, xsrc.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
|
|
|
|
|
// x could be changed after compute, so copy to save src
|
|
|
|
|
std::vector<T> x(xsrc.size());
|
|
|
|
|
std::copy(xsrc.begin(), xsrc.end(), x.begin());
|
|
|
|
|
const T* ht_1_data = ht_1.data();
|
|
|
|
|
T* x_data = x.data();
|
|
|
|
|
T* ht_ref_data = ht_ref.data();
|
|
|
|
|
jit::gru_t step;
|
|
|
|
|
step.gates = x_data;
|
|
|
|
|
step.ht_1 = ht_1_data;
|
|
|
|
|
step.ht = ht_ref_data;
|
|
|
|
|
ref(&step, &attr);
|
|
|
|
|
|
|
|
|
|
// test jitcode
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::GRUTuples<T>, PlaceType>(attr);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
VLOG(10) << "Test Jitcode Kernel " << info;
|
|
|
|
|
TestGRUFunc<T, jit::GRUTuples<T>>(jitcode, xsrc, ht_1, ht_ref, attr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// test all impls in more
|
|
|
|
|
jit::KernelKey kkey(KT, PlaceType());
|
|
|
|
|
auto& pool = jit::KernelPool().Instance().AllKernels();
|
|
|
|
|
auto iter = pool.find(kkey);
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i = dynamic_cast<const jit::KernelImpl<jit::GRUTuples<T>>*>(
|
|
|
|
|
impl.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
auto more = i->GetFunc();
|
|
|
|
|
VLOG(10) << "Test More Kernel " << info;
|
|
|
|
|
TestGRUFunc<T, jit::GRUTuples<T>>(more, xsrc, ht_1, ht_ref, attr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Test result from Get function
|
|
|
|
|
auto tgt = jit::Get<KT, jit::GRUTuples<T>, PlaceType>(attr);
|
|
|
|
|
TestGRUFunc<T, jit::GRUTuples<T>>(tgt, xsrc, ht_1, ht_ref, attr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, gruh1) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestGRUKernel<jit::gruh1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::gruh1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, gruhtpart1) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestGRUKernel<jit::gruhtpart1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::gruhtpart1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, gruhtpart2) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestGRUKernel<jit::gruhtpart2, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): refine the tests template
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, pool) {
|
|
|
|
|