|
|
|
@ -103,7 +103,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
|
auto id_index = table_t.index(ids[i]);
|
|
|
|
|
auto id_index = table_t.Index(ids[i]);
|
|
|
|
|
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
|
|
|
|
|
memcpy(output + i * row_width, table + id_index * row_width,
|
|
|
|
|
row_width * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|