|
|
|
@ -85,8 +85,18 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], row_number);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
ids[i], row_number,
|
|
|
|
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
|
|
|
|
"expected >= 0 and < %ld, but got %ld. Please check input "
|
|
|
|
|
"value.",
|
|
|
|
|
row_number, ids[i]);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
ids[i], 0,
|
|
|
|
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
|
|
|
|
"expected >= 0 and < %ld, but got %ld. Please check input "
|
|
|
|
|
"value.",
|
|
|
|
|
row_number, ids[i]);
|
|
|
|
|
memcpy(output + i * row_width, table + ids[i] * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
@ -181,8 +191,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto *ids_data = ids->data<int64_t>();
|
|
|
|
|
|
|
|
|
|
int N = table_dim[0];
|
|
|
|
|
int D = table_dim[1];
|
|
|
|
|
int64_t N = table_dim[0];
|
|
|
|
|
int64_t D = table_dim[1];
|
|
|
|
|
|
|
|
|
|
auto *d_output_data = d_output->data<T>();
|
|
|
|
|
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
|
|
|
|
@ -194,8 +204,16 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// the gradient of padding_idx should be 0, already done by memset, so
|
|
|
|
|
// do nothing.
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids_data[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids_data[i], 0);
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
ids_data[i], N,
|
|
|
|
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
|
|
|
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
|
|
|
|
N, ids_data[i]);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
ids_data[i], 0,
|
|
|
|
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
|
|
|
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
|
|
|
|
N, ids_data[i]);
|
|
|
|
|
for (int j = 0; j < D; ++j) {
|
|
|
|
|
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
|
|
|
|
|
}
|
|
|
|
|