|
|
|
@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
"Input(WeightX) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
|
|
|
|
|
"Input(WeightH) of GRU should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
|
|
|
|
|
"Output(ReorderedH0) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
|
|
|
|
|
"Output(BatchedInput) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
|
|
|
|
|
"Output(BatchedOut) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Output(Hidden) of GRU should not be null.");
|
|
|
|
|
|
|
|
|
@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
}
|
|
|
|
|
framework::DDim out_dims({x_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("BatchedOut", out_dims);
|
|
|
|
|
ctx->ShareLoD("X", "Hidden");
|
|
|
|
|
|
|
|
|
|
int xx_width;
|
|
|
|
|
if (ctx->Attrs().Get<bool>("use_seq")) {
|
|
|
|
|
xx_width = wx_dims[1];
|
|
|
|
|
} else {
|
|
|
|
|
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
|
|
|
|
|
"Output(ReorderedH0) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
|
|
|
|
|
"Output(BatchedInput) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
|
|
|
|
|
"Output(BatchedOut) of GRU should not be null.");
|
|
|
|
|
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("BatchedOut", out_dims);
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
|
|
|
|
|
ctx->ShareLoD("X", "XX");
|
|
|
|
|