|
|
|
@ -23,7 +23,7 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename T, int blockDimX, int blockDimY, int gridDimX>
|
|
|
|
|
__global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
|
|
|
|
|
__global__ void LookupTable(T* output, const T* table, const int32_t* ids,
|
|
|
|
|
const int N, const int K, const int D) {
|
|
|
|
|
int idx = threadIdx.x;
|
|
|
|
|
int idy = blockIdx.x + threadIdx.y * gridDimX;
|
|
|
|
@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
|
|
|
|
|
int id = ids[idy];
|
|
|
|
|
PADDLE_ASSERT(id >= 0);
|
|
|
|
|
PADDLE_ASSERT(id < N);
|
|
|
|
|
T* out = output + idy;
|
|
|
|
|
const T* tab = table + id;
|
|
|
|
|
T* out = output + idy * D;
|
|
|
|
|
const T* tab = table + id * D;
|
|
|
|
|
for (int i = idx; i < D; i += blockDimX) {
|
|
|
|
|
out[i] = tab[i];
|
|
|
|
|
}
|
|
|
|
@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int blockDimX, int blockDimY, int gridDimX>
|
|
|
|
|
__global__ void LookupTableGradKernel(T* table, const T* output,
|
|
|
|
|
const uint32_t* ids, const int N,
|
|
|
|
|
const int K, const int D) {
|
|
|
|
|
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
|
|
|
|
|
const int N, const int K, const int D) {
|
|
|
|
|
int idx = threadIdx.x;
|
|
|
|
|
int idy = blockIdx.x + threadIdx.y * gridDimX;
|
|
|
|
|
|
|
|
|
@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output,
|
|
|
|
|
int id = ids[idy];
|
|
|
|
|
PADDLE_ASSERT(id >= 0);
|
|
|
|
|
PADDLE_ASSERT(id < N);
|
|
|
|
|
const T* out = output + idy;
|
|
|
|
|
T* tab = table + id;
|
|
|
|
|
const T* out = output + idy * D;
|
|
|
|
|
T* tab = table + id * D;
|
|
|
|
|
for (int i = idx; i < D; i += blockDimX) {
|
|
|
|
|
paddle::platform::CudaAtomicAdd(tab + i, out[i]);
|
|
|
|
|
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
|
|
|
|
|
}
|
|
|
|
|
idy += blockDimY * gridDimX;
|
|
|
|
|
}
|
|
|
|
@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
|
|
|
|
|
size_t N = table_t->dims()[0];
|
|
|
|
|
size_t D = table_t->dims()[1];
|
|
|
|
|
size_t K = product(ids_t->dims());
|
|
|
|
|
auto ids = ids_t->data<uint32_t>();
|
|
|
|
|
auto ids = ids_t->data<int32_t>();
|
|
|
|
|
auto table = table_t->data<T>();
|
|
|
|
|
auto output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LookupTableGrad : public framework::OpKernel {
|
|
|
|
|
class LookupTableGradCUDAKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto ids_t = context.Input<Tensor>("Ids");
|
|
|
|
@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel {
|
|
|
|
|
int N = d_table_t->dims()[0];
|
|
|
|
|
int D = d_table_t->dims()[1];
|
|
|
|
|
int K = product(ids_t->dims());
|
|
|
|
|
const uint32_t* ids = ids_t->data<uint32_t>();
|
|
|
|
|
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const int32_t* ids = ids_t->data<int32_t>();
|
|
|
|
|
const T* d_output = d_output_t->data<T>();
|
|
|
|
|
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto* device_context =
|
|
|
|
|
const_cast<platform::DeviceContext*>(context.device_context_);
|
|
|
|
@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel {
|
|
|
|
|
device_context);
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
LookupTableGradKernel<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output,
|
|
|
|
|
ids, N, K, D);
|
|
|
|
|
LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N,
|
|
|
|
|
K, D);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(lookup_table_grad,
|
|
|
|
|
ops::LookupTableGradCUDAKernel<float>);
|
|
|
|
|