|
|
|
@ -328,6 +328,123 @@ TEST(JitKernel, vtanh) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void lstm_ctht_ref(
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
|
|
|
|
|
vsigmoid_3d,
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
|
|
|
|
|
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
|
|
|
|
|
vsigmoid_3d->Compute(gates + d, gates + d);
|
|
|
|
|
vtanh_d->Compute(gates, gates);
|
|
|
|
|
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
|
|
|
|
|
const float min = SIGMOID_THRESHOLD_MIN;
|
|
|
|
|
const float max = SIGMOID_THRESHOLD_MAX;
|
|
|
|
|
for (int k = 0; k < d; ++k) {
|
|
|
|
|
// C_t = C_t-1 * fgated + cand_gated * igated
|
|
|
|
|
ct[k] = ct_1[k] * f[k] + gates[k] * i[k];
|
|
|
|
|
// H_t = act_cell(C_t) * ogated
|
|
|
|
|
float tmp = ct[k] * 2;
|
|
|
|
|
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
|
|
|
|
|
vexp_1->Compute(&tmp, &tmp);
|
|
|
|
|
tmp = 2.f / (1.f + tmp) - 1.f;
|
|
|
|
|
ht[k] = tmp * o[k];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void lstm_ctht_better(
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
|
|
|
|
|
vsigmoid_3d,
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VMulKernel<float>>& vmul_d,
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
|
|
|
|
|
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
|
|
|
|
|
int d2 = d * 2;
|
|
|
|
|
vsigmoid_3d->Compute(gates + d, gates + d);
|
|
|
|
|
vtanh_d->Compute(gates, gates);
|
|
|
|
|
vmul_d->Compute(gates, gates + d, gates + d);
|
|
|
|
|
vmul_d->Compute(ct_1, gates + d2, gates + d2);
|
|
|
|
|
vadd_d->Compute(gates + d, gates + d2, ct);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
vtanh_d->Compute(ct, gates + d2);
|
|
|
|
|
vmul_d->Compute(gates + d2, gates + d * 3, ht);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JitKernel, lstm) {
|
|
|
|
|
namespace jit = paddle::operators::math::jitkernel;
|
|
|
|
|
for (int d : {7, 8, 15, 16, 30, 32, 64, 100}) {
|
|
|
|
|
int d4 = d * 4;
|
|
|
|
|
int d3 = d * 3;
|
|
|
|
|
std::vector<float> x(d4), xref(d4);
|
|
|
|
|
std::vector<float> ct_1(d), ct_tgt(d), ht_tgt(d);
|
|
|
|
|
std::vector<float> ct_ref(d), ht_ref(d);
|
|
|
|
|
RandomVec<float>(d4, x.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<float>(d, ct_1.data(), -2.f, 2.f);
|
|
|
|
|
memcpy(xref.data(), x.data(), sizeof(float) * d4);
|
|
|
|
|
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
|
|
|
|
|
const auto& ker =
|
|
|
|
|
jit::KernelPool::Instance()
|
|
|
|
|
.template Get<jit::LSTMKernel<float>, int, const std::string&,
|
|
|
|
|
const std::string&, const std::string&>(
|
|
|
|
|
d, act_gate, act_cand, act_cell);
|
|
|
|
|
// below kernels are used to compute refer
|
|
|
|
|
const auto& vsigmoid_3d =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
|
|
|
|
|
d3);
|
|
|
|
|
const auto& vtanh_d =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
|
|
|
|
|
const auto& vexp_1 =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VExpKernel<float>>(1);
|
|
|
|
|
const auto& vmul_d =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
|
|
|
|
|
const auto& vadd_d =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
|
|
|
|
|
|
|
|
|
|
float* x_data = x.data();
|
|
|
|
|
float* xref_data = xref.data();
|
|
|
|
|
const float* ct_1_data = ct_1.data();
|
|
|
|
|
float* ct_tgt_data = ct_tgt.data();
|
|
|
|
|
float* ht_tgt_data = ht_tgt.data();
|
|
|
|
|
float* ct_ref_data = ct_ref.data();
|
|
|
|
|
float* ht_ref_data = ht_ref.data();
|
|
|
|
|
// compute once to check correctness
|
|
|
|
|
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data,
|
|
|
|
|
ct_ref_data, ht_ref_data);
|
|
|
|
|
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
|
|
|
|
|
for (int i = 0; i < d; ++i) {
|
|
|
|
|
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
|
|
|
|
|
EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto tmkls = GetCurrentUS();
|
|
|
|
|
for (int i = 0; i < repeat; ++i) {
|
|
|
|
|
lstm_ctht_better(vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data,
|
|
|
|
|
ct_1_data, ct_ref_data, ht_ref_data);
|
|
|
|
|
}
|
|
|
|
|
auto tmkle = GetCurrentUS();
|
|
|
|
|
auto trefs = GetCurrentUS();
|
|
|
|
|
for (int i = 0; i < repeat; ++i) {
|
|
|
|
|
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data,
|
|
|
|
|
ct_ref_data, ht_ref_data);
|
|
|
|
|
}
|
|
|
|
|
auto trefe = GetCurrentUS();
|
|
|
|
|
auto ttgts = GetCurrentUS();
|
|
|
|
|
for (int i = 0; i < repeat; ++i) {
|
|
|
|
|
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
|
|
|
|
|
}
|
|
|
|
|
auto ttgte = GetCurrentUS();
|
|
|
|
|
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
|
|
|
|
|
<< " us, better(jit) takes: " << (tmkle - tmkls) / repeat
|
|
|
|
|
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void vscal_ref(const int n, const float a, const float* x, float* y) {
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = a * x[i];
|
|
|
|
|