|
|
|
@ -14,6 +14,9 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
@ -25,16 +28,37 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
using DDim = framework::DDim;
|
|
|
|
|
|
|
|
|
|
static constexpr int64_t kNoPadding = -1;
|
|
|
|
|
|
|
|
|
|
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
|
|
|
|
|
auto it = std::find(rows.begin(), rows.end(), value);
|
|
|
|
|
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
|
|
|
|
|
return static_cast<size_t>(std::distance(rows.begin(), it));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* table_t = context.Input<LoDTensor>("W");
|
|
|
|
|
auto* ids_var = context.InputVar("Ids");
|
|
|
|
|
Tensor* output_t = context.Output<Tensor>("Out");
|
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
|
auto *table_var = context.InputVar("W");
|
|
|
|
|
auto *ids_var = context.InputVar("Ids");
|
|
|
|
|
Tensor *output_t = context.Output<Tensor>("Out");
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
DDim table_dim;
|
|
|
|
|
|
|
|
|
|
int64_t* ids;
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
table_dim = context.Input<LoDTensor>("W")->dims();
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto *table_t = context.Input<SelectedRows>("W");
|
|
|
|
|
table_dim = table_t->value().dims();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("table only support LoDTensor and SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t *ids;
|
|
|
|
|
int64_t ids_numel;
|
|
|
|
|
|
|
|
|
|
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
|
|
|
|
@ -42,39 +66,50 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
// when Ids's type is SelectedRows, the rows of Ids contains the
|
|
|
|
|
// ids to be looked up in W.
|
|
|
|
|
if (ids_var->IsType<LoDTensor>()) {
|
|
|
|
|
auto* ids_t = context.Input<LoDTensor>("Ids");
|
|
|
|
|
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
|
|
|
|
|
auto *ids_t = context.Input<LoDTensor>("Ids");
|
|
|
|
|
ids = const_cast<int64_t *>(ids_t->data<int64_t>());
|
|
|
|
|
ids_numel = ids_t->numel();
|
|
|
|
|
} else if (ids_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto* ids_t = context.Input<SelectedRows>("Ids");
|
|
|
|
|
ids = const_cast<int64_t*>(ids_t->rows().data());
|
|
|
|
|
auto *ids_t = context.Input<SelectedRows>("Ids");
|
|
|
|
|
ids = const_cast<int64_t *>(ids_t->rows().data());
|
|
|
|
|
ids_numel = ids_t->rows().size();
|
|
|
|
|
output_t->Resize({ids_numel, table_t->dims()[1]});
|
|
|
|
|
output_t->Resize({ids_numel, table_dim[1]});
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unsupported Variable Type of Ids");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
auto *table_t = context.Input<LoDTensor>("W");
|
|
|
|
|
int64_t row_number = table_t->dims()[0];
|
|
|
|
|
int64_t row_width = table_t->dims()[1];
|
|
|
|
|
|
|
|
|
|
int N = table_t->dims()[0];
|
|
|
|
|
int D = table_t->dims()[1];
|
|
|
|
|
auto* table = table_t->data<T>();
|
|
|
|
|
auto* output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
if (padding_idx == -1) {
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], row_number);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * row_width, table + ids[i] * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
const auto &table_t = table_var->Get<SelectedRows>();
|
|
|
|
|
int64_t row_width = table_t.value().dims()[1];
|
|
|
|
|
const auto *table = table_t.value().data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
if (ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * D, 0, D * sizeof(T));
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
|
auto id_index = getIndex(table_t.rows(), ids[i]);
|
|
|
|
|
memcpy(output + i * row_width, table + id_index * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -84,17 +119,27 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
|
auto *table_var = context.InputVar("W");
|
|
|
|
|
DDim table_dim;
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
table_dim = context.Input<LoDTensor>("W")->dims();
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto *table_t = context.Input<SelectedRows>("W");
|
|
|
|
|
table_dim = table_t->value().dims();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("table only support LoDTensor and SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
|
|
|
|
// Since paddings are not trainable and fixed in forward, the gradient of
|
|
|
|
|
// paddings makes no sense and we don't deal with it in backward.
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|
auto* ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto* table = context.Input<LoDTensor>("W");
|
|
|
|
|
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
|
|
|
|
|
auto *ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
|
|
|
|
|
|
|
|
|
|
auto* ids_data = ids->data<int64_t>();
|
|
|
|
|
auto *ids_data = ids->data<int64_t>();
|
|
|
|
|
auto ids_dim = ids->dims();
|
|
|
|
|
|
|
|
|
|
framework::Vector<int64_t> new_rows;
|
|
|
|
@ -104,31 +149,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
d_table->set_rows(new_rows);
|
|
|
|
|
|
|
|
|
|
auto* d_table_value = d_table->mutable_value();
|
|
|
|
|
d_table_value->Resize({ids_dim[0], table->dims()[1]});
|
|
|
|
|
auto *d_table_value = d_table->mutable_value();
|
|
|
|
|
d_table_value->Resize({ids_dim[0], table_dim[1]});
|
|
|
|
|
d_table_value->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
d_table->set_height(table->dims()[0]);
|
|
|
|
|
d_table->set_height(table_dim[0]);
|
|
|
|
|
|
|
|
|
|
auto* d_output_data = d_output->data<T>();
|
|
|
|
|
auto* d_table_data = d_table_value->data<T>();
|
|
|
|
|
auto *d_output_data = d_output->data<T>();
|
|
|
|
|
auto *d_table_data = d_table_value->data<T>();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
|
|
|
|
|
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
|
|
|
|
|
} else {
|
|
|
|
|
auto* ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
|
|
|
|
|
auto* table = context.Input<LoDTensor>("W");
|
|
|
|
|
auto *ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
|
|
|
|
|
|
|
|
|
|
auto* ids_data = ids->data<int64_t>();
|
|
|
|
|
auto *ids_data = ids->data<int64_t>();
|
|
|
|
|
auto ids_dim = ids->dims();
|
|
|
|
|
|
|
|
|
|
int N = table->dims()[0];
|
|
|
|
|
int N = table_dim[0];
|
|
|
|
|
int D = d_output->dims()[1];
|
|
|
|
|
|
|
|
|
|
auto* d_output_data = d_output->data<T>();
|
|
|
|
|
auto* d_table_data = d_table->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto *d_output_data = d_output->data<T>();
|
|
|
|
|
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
memset(d_table_data, 0, d_table->numel() * sizeof(T));
|
|
|
|
|
|
|
|
|
|