|
|
@ -367,8 +367,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
PADDLE_ENFORCE_LT(ids_data[i], row_number);
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i);
|
|
|
|
ids_data[i], row_number,
|
|
|
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
|
|
|
"Value of Ids %d should less than dict size %d.", i, row_number));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ids_data[i], 0,
|
|
|
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
|
|
|
"Value of Ids %d should greater than ZERO.", i));
|
|
|
|
memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
|
|
|
|
memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
|
|
|
|
row_width * sizeof(T));
|
|
|
|
row_width * sizeof(T));
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -473,8 +478,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
|
|
|
PADDLE_ENFORCE_LT(ids_data[i], row_number);
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i);
|
|
|
|
ids_data[i], row_number,
|
|
|
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
|
|
|
"Value of Ids %d should less than dict size %d.", i, row_number));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ids_data[i], 0,
|
|
|
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
|
|
|
"Value of Ids %d should greater than ZERO.", i));
|
|
|
|
memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
|
|
|
|
memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
|
|
|
|
row_width * sizeof(T));
|
|
|
|
row_width * sizeof(T));
|
|
|
|
}
|
|
|
|
}
|
|
|
|