|
|
|
@ -76,12 +76,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
|
"The first dimension of Input(Bias) should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
|
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
|
"7 * %d if enable peepholes connection or"
|
|
|
|
|
"4 * %d if disable peepholes",
|
|
|
|
|
frame_size, frame_size);
|
|
|
|
|
if (ctx->Attrs().Get<bool>("use_peepholes")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
|
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
|
"7 * %d if enable peepholes connection",
|
|
|
|
|
frame_size);
|
|
|
|
|
ctx->SetOutputDim("CheckedCell", {2, frame_size});
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
|
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
|
"4 * %d if disable peepholes",
|
|
|
|
|
frame_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims({x_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
@ -173,6 +179,8 @@ void FusionLSTMOpMaker::Make() {
|
|
|
|
|
AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
|
|
|
|
|
AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
|
|
|
|
|
AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
|
|
|
|
|
AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddAttr<bool>("use_peepholes",
|
|
|
|
|
"(bool, defalut: True) "
|
|
|
|
|
"whether to enable diagonal/peephole connections.")
|
|
|
|
@ -250,19 +258,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int D3 = D * 3; \
|
|
|
|
|
const int D4 = wh_dims[1];
|
|
|
|
|
|
|
|
|
|
#define INIT_BASE_INPUT_DATAS \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
/* diagonal weight*/ \
|
|
|
|
|
const T* wc_data = bias->data<T>() + D4; \
|
|
|
|
|
/* for peephole only*/ \
|
|
|
|
|
Tensor checked_cell; \
|
|
|
|
|
T* checked_cell_data = nullptr; \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
if (use_peepholes) { \
|
|
|
|
|
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
|
|
|
|
|
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
|
|
|
|
|
#define INIT_BASE_INPUT_DATAS \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
/* diagonal weight*/ \
|
|
|
|
|
const T* wc_data = bias->data<T>() + D4; \
|
|
|
|
|
/* for peephole only*/ \
|
|
|
|
|
T* checked_cell_data = nullptr; \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
if (use_peepholes) { \
|
|
|
|
|
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
|
|
|
|
|
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
|
|
|
|
|
checked_cell_data = checked_cell->mutable_data<T>(place); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Compute LSTM
|
|
|
|
|