diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 461e5bd2d3..50eeadab72 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -34,8 +34,11 @@ class LookupTableOp : public framework::OperatorWithKernel { auto ids_dims = ctx->GetInputDim("Ids"); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W + // and it must be a column vector with rank = 2 while the 2nd dimension + // size must be 1, when Ids's type is SelectedRows, the rows of Ids + // contains the ids to be looked up in W; if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[1], 1); @@ -59,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("W", - "An input represents embedding tensors, " + "(Tensor) The input represents embedding tensors, " "which is a learnable parameter."); - AddInput("Ids", - "An input with type int32 or int64 " - "contains the ids to be looked up in W. " - "Ids must be a column vector with rank = 2. " - "The 2nd dimension size must be 1."); - AddOutput("Out", "The lookup results, which have the same type as W."); + AddInput( + "Ids", + "(Tensor or SelectedRows) Ids's type can be Tensor or " + "SelectedRows, when Ids's type is Tensor, this tensor contains " + "the ids to be looked up in W and it must be a column vector with " + "rank = 2 while the 2nd dimension size must be 1; when Ids's type is " + "SelectedRows, the rows of Ids contains the ids to be looked up " + "in W."); + AddOutput("Out", + "(Tensor or SelectedRows) The lookup results, which have the " + "same type as W."); AddAttr("is_sparse", "(boolean, default false) " - "Sparse update") + "Sparse update.") .SetDefault(false); AddAttr("padding_idx", "(int64, default -1) " @@ -81,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { Lookup Table Operator. This operator is used to perform lookups on the parameter W, -then concatenated into a dense tensor. +then concatenated into a dense or sparse tensor. + +The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's +type is SelectedRows, the rows of Ids contains the ids to be looked up in W; +when Ids's type is Tensor, this tensor contains the ids to be looked up in W +and it must be a column vector with rank = 2 while the 2nd dimension size must be 1, +at this time, Ids can carry the LoD (Level of Details) information, or not, and +the output only shares the LoD information with input Ids. -The input Ids can carry the LoD (Level of Details) information, -or not. And the output only shares the LoD information with input Ids. )DOC"); } diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index f314fdbbff..6d81fccd20 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -75,22 +75,22 @@ class LookupTableCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* table_t = context.Input("W"); int64_t padding_idx = context.Attr("padding_idx"); - auto* ids_var = context.InputVar("Ids"); // int tensor + auto* ids_var = context.InputVar("Ids"); + Tensor* output_t = context.Output("Out"); int64_t* ids; int64_t K; - framework::Tensor* output_t; - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out"); // float tensor ids = const_cast(ids_t->data()); K = ids_t->numel(); } else if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out")->mutable_value(); ids = const_cast(ids_t->rows().CUDAData(context.GetPlace())); K = ids_t->rows().size(); output_t->Resize({K, table_t->dims()[1]}); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 4495b4e9e2..c92ce78eef 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -30,23 +30,23 @@ template class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* table_t = context.Input("W"); // float tensor - auto* ids_var = context.InputVar("Ids"); // int tensor + auto* table_t = context.Input("W"); + auto* ids_var = context.InputVar("Ids"); + Tensor* output_t = context.Output("Out"); int64_t* ids; int64_t ids_numel; - Tensor* output_t; - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out"); ids = const_cast(ids_t->data()); ids_numel = ids_t->numel(); } else if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out")->mutable_value(); ids = const_cast(ids_t->rows().data()); ids_numel = ids_t->rows().size(); output_t->Resize({ids_numel, table_t->dims()[1]});