|
|
|
@ -23,8 +23,12 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -43,44 +47,64 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
|
|
|
|
|
auto *table_var = context.InputVar("W");
|
|
|
|
|
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
|
|
|
|
|
int64_t ids_numel = ids_t->numel();
|
|
|
|
|
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
auto *table_t = context.Input<LoDTensor>("W");
|
|
|
|
|
int64_t row_number = table_t->dims()[0];
|
|
|
|
|
int64_t row_width = table_t->dims()[1];
|
|
|
|
|
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], row_number);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
|
|
|
|
|
memcpy(output + i * row_width, table + ids[i] * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
auto id_name = context.Inputs("Ids").front();
|
|
|
|
|
auto out_name = context.Outputs("Out").front();
|
|
|
|
|
auto table_name = context.Inputs("W").front();
|
|
|
|
|
auto epmap = context.Attr<std::vector<std::string>>("epmap");
|
|
|
|
|
auto height_sections =
|
|
|
|
|
context.Attr<std::vector<int64_t>>("height_sections");
|
|
|
|
|
|
|
|
|
|
if (!epmap.empty()) {
|
|
|
|
|
// if emap is not empty, then the paramter will be fetched from remote parameter
|
|
|
|
|
// server
|
|
|
|
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
|
|
|
|
operators::distributed::prefetch(id_name, out_name, table_name, epmap,
|
|
|
|
|
height_sections, context);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"paddle is not compiled with distribute support, can not do "
|
|
|
|
|
"parameter prefetch!");
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
|
|
|
|
|
int64_t ids_numel = ids_t->numel();
|
|
|
|
|
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
auto *table_t = context.Input<LoDTensor>("W");
|
|
|
|
|
int64_t row_number = table_t->dims()[0];
|
|
|
|
|
int64_t row_width = table_t->dims()[1];
|
|
|
|
|
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], row_number);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
|
|
|
|
|
memcpy(output + i * row_width, table + ids[i] * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
const auto &table_t = table_var->Get<SelectedRows>();
|
|
|
|
|
int64_t row_width = table_t.value().dims()[1];
|
|
|
|
|
const auto *table = table_t.value().data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
auto id_index = table_t.Index(ids[i]);
|
|
|
|
|
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
|
|
|
|
|
blas.VCOPY(row_width, table + id_index * row_width,
|
|
|
|
|
output + i * row_width);
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
const auto &table_t = table_var->Get<SelectedRows>();
|
|
|
|
|
int64_t row_width = table_t.value().dims()[1];
|
|
|
|
|
const auto *table = table_t.value().data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
auto id_index = table_t.Index(ids[i]);
|
|
|
|
|
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
|
|
|
|
|
blas.VCOPY(row_width, table + id_index * row_width,
|
|
|
|
|
output + i * row_width);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|