|
|
|
@ -105,9 +105,24 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
dim3 threads(64, 4);
|
|
|
|
|
#else
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
#endif // PADDLE_WITH_HIP
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
if (padding_idx == -1)
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 64, 4, 8,
|
|
|
|
|
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
else
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 64, 4, 8,
|
|
|
|
|
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
#else
|
|
|
|
|
if (padding_idx == -1)
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 128, 8, 8,
|
|
|
|
@ -118,6 +133,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
T, 128, 8, 8,
|
|
|
|
|
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
#endif // PADDLE_WITH_HIP
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -185,10 +201,20 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
|
|
|
|
|
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
dim3 threads(64, 4);
|
|
|
|
|
#else
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
#endif // PADDLE_WITH_HIP
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
LookupTableGrad<T, 64, 4, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
|
|
|
|
|
d_table, d_output, ids, N, K, D);
|
|
|
|
|
#else
|
|
|
|
|
LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
|
|
|
|
|
d_table, d_output, ids, N, K, D);
|
|
|
|
|
#endif // PADDLE_WITH_HIP
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|