|
|
|
|
@ -14,6 +14,9 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
@ -25,14 +28,35 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
using DDim = framework::DDim;
|
|
|
|
|
|
|
|
|
|
static const int64_t kNoPadding = -1;
|
|
|
|
|
|
|
|
|
|
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
|
|
|
|
|
auto it = std::find(rows.begin(), rows.end(), value);
|
|
|
|
|
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
|
|
|
|
|
return std::distance(rows.begin(), it);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
|
auto* table_t = context.Input<LoDTensor>("W");
|
|
|
|
|
auto *table_var = context.InputVar("W");
|
|
|
|
|
auto *ids_var = context.InputVar("Ids");
|
|
|
|
|
Tensor *output_t = context.Output<Tensor>("Out");
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
DDim table_dim;
|
|
|
|
|
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
table_dim = context.Input<LoDTensor>("W")->dims();
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto *table_t = context.Input<SelectedRows>("W");
|
|
|
|
|
table_dim = table_t->value().dims();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("table only support LoDTensor and SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t *ids;
|
|
|
|
|
int64_t ids_numel;
|
|
|
|
|
@ -49,32 +73,43 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *ids_t = context.Input<SelectedRows>("Ids");
|
|
|
|
|
ids = const_cast<int64_t *>(ids_t->rows().data());
|
|
|
|
|
ids_numel = ids_t->rows().size();
|
|
|
|
|
output_t->Resize({ids_numel, table_t->dims()[1]});
|
|
|
|
|
output_t->Resize({ids_numel, table_dim[1]});
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unsupported Variable Type of Ids");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
auto *table_t = context.Input<LoDTensor>("W");
|
|
|
|
|
int64_t row_number = table_t->dims()[0];
|
|
|
|
|
int64_t row_width = table_t->dims()[1];
|
|
|
|
|
|
|
|
|
|
int N = table_t->dims()[0];
|
|
|
|
|
int D = table_t->dims()[1];
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
if (padding_idx == -1) {
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], row_number);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
|
memcpy(output + i * row_width, table + ids[i] * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
const auto &table_t = table_var->Get<SelectedRows>();
|
|
|
|
|
int64_t row_width = table_t.value().dims()[1];
|
|
|
|
|
const auto *table = table_t.value().data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
if (ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * D, 0, D * sizeof(T));
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
|
auto id_index = getIndex(table_t.rows(), ids[i]);
|
|
|
|
|
memcpy(output + i * row_width, table + id_index * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|