|
|
|
@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel {
|
|
|
|
|
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
|
|
|
|
|
auto output_t = context.Output<Tensor>("Out"); // float tensor
|
|
|
|
|
|
|
|
|
|
size_t N = table_t->dims()[0];
|
|
|
|
|
size_t D = table_t->dims()[1];
|
|
|
|
|
int N = table_t->dims()[0];
|
|
|
|
|
int D = table_t->dims()[1];
|
|
|
|
|
auto ids = ids_t->data<int32_t>();
|
|
|
|
|
auto table = table_t->data<T>();
|
|
|
|
|
auto output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
|
|
|
|
|
for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
|
|
|
@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel {
|
|
|
|
|
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
|
|
|
|
|
|
|
|
|
|
size_t N = d_table_t->dims()[0];
|
|
|
|
|
size_t D = d_table_t->dims()[1];
|
|
|
|
|
int N = d_table_t->dims()[0];
|
|
|
|
|
int D = d_table_t->dims()[1];
|
|
|
|
|
auto ids = ids_t->data<int32_t>();
|
|
|
|
|
const T* d_output = d_output_t->data<T>();
|
|
|
|
|
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
|
|
|
|
@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel {
|
|
|
|
|
t.device(context.GetEigenDevice<platform::CPUPlace>()) =
|
|
|
|
|
t.constant(static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
|
|
|
|
|
for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
for (size_t j = 0; j < D; ++j) {
|
|
|
|
|
for (int j = 0; j < D; ++j) {
|
|
|
|
|
d_table[ids[i] * D + j] += d_output[i * D + j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|