|
|
|
@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto c_prev_dims = ctx->GetInputDim("C_prev");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
|
|
|
|
|
"Batch size of inputs and states must be equal");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
|
|
|
|
|
"Dimension of FC should equal to prev state * 4");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
|
|
|
|
|
"Batch size of inputs and states must be equal");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
|
|
|
|
|
"Dimension of FC should equal to prev state * 4");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int b_size = c_prev_dims[0]; // batch size
|
|
|
|
|
int s_dim = c_prev_dims[1]; // state dim
|
|
|
|
|