@ -1,11 +1,8 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -14,22 +11,21 @@
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/lookup_table_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTable(T* output, const T* table, const int32 _t* ids,
const int N, const int K, const int D) {
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
const int64_t N, const int64_t K, const int64_ t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
int id = ids[idy];
int64_t id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
T* out = output + idy * D;
@ -42,8 +38,9 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
}
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
const int N, const int K, const int D) {
__global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
const int64_t N, const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
@ -71,7 +68,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto ids = ids_t->data<int32 _t>();
auto ids = ids_t->data<int64 _t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
@ -88,27 +85,63 @@ template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int32_t* ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8, 8><<<
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids");
auto* table = context.Input<Tensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_dim[0]);
auto gpu_place = boost::get<platform::GPUPlace>(context.GetPlace());
memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data,
ids_dim[0] * sizeof(int64_t), stream);
d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace());
auto* d_table_data = d_table_value->data<T>();
auto* d_output_data = d_output->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel(), stream);
} else {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t* ids = ids_t->data<int64_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8,
8><<<grids, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(d_table, d_output, ids, N, K, D);
}
}
};
@ -116,6 +149,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
ops::LookupTableCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel<float>,
ops::LookupTableGradCUDAKernel<double>);