|
|
|
@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
"Output(Hidden) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
|
"Output(Cell) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
|
|
|
|
|
"Output(BatchedInput) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
|
|
|
|
|
"Output(BatchedHidden) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
|
|
|
|
|
"Output(BatchedCell) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
|
|
|
|
|
"Output(ReorderedH0) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
|
|
|
|
|
"Output(ReorderedC0) of LSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
@ -99,17 +89,26 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
framework::DDim out_dims({x_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
|
ctx->SetOutputDim("Cell", out_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("BatchedHidden", out_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchedCell", out_dims);
|
|
|
|
|
ctx->ShareLoD("X", "Hidden");
|
|
|
|
|
ctx->ShareLoD("X", "Cell");
|
|
|
|
|
|
|
|
|
|
int xx_width;
|
|
|
|
|
if (ctx->Attrs().Get<bool>("use_seq")) {
|
|
|
|
|
xx_width = wx_dims[1];
|
|
|
|
|
} else {
|
|
|
|
|
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
|
|
|
|
|
"Output(BatchedInput) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
|
|
|
|
|
"Output(BatchedHidden) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
|
|
|
|
|
"Output(BatchedCell) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
|
|
|
|
|
"Output(ReorderedH0) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
|
|
|
|
|
"Output(ReorderedC0) of LSTM should not be null.");
|
|
|
|
|
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("BatchedHidden", out_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchedCell", out_dims);
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
|
|
|
|
|
ctx->ShareLoD("X", "XX");
|
|
|
|
|