|
|
|
@ -28,16 +28,12 @@ __global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
|
|
|
|
|
size_t* num_erased) {
|
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (index < in_len) {
|
|
|
|
|
int erased = 0;
|
|
|
|
|
for (size_t i = 0; i < tokens_len; ++i) {
|
|
|
|
|
if (in_dat[index] == tokens[i]) {
|
|
|
|
|
erased = 1;
|
|
|
|
|
num_erased[index + 1] = 1;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
num_erased[index + 1] = erased;
|
|
|
|
|
if (index == 0) {
|
|
|
|
|
num_erased[0] = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -60,26 +56,6 @@ __global__ void SetOutput(const T* in_dat, const int64_t in_len,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Vector>
|
|
|
|
|
thrust::device_vector<T> set_device_vector(Vector& vector) {
|
|
|
|
|
thrust::host_vector<T> host_vec(vector.size());
|
|
|
|
|
for (size_t i = 0; i < vector.size(); ++i) {
|
|
|
|
|
host_vec[i] = vector[i];
|
|
|
|
|
}
|
|
|
|
|
thrust::device_vector<T> dev_vec = host_vec;
|
|
|
|
|
return dev_vec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) {
|
|
|
|
|
thrust::host_vector<T> host_vec = dev_vec;
|
|
|
|
|
std::vector<T> std_vec(host_vec.size(), 0);
|
|
|
|
|
for (size_t i = 0; i < host_vec.size(); ++i) {
|
|
|
|
|
std_vec[i] = host_vec[i];
|
|
|
|
|
}
|
|
|
|
|
return std_vec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -95,12 +71,11 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto in_len = in->numel();
|
|
|
|
|
auto in_dat = in->data<T>();
|
|
|
|
|
// Copy tokens to GPU
|
|
|
|
|
thrust::device_vector<int> dev_tokens =
|
|
|
|
|
set_device_vector<int, std::vector<int>>(tokens);
|
|
|
|
|
thrust::device_vector<int> dev_tokens(tokens.begin(), tokens.end());
|
|
|
|
|
int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
|
|
|
|
|
|
|
|
|
|
// Count number of elements to be erased
|
|
|
|
|
thrust::device_vector<size_t> num_erased(in_len + 1);
|
|
|
|
|
thrust::device_vector<size_t> num_erased(in_len + 1, 0);
|
|
|
|
|
size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
|
|
|
|
|
auto stream = ctx.cuda_device_context().stream();
|
|
|
|
|
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
|
|
|
|
@ -112,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Copy LoD to GPU
|
|
|
|
|
auto lod0 = lod[0];
|
|
|
|
|
auto lod_len = lod0.size();
|
|
|
|
|
thrust::device_vector<size_t> dev_in_lod =
|
|
|
|
|
set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0);
|
|
|
|
|
thrust::device_vector<size_t> dev_in_lod = lod0;
|
|
|
|
|
size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
|
|
|
|
|
|
|
|
|
|
// Calc output LoD
|
|
|
|
@ -124,7 +98,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
|
|
|
|
|
|
|
|
|
|
// Set LoD for output
|
|
|
|
|
std::vector<size_t> out_lod0 = get_std_vector<size_t>(dev_out_lod);
|
|
|
|
|
thrust::host_vector<size_t> out_lod0 = dev_out_lod;
|
|
|
|
|
framework::LoD out_lod;
|
|
|
|
|
out_lod.push_back(out_lod0);
|
|
|
|
|
out->set_lod(out_lod);
|
|
|
|
@ -142,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(sequence_erase,
|
|
|
|
|
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>);
|
|
|
|
|
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>,
|
|
|
|
|
paddle::operators::SequenceEraseOpCUDAKernel<int64_t>);
|
|
|
|
|