|
|
|
@ -350,6 +350,143 @@ TEST(JITKernel, vtanh) {
|
|
|
|
|
TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename KernelTuples>
|
|
|
|
|
void TestLSTMFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
const std::vector<T>& xsrc, const std::vector<T>& wp,
|
|
|
|
|
const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
|
|
|
|
|
const std::vector<T>& ht_ref,
|
|
|
|
|
const paddle::operators::jit::lstm_attr_t& attr) {
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
EXPECT_EQ(ct_ref.size(), ht_ref.size());
|
|
|
|
|
EXPECT_EQ(ct_1.size(), ht_ref.size());
|
|
|
|
|
EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
|
|
|
|
|
EXPECT_EQ(wp.size(), 3 * ht_ref.size());
|
|
|
|
|
|
|
|
|
|
// x could be changed after compute, so copy to save src
|
|
|
|
|
int d = ht_ref.size();
|
|
|
|
|
std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size());
|
|
|
|
|
std::vector<T> checked(2 * d);
|
|
|
|
|
std::copy(xsrc.begin(), xsrc.end(), x.begin());
|
|
|
|
|
|
|
|
|
|
const T* ct_1_data = ct_1.data();
|
|
|
|
|
const T* wp_data = wp.data();
|
|
|
|
|
const T* ct_ref_data = ct_ref.data();
|
|
|
|
|
const T* ht_ref_data = ht_ref.data();
|
|
|
|
|
T* x_data = x.data();
|
|
|
|
|
T* ct_data = ct.data();
|
|
|
|
|
T* ht_data = ht.data();
|
|
|
|
|
T* checked_data = checked.data();
|
|
|
|
|
|
|
|
|
|
paddle::operators::jit::lstm_t step;
|
|
|
|
|
step.gates = x_data;
|
|
|
|
|
step.ct_1 = ct_1_data;
|
|
|
|
|
step.ct = ct_data;
|
|
|
|
|
step.ht = ht_data;
|
|
|
|
|
if (attr.use_peephole) {
|
|
|
|
|
step.wp = wp_data;
|
|
|
|
|
step.checked = checked_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tgt(&step, &attr);
|
|
|
|
|
ExpectEQ<T>(ct_data, ct_ref_data, d);
|
|
|
|
|
ExpectEQ<T>(ht_data, ht_ref_data, d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestLSTMKernel() {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
|
|
|
|
|
for (int d : TestSizes()) {
|
|
|
|
|
for (bool use_peephole : {true, false}) {
|
|
|
|
|
for (auto& act_gate : all_acts) {
|
|
|
|
|
for (auto& act_cand : all_acts) {
|
|
|
|
|
for (auto& act_cell : all_acts) {
|
|
|
|
|
std::string info = act_gate + act_cand + act_cell +
|
|
|
|
|
(use_peephole ? "peephole_" : "") + "size_" +
|
|
|
|
|
std::to_string(d);
|
|
|
|
|
const jit::lstm_attr_t attr(
|
|
|
|
|
d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand),
|
|
|
|
|
jit::to_kerneltype(act_cell), use_peephole);
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::LSTMTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
|
|
|
|
|
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
|
|
|
|
|
RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
|
|
|
|
|
// x could be changed after compute, so copy to save src
|
|
|
|
|
std::vector<T> x(xsrc.size());
|
|
|
|
|
std::copy(xsrc.begin(), xsrc.end(), x.begin());
|
|
|
|
|
const T* ct_1_data = ct_1.data();
|
|
|
|
|
const T* wp_data = wp.data();
|
|
|
|
|
T* x_data = x.data();
|
|
|
|
|
T* checked_data = checked.data();
|
|
|
|
|
T* ct_ref_data = ct_ref.data();
|
|
|
|
|
T* ht_ref_data = ht_ref.data();
|
|
|
|
|
jit::lstm_t step;
|
|
|
|
|
step.gates = x_data;
|
|
|
|
|
step.ct_1 = ct_1_data;
|
|
|
|
|
step.ct = ct_ref_data;
|
|
|
|
|
step.ht = ht_ref_data;
|
|
|
|
|
if (use_peephole) {
|
|
|
|
|
step.wp = wp_data;
|
|
|
|
|
step.checked = checked_data;
|
|
|
|
|
}
|
|
|
|
|
ref(&step, &attr);
|
|
|
|
|
|
|
|
|
|
// test jitcode
|
|
|
|
|
auto jitcode =
|
|
|
|
|
jit::GetJitCode<KT, jit::LSTMTuples<T>, PlaceType>(attr);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
VLOG(10) << "Test Jitcode Kernel " << info;
|
|
|
|
|
TestLSTMFunc<T, jit::LSTMTuples<T>>(jitcode, xsrc, wp, ct_1,
|
|
|
|
|
ct_ref, ht_ref, attr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// test all impls in more
|
|
|
|
|
jit::KernelKey kkey(KT, PlaceType());
|
|
|
|
|
auto& pool = jit::KernelPool().Instance().AllKernels();
|
|
|
|
|
auto iter = pool.find(kkey);
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i =
|
|
|
|
|
dynamic_cast<const jit::KernelImpl<jit::LSTMTuples<T>>*>(
|
|
|
|
|
impl.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
auto more = i->GetFunc();
|
|
|
|
|
VLOG(10) << "Test More Kernel " << info;
|
|
|
|
|
TestLSTMFunc<T, jit::LSTMTuples<T>>(more, xsrc, wp, ct_1,
|
|
|
|
|
ct_ref, ht_ref, attr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Test result from Get function
|
|
|
|
|
auto tgt = jit::Get<KT, jit::LSTMTuples<T>, PlaceType>(attr);
|
|
|
|
|
TestLSTMFunc<T, jit::LSTMTuples<T>>(tgt, xsrc, wp, ct_1, ct_ref,
|
|
|
|
|
ht_ref, attr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, lstmctht) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestLSTMKernel<jit::lstmctht, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::lstmctht, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, lstmc1h1) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestLSTMKernel<jit::lstmc1h1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): refine the tests template
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, pool) {
|
|
|
|
|
// TODO(TJ): add some test
|
|
|
|
|
}
|
|
|
|
|