|
|
|
@ -32,6 +32,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
|
|
|
|
|
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
|
|
|
|
|
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
int N = table_t->dims()[0];
|
|
|
|
|
int D = table_t->dims()[1];
|
|
|
|
@ -39,9 +40,13 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* table = table_t->data<T>();
|
|
|
|
|
auto* output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
for (int64_t i = 0; i < ids_t->numel(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
|
if (ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * D, 0, D * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -51,6 +56,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|
auto* ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto* table = context.Input<LoDTensor>("W");
|
|
|
|
@ -63,6 +70,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
framework::Vector<int64_t> new_rows;
|
|
|
|
|
new_rows.reserve(ids_dim[0]);
|
|
|
|
|
for (int64_t i = 0; i < ids_dim[0]; i++) {
|
|
|
|
|
if (ids_data[i] == padding_idx)
|
|
|
|
|
continue; // Paddings are not trainable and the gradient are not
|
|
|
|
|
// necessary.
|
|
|
|
|
new_rows.push_back(ids_data[i]);
|
|
|
|
|
}
|
|
|
|
|
d_table->set_rows(new_rows);
|
|
|
|
@ -96,6 +106,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
memset(d_table_data, 0, d_table->numel() * sizeof(T));
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < ids->numel(); ++i) {
|
|
|
|
|
if (ids_data[i] == padding_idx)
|
|
|
|
|
continue; // Paddings are not trainable and the gradient are not
|
|
|
|
|
// necessary.
|
|
|
|
|
PADDLE_ENFORCE_LT(ids_data[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids_data[i], 0);
|
|
|
|
|
for (int j = 0; j < D; ++j) {
|
|
|
|
|