|
|
|
@ -56,7 +56,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
|
|
|
|
|
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
|
|
|
|
|
"The type of W var should be SelectedRows.");
|
|
|
|
|
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
|
|
|
|
|
"The type of Ids var should be SelectedRows.");
|
|
|
|
|
"The type of Ids var should be LoDTensor.");
|
|
|
|
|
auto &ids_t = ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto out_t = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto w_t = w_var->GetMutable<framework::SelectedRows>();
|
|
|
|
@ -111,10 +111,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(SelectedRows) The input represents embedding table, "
|
|
|
|
|
"which is a learnable parameter.");
|
|
|
|
|
AddInput("Ids",
|
|
|
|
|
"(SelectedRows) Ids's type should be SelectedRows "
|
|
|
|
|
"the rows of Ids contains the Ids to be looked up in W.");
|
|
|
|
|
"(LoDTensor) Ids's type should be LoDTensor"
|
|
|
|
|
"THe ids to be looked up in W.");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(SelectedRows) The lookup results, which have the "
|
|
|
|
|
"(LoDTensor) The lookup results, which have the "
|
|
|
|
|
"same type as W.");
|
|
|
|
|
AddAttr<int64_t>("padding_idx",
|
|
|
|
|
"(int64, default -1) "
|
|
|
|
|