|
|
|
@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto weight_dims = ctx->GetInputDim("Weight");
|
|
|
|
|
int input_size = input_dims[1];
|
|
|
|
|
int frame_size = weight_dims[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_size, frame_size * 3,
|
|
|
|
|
"The input_size must be 3 times of frame_size in GRUOp.");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_dims[1], frame_size * 3,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|