|
|
|
@ -21,9 +21,11 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
|
|
|
|
|
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
|
|
|
|
|
bool PaddingFlag>
|
|
|
|
|
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
|
|
|
|
|
const int64_t N, const int64_t K, const int64_t D) {
|
|
|
|
|
const int64_t N, const int64_t K, const int64_t D,
|
|
|
|
|
const int64_t padding_idx) {
|
|
|
|
|
int idx = threadIdx.x;
|
|
|
|
|
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
|
|
|
|
|
|
|
|
@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
|
|
|
|
|
T* out = output + idy * D;
|
|
|
|
|
const T* tab = table + id * D;
|
|
|
|
|
for (int i = idx; i < D; i += BlockDimX) {
|
|
|
|
|
out[i] = tab[i];
|
|
|
|
|
if (PaddingFlag) {
|
|
|
|
|
if (idx == padding_idx)
|
|
|
|
|
out[i] = static_cast<T>(0);
|
|
|
|
|
else
|
|
|
|
|
out[i] = tab[i];
|
|
|
|
|
} else {
|
|
|
|
|
out[i] = tab[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
idy += BlockDimY * GridDimX;
|
|
|
|
|
}
|
|
|
|
@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* table_t = context.Input<LoDTensor>("W");
|
|
|
|
|
auto* ids_t = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto* output_t = context.Output<LoDTensor>("Out");
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
size_t N = table_t->dims()[0];
|
|
|
|
|
size_t D = table_t->dims()[1];
|
|
|
|
@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 128, 8,
|
|
|
|
|
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D);
|
|
|
|
|
|
|
|
|
|
if (padding_idx == -1)
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 128, 8, 8,
|
|
|
|
|
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
else
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 128, 8, 8,
|
|
|
|
|
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
|
|
|
|
// Since paddings are not trainable and fixed in forward, the gradient of
|
|
|
|
|
// paddings makes no sense and we don't deal with it in backward.
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|
auto* ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto* table = context.Input<LoDTensor>("W");
|
|
|
|
|