|
|
|
@ -64,8 +64,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* out = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
|
|
|
|
|
auto lod = in->lod();
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod[lod.size() - 1].back(), (size_t)in->numel(),
|
|
|
|
|
"The actual size mismatches with the LoD information.");
|
|
|
|
|
auto tokens = ctx.Attr<std::vector<int>>("tokens");
|
|
|
|
|
auto in_len = in->numel();
|
|
|
|
@ -85,10 +84,9 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
num_erased.begin() + 1);
|
|
|
|
|
|
|
|
|
|
// Copy LoD to GPU
|
|
|
|
|
auto lod0 = lod[0];
|
|
|
|
|
auto lod_len = lod0.size();
|
|
|
|
|
const size_t* dev_in_lod_ptr = lod0.CUDAData(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto last_lod = lod[lod.size() - 1];
|
|
|
|
|
auto lod_len = last_lod.size();
|
|
|
|
|
const size_t* dev_in_lod_ptr = last_lod.CUDAData(ctx.GetPlace());
|
|
|
|
|
// Calc output LoD
|
|
|
|
|
thrust::device_vector<size_t> dev_out_lod(lod_len);
|
|
|
|
|
size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
|
|
|
|
@ -96,13 +94,16 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
|
|
|
|
|
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
|
|
|
|
|
// Set LoD for output
|
|
|
|
|
std::vector<size_t> out_lod0(dev_out_lod.begin(), dev_out_lod.end());
|
|
|
|
|
std::vector<size_t> out_last_lod(dev_out_lod.begin(), dev_out_lod.end());
|
|
|
|
|
framework::LoD out_lod;
|
|
|
|
|
out_lod.push_back(out_lod0);
|
|
|
|
|
for (size_t i = 0; i < lod.size() - 1; ++i) {
|
|
|
|
|
out_lod.push_back(lod[i]);
|
|
|
|
|
}
|
|
|
|
|
out_lod.push_back(out_last_lod);
|
|
|
|
|
out->set_lod(out_lod);
|
|
|
|
|
|
|
|
|
|
// Set output
|
|
|
|
|
out->Resize({static_cast<int64_t>(out_lod0.back()), 1});
|
|
|
|
|
out->Resize({static_cast<int64_t>(out_last_lod.back()), 1});
|
|
|
|
|
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
|
|
|
|
|
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
|
|
|
|
|