|
|
|
@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
using DDim = framework::DDim;
|
|
|
|
|
|
|
|
|
|
static constexpr int64_t kNoPadding = -1;
|
|
|
|
|
|
|
|
|
|
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
|
|
|
|
|
auto it = std::find(rows.begin(), rows.end(), value);
|
|
|
|
|
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
|
|
|
|
|
return static_cast<size_t>(std::distance(rows.begin(), it));
|
|
|
|
|
}
|
|
|
|
|
constexpr int64_t kNoPadding = -1;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *table_t = context.Input<SelectedRows>("W");
|
|
|
|
|
table_dim = table_t->value().dims();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("table only support LoDTensor and SelectedRows");
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"The parameter W of a LookupTable "
|
|
|
|
|
"must be either LoDTensor or SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t *ids;
|
|
|
|
@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
auto id_index = getIndex(table_t.rows(), ids[i]);
|
|
|
|
|
auto id_index = table_t.index(ids[i]);
|
|
|
|
|
memcpy(output + i * row_width, table + id_index * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *table_t = context.Input<SelectedRows>("W");
|
|
|
|
|
table_dim = table_t->value().dims();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("table only support LoDTensor and SelectedRows");
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"The parameter W of a LookupTable "
|
|
|
|
|
"must be either LoDTensor or SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
|
|
|
|