|
|
|
@ -24,68 +24,94 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusedEmbeddingFCLSTMOp::InferShape(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Embeddings"),
|
|
|
|
|
"Assert only one Input(Embeddings) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
|
|
|
|
|
"Assert only one Input(WeightH) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Assert only one Output(Hidden) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
|
"Assert only one Output(Cell) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"),
|
|
|
|
|
"Input(Ids) of LookupTableOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Embeddings"), "Input", "Embeddings",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
|
|
|
|
|
auto table_dims = ctx->GetInputDim("Embeddings");
|
|
|
|
|
auto ids_dims = ctx->GetInputDim("Ids");
|
|
|
|
|
int ids_rank = ids_dims.size();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
table_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Embeddings's rank should be 2, but received value is:%d.",
|
|
|
|
|
table_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
|
|
|
|
|
"The last dimension of the 'Ids' tensor must be 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The last dimension of the 'Ids' tensor must be 1, but "
|
|
|
|
|
"received value is:%d.",
|
|
|
|
|
ids_dims[ids_rank - 1]));
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("Ids");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(Ids)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Ids)'s rank must be 2, but received value is:%d.",
|
|
|
|
|
x_dims.size()));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
|
|
|
|
"Input(Cell) and Input(Hidden) of LSTM should not "
|
|
|
|
|
"be null at the same time.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("C0"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Cell) and Input(Hidden) of LSTM should exist "
|
|
|
|
|
"at the same time."));
|
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
h_dims, c_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
|
"should be the same, but received H0 dim is:[%s], C0 dim is[%s]",
|
|
|
|
|
h_dims, c_dims));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto embeddings_dims = ctx->GetInputDim("Embeddings");
|
|
|
|
|
PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2,
|
|
|
|
|
"The rank of Input(Embeddings) should be 2.");
|
|
|
|
|
|
|
|
|
|
auto wh_dims = ctx->GetInputDim("WeightH");
|
|
|
|
|
int frame_size = wh_dims[1] / 4;
|
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
|
|
|
|
|
"The rank of Input(WeightH) should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
wh_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(WeightH) should be 2, but received value is:%d.",
|
|
|
|
|
wh_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
|
|
|
|
|
"The first dimension of Input(WeightH) "
|
|
|
|
|
"should be %d.",
|
|
|
|
|
frame_size);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(WeightH) should equal to "
|
|
|
|
|
"frame size:%d, but received value is:%d.",
|
|
|
|
|
frame_size, wh_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
|
|
|
|
|
"The second dimension of Input(WeightH) "
|
|
|
|
|
"should be 4 * %d.",
|
|
|
|
|
frame_size);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(WeightH) should equal "
|
|
|
|
|
"to 4 * %d, but received value is:%d.",
|
|
|
|
|
frame_size, wh_dims[1]));
|
|
|
|
|
|
|
|
|
|
auto b_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
|
"The first dimension of Input(Bias) should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(Bias) should be 2, but received value is:%d.",
|
|
|
|
|
b_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(Bias) "
|
|
|
|
|
"should be 1, but received value is:%d.",
|
|
|
|
|
b_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
|
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
|
"7 * %d if enable peepholes connection or"
|
|
|
|
|
"4 * %d if disable peepholes",
|
|
|
|
|
frame_size, frame_size);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
|
"7 * %d if enable peepholes connection or"
|
|
|
|
|
"4 * %d if disable peepholes, bias dim is:%d, use_peepholes:%d",
|
|
|
|
|
frame_size, frame_size, b_dims[1],
|
|
|
|
|
ctx->Attrs().Get<bool>("use_peepholes")));
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims({x_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
@ -93,16 +119,17 @@ void FusedEmbeddingFCLSTMOp::InferShape(
|
|
|
|
|
ctx->ShareLoD("Ids", "Hidden");
|
|
|
|
|
ctx->ShareLoD("Ids", "Cell");
|
|
|
|
|
if (!ctx->Attrs().Get<bool>("use_seq")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
|
|
|
|
|
"Assert only one Output(BatchedInput) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
|
|
|
|
|
"Assert only one Output(BatchedHidden) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
|
|
|
|
|
"Assert only one Output(BatchedCell) of LSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
|
|
|
|
|
"Assert only one Output(ReorderedH0) of LSTM");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
|
|
|
|
|
"Assert only one Output(ReorderedC0) of LSTM.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"), "Output", "BatchedHidden",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchedCell"), "Output", "BatchedCell",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0",
|
|
|
|
|
"fused_embedding_fc_lstm");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("BatchedInput", {x_dims[0], wh_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("BatchedHidden", out_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchedCell", out_dims);
|
|
|
|
|