|
|
|
@ -24,30 +24,47 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("W"),
|
|
|
|
|
"Input W of FusedEmbeddingSeqPoolOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"),
|
|
|
|
|
"Input Ids of FusedEmbeddingSeqPoolOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output of FusedEmbeddingSeqPoolOp should not be null.");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "FusedEmbeddingSeqPool");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids",
|
|
|
|
|
"FusedEmbeddingSeqPool");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FusedEmbeddingSeqPool");
|
|
|
|
|
auto table_dims = ctx->GetInputDim("W");
|
|
|
|
|
auto ids_dims = ctx->GetInputDim("Ids");
|
|
|
|
|
const std::string& combiner = ctx->Attrs().Get<std::string>("combiner");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids_dims.size(), 1,
|
|
|
|
|
"The dim size of the 'Ids' tensor must greater than 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[ids_dims.size() - 1], 1,
|
|
|
|
|
"The last dimension of the 'Ids' tensor must be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dim size of the input tensor 'W' should be 2. "
|
|
|
|
|
"But received W's size = %d.",
|
|
|
|
|
table_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
ids_dims.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dim size of the input tensor 'Ids' should be greater "
|
|
|
|
|
"than or equal to 1. But received Ids's size = %d.",
|
|
|
|
|
ids_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ids_dims[ids_dims.size() - 1], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The last dimension of the input tensor 'Ids' should be 1. "
|
|
|
|
|
"But received Ids's size in the last dimension = %d.",
|
|
|
|
|
ids_dims[ids_dims.size() - 1]));
|
|
|
|
|
// we only support sum now
|
|
|
|
|
PADDLE_ENFORCE_EQ(combiner, "sum");
|
|
|
|
|
PADDLE_ENFORCE_EQ(combiner, "sum",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"The pooling type of sequence_pool only support sum "
|
|
|
|
|
"now. So the 'combiner' must be 'sum'."));
|
|
|
|
|
|
|
|
|
|
int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims);
|
|
|
|
|
// in compile time, the lod level of ids must be 1
|
|
|
|
|
framework::VarDesc* ids_desc =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"In compile time, the LoD Level of Ids should be 1. "
|
|
|
|
|
"But received the LoD Level of Ids = %d.",
|
|
|
|
|
ids_desc->GetLoDLevel()));
|
|
|
|
|
|
|
|
|
|
// in compile time, the shape from Ids -> output
|
|
|
|
|
// should be [-1, 1] -> [-1, embedding_size]
|
|
|
|
|