|
|
|
@ -285,18 +285,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
act_cell(D, ct, gates + D2); \
|
|
|
|
|
blas.VMUL(D, gates + D2, gates + D3, ht)
|
|
|
|
|
|
|
|
|
|
#define COMPUTE_CtHt_WITHOUT_H0C0(gates, ct, ht) \
|
|
|
|
|
act_gate(D, gates + D, gates + D); \
|
|
|
|
|
act_cand(D, gates, gates); \
|
|
|
|
|
/* C_t = igated * cgated*/ \
|
|
|
|
|
blas.VMUL(D, gates, gates + D, ct); \
|
|
|
|
|
/* get outgated*/ \
|
|
|
|
|
if (use_peepholes) { \
|
|
|
|
|
/* put W_oc * C_t on igated */ \
|
|
|
|
|
blas.VMUL(D, wc_data + D2, ct, gates + D); \
|
|
|
|
|
blas.VADD(D, gates + D, gates + D3, gates + D3); \
|
|
|
|
|
} \
|
|
|
|
|
act_gate(D, gates + D3, gates + D3); \
|
|
|
|
|
#define GET_Ct_NOH0C0(gates, ct) \
|
|
|
|
|
/* C_t = igated * cgated*/ \
|
|
|
|
|
act_gate(D, gates + D, gates + D); \
|
|
|
|
|
act_cand(D, gates, gates); \
|
|
|
|
|
blas.VMUL(D, gates, gates + D, ct)
|
|
|
|
|
|
|
|
|
|
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
|
|
|
|
|
GET_Ct_NOH0C0(gates, ct); \
|
|
|
|
|
act_gate(D, gates + D3, gates + D3); \
|
|
|
|
|
GET_Ht(ct, gates, ht)
|
|
|
|
|
|
|
|
|
|
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
|
|
|
|
|
GET_Ct_NOH0C0(gates, ct); \
|
|
|
|
|
/* get outgated, put W_oc * C_t on igated */ \
|
|
|
|
|
blas.VMUL(D, wc_data + D2, ct, gates + D); \
|
|
|
|
|
blas.VADD(D, gates + D, gates + D3, gates + D3); \
|
|
|
|
|
act_gate(D, gates + D3, gates + D3); \
|
|
|
|
|
GET_Ht(ct, gates, ht)
|
|
|
|
|
|
|
|
|
|
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
|
|
|
|
@ -354,24 +359,38 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
h_out_data = h_out_data + gate_offset; \
|
|
|
|
|
c_out_data = c_out_data + gate_offset
|
|
|
|
|
|
|
|
|
|
#define PROCESS_H0C0 \
|
|
|
|
|
int bid = is_reverse ? N - 1 - i : i; \
|
|
|
|
|
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
|
|
|
|
|
const T* prev_c_data = nullptr; \
|
|
|
|
|
const T* prev_h_data = nullptr; \
|
|
|
|
|
int tstart = 0; \
|
|
|
|
|
if (h0_data) { \
|
|
|
|
|
prev_h_data = h0_data + bid * D; \
|
|
|
|
|
prev_c_data = c0_data + bid * D; \
|
|
|
|
|
} else { \
|
|
|
|
|
COMPUTE_CtHt_WITHOUT_H0C0(xx_data, c_out_data, h_out_data); \
|
|
|
|
|
MOVE_ONE_STEP; \
|
|
|
|
|
tstart = 1; \
|
|
|
|
|
#define PROCESS_H0C0_DEFINES \
|
|
|
|
|
int bid = is_reverse ? N - 1 - i : i; \
|
|
|
|
|
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
|
|
|
|
|
const T* prev_c_data = nullptr; \
|
|
|
|
|
const T* prev_h_data = nullptr; \
|
|
|
|
|
int tstart = 0
|
|
|
|
|
|
|
|
|
|
#define PROCESS_H0C0_PEEPHOLE \
|
|
|
|
|
PROCESS_H0C0_DEFINES; \
|
|
|
|
|
if (h0_data) { \
|
|
|
|
|
prev_h_data = h0_data + bid * D; \
|
|
|
|
|
prev_c_data = c0_data + bid * D; \
|
|
|
|
|
} else { \
|
|
|
|
|
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
|
|
|
|
|
MOVE_ONE_STEP; \
|
|
|
|
|
tstart = 1; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define PROCESS_H0C0 \
|
|
|
|
|
PROCESS_H0C0_DEFINES; \
|
|
|
|
|
if (h0_data) { \
|
|
|
|
|
prev_h_data = h0_data + bid * D; \
|
|
|
|
|
prev_c_data = c0_data + bid * D; \
|
|
|
|
|
} else { \
|
|
|
|
|
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
|
|
|
|
|
MOVE_ONE_STEP; \
|
|
|
|
|
tstart = 1; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (use_peepholes) {
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
|
PROCESS_H0C0;
|
|
|
|
|
PROCESS_H0C0_PEEPHOLE
|
|
|
|
|
for (int step = tstart; step < seq_len; ++step) {
|
|
|
|
|
GEMM_WH_ADDON(1, prev_h_data, xx_data);
|
|
|
|
|
COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data);
|
|
|
|
@ -380,7 +399,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
|
PROCESS_H0C0;
|
|
|
|
|
PROCESS_H0C0
|
|
|
|
|
for (int step = tstart; step < seq_len; ++step) {
|
|
|
|
|
GEMM_WH_ADDON(1, prev_h_data, xx_data);
|
|
|
|
|
COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
|
|
|
|
@ -388,6 +407,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#undef PROCESS_H0C0_DEFINES
|
|
|
|
|
#undef PROCESS_H0C0_PEEPHOLE
|
|
|
|
|
#undef PROCESS_H0C0
|
|
|
|
|
#undef MOVE_ONE_STEP
|
|
|
|
|
}
|
|
|
|
@ -460,7 +481,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* cur_h_out_data = batched_h_out_data;
|
|
|
|
|
T* cur_c_out_data = batched_c_out_data;
|
|
|
|
|
for (int i = 0; i < max_bs; ++i) {
|
|
|
|
|
COMPUTE_CtHt_WITHOUT_H0C0(cur_in_data, cur_c_out_data, cur_h_out_data);
|
|
|
|
|
GET_Ct_NOH0C0(cur_in_data, cur_c_out_data);
|
|
|
|
|
if (use_peepholes) {
|
|
|
|
|
blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D);
|
|
|
|
|
blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
|
|
|
|
|
}
|
|
|
|
|
act_gate(D, cur_in_data + D3, cur_in_data + D3);
|
|
|
|
|
GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data);
|
|
|
|
|
cur_in_data += D4;
|
|
|
|
|
cur_c_out_data += D;
|
|
|
|
|
cur_h_out_data += D;
|
|
|
|
@ -541,7 +568,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
#undef COMPUTE_CtHt_PEEPHOLE
|
|
|
|
|
#undef COMPUTE_CtHt
|
|
|
|
|
#undef COMPUTE_CtHt_WITHOUT_H0C0
|
|
|
|
|
#undef GET_Ct_NOH0C0
|
|
|
|
|
#undef COMPUTE_CtHt_NOH0C0
|
|
|
|
|
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
|
|
|
|
|
#undef GET_Ht
|
|
|
|
|
#undef GET_Ct
|
|
|
|
|
#undef GEMM_WH_ADDON
|
|
|
|
|