|
|
|
@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto table_dims = ctx->GetInputDim("W");
|
|
|
|
|
auto ids_dims = ctx->GetInputDim("Ids");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
|
|
|
|
|
ctx->ShareLoD("Ids", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
" which is a learnable parameter.");
|
|
|
|
|
AddInput("Ids",
|
|
|
|
|
"An input with type int32 or int64"
|
|
|
|
|
"contains the ids to be looked up in W.");
|
|
|
|
|
"contains the ids to be looked up in W."
|
|
|
|
|
"Ids must be a column vector with rank = 2."
|
|
|
|
|
"The 2nd dimension size must be 1");
|
|
|
|
|
AddOutput("Out", "The lookup results, which have the same type with W.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
This operator is used to perform lookups on the parameter W,
|
|
|
|
|