|
|
|
@ -44,14 +44,18 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const int M = x_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected input(X)'s dimension is 2. But received %d.",
|
|
|
|
|
x_dims.size()));
|
|
|
|
|
|
|
|
|
|
auto w_dims = ctx->GetInputDim("LSTMWeight");
|
|
|
|
|
const int D = w_dims[1] / 4;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
w_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument("Input(LSTMWeight)'s rank must be 2."));
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected input(LSTMWeight)'s dimension is 2.But received %d.",
|
|
|
|
|
w_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
w_dims[0], D + M,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
@ -77,8 +81,11 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(H0)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
h_dims.size(), 2UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected input(H0)'s dimension is 2. But received %d.",
|
|
|
|
|
h_dims.size()));
|
|
|
|
|
if (ctx->IsRuntime() ||
|
|
|
|
|
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(h_dims, c_dims,
|
|
|
|
@ -94,7 +101,9 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
"Input(AttentionWeight)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
|
|
|
|
|
"Expected `AttentionWeight` shape is [(%d + %d), 1]. "
|
|
|
|
|
"But received shape = [%d, 1], shape[0] is not %d.",
|
|
|
|
|
M, D, atten_w_dims[0], M + D));
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
|
|
|
|
|