|
|
|
@ -236,33 +236,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D4 = wh_dims[1]
|
|
|
|
|
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
/* diagonal weight*/ \
|
|
|
|
|
const T* wp_data = bias->data<T>() + D4; \
|
|
|
|
|
/* for peephole only*/ \
|
|
|
|
|
T* checked_cell_data = nullptr; \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
if (use_peepholes) { \
|
|
|
|
|
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
|
|
|
|
|
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
|
|
|
|
|
checked_cell_data = checked_cell->mutable_data<T>(place); \
|
|
|
|
|
} \
|
|
|
|
|
const jit \
|
|
|
|
|
: lstm_attr_t attr( \
|
|
|
|
|
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
|
|
|
|
|
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
|
|
|
|
|
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
|
|
|
|
|
use_peepholes); \
|
|
|
|
|
math::jitkernel::lstm_t one_step; \
|
|
|
|
|
one_step.wp = wp_data; \
|
|
|
|
|
one_step.checked = checked_cell_data; \
|
|
|
|
|
auto ComputeC1H1 = \
|
|
|
|
|
jit::Get<jit::lstmc1h1, jit::LSTMTuples, platform::CPUPlace>(attr); \
|
|
|
|
|
auto ComputeCtHt = \
|
|
|
|
|
jit::Get<jit::lstmctht, jit::LSTMTuples, platform::CPUPlace>(attr)
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
/* diagonal weight*/ \
|
|
|
|
|
const T* wp_data = bias->data<T>() + D4; \
|
|
|
|
|
/* for peephole only*/ \
|
|
|
|
|
T* checked_cell_data = nullptr; \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
if (use_peepholes) { \
|
|
|
|
|
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
|
|
|
|
|
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
|
|
|
|
|
checked_cell_data = checked_cell->mutable_data<T>(place); \
|
|
|
|
|
} \
|
|
|
|
|
const jit::lstm_attr_t attr( \
|
|
|
|
|
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
|
|
|
|
|
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
|
|
|
|
|
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
|
|
|
|
|
use_peepholes); \
|
|
|
|
|
jit::lstm_t one_step; \
|
|
|
|
|
one_step.wp = wp_data; \
|
|
|
|
|
one_step.checked = checked_cell_data; \
|
|
|
|
|
auto ComputeC1H1 = \
|
|
|
|
|
jit::Get<jit::lstmc1h1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
|
|
|
|
|
auto ComputeCtHt = \
|
|
|
|
|
jit::Get<jit::lstmctht, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
|
|
|
|
|
|
|
|
|
|
// Wh GEMM
|
|
|
|
|
#define GEMM_WH_ADDON(bs, prev, out) \
|
|
|
|
@ -434,7 +433,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
one_step.ct_1 = cur_prev_c_data;
|
|
|
|
|
one_step.ct = cur_c_out_data;
|
|
|
|
|
one_step.ht = cur_h_out_data;
|
|
|
|
|
ComputeC1H1(&one_step, &attr);
|
|
|
|
|
ComputeCtHt(&one_step, &attr);
|
|
|
|
|
|
|
|
|
|
// move one batch
|
|
|
|
|
cur_in_data += D4;
|
|
|
|
|