|
|
|
@ -22,6 +22,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
|
|
|
|
@ -29,25 +30,46 @@ template <typename T>
|
|
|
|
|
class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
|
|
|
|
|
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
|
|
|
|
|
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
|
|
|
|
|
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
|
|
|
|
|
auto* ids_var = context.InputVar("Ids"); // int tensor
|
|
|
|
|
|
|
|
|
|
int64_t* ids;
|
|
|
|
|
int64_t ids_numel;
|
|
|
|
|
Tensor* output_t;
|
|
|
|
|
|
|
|
|
|
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
|
|
|
|
|
// Maybe near future we will add concat_rows op.
|
|
|
|
|
if (ids_var->IsType<LoDTensor>()) {
|
|
|
|
|
auto* ids_t = context.Input<LoDTensor>("Ids");
|
|
|
|
|
output_t = context.Output<LoDTensor>("Out");
|
|
|
|
|
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
|
|
|
|
|
ids_numel = ids_t->numel();
|
|
|
|
|
} else if (ids_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto* ids_t = context.Input<SelectedRows>("Ids");
|
|
|
|
|
output_t =
|
|
|
|
|
const_cast<Tensor*>(&(context.Output<SelectedRows>("Out")->value()));
|
|
|
|
|
ids = const_cast<int64_t*>(ids_t->rows().data());
|
|
|
|
|
ids_numel = ids_t->rows().size();
|
|
|
|
|
output_t->Resize({ids_numel, table_t->dims()[1]});
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unsupported Variable Type of Ids");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
int N = table_t->dims()[0];
|
|
|
|
|
int D = table_t->dims()[1];
|
|
|
|
|
auto* ids = ids_t->data<int64_t>();
|
|
|
|
|
auto* table = table_t->data<T>();
|
|
|
|
|
auto* output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
if (padding_idx == -1) {
|
|
|
|
|
for (int64_t i = 0; i < ids_t->numel(); ++i) {
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int64_t i = 0; i < ids_t->numel(); ++i) {
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
if (ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * D, 0, D * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|