|
|
|
@ -23,6 +23,9 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("W"),
|
|
|
|
|
"Input W of FusedEmbeddingSeqPoolOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"),
|
|
|
|
@ -42,36 +45,15 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
// we only support sum now
|
|
|
|
|
PADDLE_ENFORCE_EQ(combiner, "sum");
|
|
|
|
|
|
|
|
|
|
int64_t last_dim = table_dims[1];
|
|
|
|
|
for (int i = 1; i != ids_dims.size(); ++i) {
|
|
|
|
|
last_dim *= ids_dims[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
framework::Variable* ids_var =
|
|
|
|
|
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
|
|
|
|
|
const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
|
|
|
|
|
int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims);
|
|
|
|
|
// in compile time, the lod level of ids must be 1
|
|
|
|
|
framework::VarDesc* ids_desc =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
|
|
|
|
|
|
|
|
|
|
// in run time, the LoD of ids must be 1
|
|
|
|
|
PADDLE_ENFORCE(ids_lod.size(), 1u,
|
|
|
|
|
"The LoD level of Input(Ids) must be 1");
|
|
|
|
|
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
|
|
|
|
|
|
|
|
|
|
int64_t batch_size = ids_lod[0].size() - 1;
|
|
|
|
|
|
|
|
|
|
// in run time, the shape from Ids -> output
|
|
|
|
|
// should be [seq_length, 1] -> [batch_size, embedding_size]
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim}));
|
|
|
|
|
} else {
|
|
|
|
|
// in compile time, the lod level of ids must be 1
|
|
|
|
|
framework::VarDesc* ids_desc =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
|
|
|
|
|
|
|
|
|
|
// in compile time, the shape from Ids -> output
|
|
|
|
|
// should be [-1, 1] -> [-1, embedding_size]
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
|
|
|
|
|
}
|
|
|
|
|
// in compile time, the shape from Ids -> output
|
|
|
|
|
// should be [-1, 1] -> [-1, embedding_size]
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|