Merge pull request #16902 from guoshengCS/refine-infer-shape

Refine ENFORCE in infer_shape of gru_op and lstm_unit_op.
shanyi15-patch-1
Guo Sheng 6 years ago committed by GitHub
commit 9f1d4a152b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel {
auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1];
int frame_size = weight_dims[0];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
}
PADDLE_ENFORCE_EQ(
weight_dims[1], frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");

@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel {
auto c_prev_dims = ctx->GetInputDim("C_prev");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
"Batch size of inputs and states must be equal");
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
"Dimension of FC should equal to prev state * 4");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
"Batch size of inputs and states must be equal");
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
"Dimension of FC should equal to prev state * 4");
}
int b_size = c_prev_dims[0]; // batch size
int s_dim = c_prev_dims[1]; // state dim

Loading…
Cancel
Save