|
|
|
@ -27,20 +27,28 @@ class LookupTableOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("W"),
|
|
|
|
|
"Input(W) of LookupTableOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"),
|
|
|
|
|
"Input(Ids) of LookupTableOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of LookupTableOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
|
|
|
|
|
"Input(W) of LookupTableOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true,
|
|
|
|
|
"Input(Ids) of LookupTableOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output(Out) of LookupTableOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto table_dims = ctx->GetInputDim("W");
|
|
|
|
|
auto ids_dims = ctx->GetInputDim("Ids");
|
|
|
|
|
int ids_rank = ids_dims.size();
|
|
|
|
|
VLOG(5) << "ids rank is " << ids_rank << std::endl;
|
|
|
|
|
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
|
|
|
|
|
"The last dimension of the 'Ids' tensor must be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
table_dims.size(), 2,
|
|
|
|
|
"ShapeError: The dimensions of the 'lookup table' must be 2. "
|
|
|
|
|
"But received lookup table's dimensions = %d, "
|
|
|
|
|
"lookup table's shape = [%s].",
|
|
|
|
|
table_dims.size(), table_dims);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ids_dims[ids_rank - 1], 1,
|
|
|
|
|
"ShapeError: The last dimensions of the 'Ids' tensor must be 1. "
|
|
|
|
|
"But received Ids's last dimensions = %d, Ids's shape = [%s].",
|
|
|
|
|
ids_dims[ids_rank - 1], ids_dims);
|
|
|
|
|
|
|
|
|
|
auto output_dims =
|
|
|
|
|
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
|
|
|
|
|