|
|
|
@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
|
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
|
|
|
|
|
if (ctx->IsRuntime() ||
|
|
|
|
|
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
|
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
|
|
|
|
@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionBias")) {
|
|
|
|
|
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
|
|
|
|
|