|
|
|
@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D4 = wh_dims[1]
|
|
|
|
|
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
/* diagonal weight*/ \
|
|
|
|
|
const T* wp_data = bias->data<T>() + D4; \
|
|
|
|
|
/* for peephole only*/ \
|
|
|
|
|
T* checked_cell_data = nullptr; \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
if (use_peepholes) { \
|
|
|
|
|
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
|
|
|
|
|
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
|
|
|
|
|
checked_cell_data = checked_cell->mutable_data<T>(place); \
|
|
|
|
|
} \
|
|
|
|
|
const auto& ker = \
|
|
|
|
|
math::jitkernel::KernelPool::Instance() \
|
|
|
|
|
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
|
|
|
|
|
const std::string&, const std::string&>( \
|
|
|
|
|
ctx.Attr<std::string>("gate_activation"), \
|
|
|
|
|
ctx.Attr<std::string>("candidate_activation"), \
|
|
|
|
|
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
/* diagonal weight*/ \
|
|
|
|
|
const T* wp_data = bias->data<T>() + D4; \
|
|
|
|
|
/* for peephole only*/ \
|
|
|
|
|
T* checked_cell_data = nullptr; \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
if (use_peepholes) { \
|
|
|
|
|
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
|
|
|
|
|
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
|
|
|
|
|
checked_cell_data = checked_cell->mutable_data<T>(place); \
|
|
|
|
|
} \
|
|
|
|
|
const math::jitkernel::lstm_attr_t attr( \
|
|
|
|
|
D, ctx.Attr<std::string>("gate_activation"), \
|
|
|
|
|
ctx.Attr<std::string>("candidate_activation"), \
|
|
|
|
|
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
|
|
|
|
|
math::jitkernel::lstm_t one_step; \
|
|
|
|
|
one_step.wp = wp_data; \
|
|
|
|
|
one_step.checked = checked_cell_data; \
|
|
|
|
|
const auto& ker = \
|
|
|
|
|
math::jitkernel::KernelPool::Instance() \
|
|
|
|
|
.template Get<math::jitkernel::LSTMKernel<T>, \
|
|
|
|
|
const math::jitkernel::lstm_attr_t&>(attr)
|
|
|
|
|
|
|
|
|
|
// Wh GEMM
|
|
|
|
|
#define GEMM_WH_ADDON(bs, prev, out) \
|
|
|
|
@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
prev_h_data = h0_data + bid * D;
|
|
|
|
|
prev_c_data = c0_data + bid * D;
|
|
|
|
|
} else {
|
|
|
|
|
ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data);
|
|
|
|
|
one_step.gates = xx_data;
|
|
|
|
|
one_step.ct = c_out_data;
|
|
|
|
|
one_step.ht = h_out_data;
|
|
|
|
|
ker->ComputeC1H1(&one_step, &attr);
|
|
|
|
|
tstart = 1;
|
|
|
|
|
// move one step
|
|
|
|
|
prev_h_data = h_out_data;
|
|
|
|
@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
for (int step = tstart; step < seq_len; ++step) {
|
|
|
|
|
GEMM_WH_ADDON(1, prev_h_data, xx_data);
|
|
|
|
|
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data,
|
|
|
|
|
checked_cell_data);
|
|
|
|
|
|
|
|
|
|
one_step.gates = xx_data;
|
|
|
|
|
one_step.ct_1 = prev_c_data;
|
|
|
|
|
one_step.ct = c_out_data;
|
|
|
|
|
one_step.ht = h_out_data;
|
|
|
|
|
ker->ComputeCtHt(&one_step, &attr);
|
|
|
|
|
// move one step
|
|
|
|
|
prev_h_data = h_out_data;
|
|
|
|
|
prev_c_data = c_out_data;
|
|
|
|
@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* cur_h_out_data = batched_h_out_data;
|
|
|
|
|
T* cur_c_out_data = batched_c_out_data;
|
|
|
|
|
for (int i = 0; i < max_bs; ++i) {
|
|
|
|
|
ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data);
|
|
|
|
|
one_step.gates = cur_in_data;
|
|
|
|
|
one_step.ct = cur_c_out_data;
|
|
|
|
|
one_step.ht = cur_h_out_data;
|
|
|
|
|
ker->ComputeC1H1(&one_step, &attr);
|
|
|
|
|
|
|
|
|
|
cur_in_data += D4;
|
|
|
|
|
cur_c_out_data += D;
|
|
|
|
|
cur_h_out_data += D;
|
|
|
|
@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* cur_c_out_data = batched_c_out_data;
|
|
|
|
|
T* cur_h_out_data = batched_h_out_data;
|
|
|
|
|
for (int i = 0; i < cur_bs; ++i) {
|
|
|
|
|
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
|
|
|
|
|
cur_h_out_data, wp_data, checked_cell_data);
|
|
|
|
|
one_step.gates = cur_in_data;
|
|
|
|
|
one_step.ct_1 = cur_prev_c_data;
|
|
|
|
|
one_step.ct = cur_c_out_data;
|
|
|
|
|
one_step.ht = cur_h_out_data;
|
|
|
|
|
ker->ComputeCtHt(&one_step, &attr);
|
|
|
|
|
|
|
|
|
|
// move one batch
|
|
|
|
|
cur_in_data += D4;
|
|
|
|
|
cur_prev_c_data += D;
|
|
|
|
|