|
|
|
@ -34,28 +34,26 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* len_t = ctx.Input<LoDTensor>("Length");
|
|
|
|
|
auto* out_t = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
|
|
|
|
|
const int64_t* seq_len_ptr = nullptr;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
framework::Tensor seq_len_cpu =
|
|
|
|
|
ctx.AllocateTmpTensor<T, DeviceContext>(len_t->dims(), dev_ctx);
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
LoDTensor seq_len_cpu;
|
|
|
|
|
seq_len_cpu.Resize(len_t->dims());
|
|
|
|
|
seq_len_ptr = seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
|
|
|
|
|
framework::TensorCopy(*len_t, platform::CPUPlace(),
|
|
|
|
|
ctx.template device_context<DeviceContext>(),
|
|
|
|
|
&seq_len_cpu);
|
|
|
|
|
seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
|
|
|
|
|
framework::TensorCopySync(*len_t, platform::CPUPlace(), &seq_len_cpu);
|
|
|
|
|
} else {
|
|
|
|
|
seq_len_ptr = len_t->data<int64_t>();
|
|
|
|
|
seq_len_cpu = *len_t;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t batch_size = x_t->dims()[0];
|
|
|
|
|
const int64_t* seq_len_ptr = seq_len_cpu.data<int64_t>();
|
|
|
|
|
int64_t batch_size = len_t->dims()[0];
|
|
|
|
|
std::vector<size_t> out_lod0(batch_size + 1, 0);
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i];
|
|
|
|
|
for (int64_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
out_lod0[i + 1] = out_lod0[i] + static_cast<size_t>(seq_len_ptr[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::LoD out_lod;
|
|
|
|
|
out_lod.push_back(out_lod0);
|
|
|
|
|
out_t->set_lod(out_lod);
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_dims_vec{static_cast<int64_t>(out_lod0.back())};
|
|
|
|
|
if (x_t->dims().size() == 2) {
|
|
|
|
|
out_dims_vec.push_back(1);
|
|
|
|
@ -71,8 +69,7 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
int64_t padded_length = x_t->dims()[1];
|
|
|
|
|
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), *x_t, out_t,
|
|
|
|
|
padded_length, 0, false, math::kBatchLengthWidth);
|
|
|
|
|
dev_ctx, *x_t, out_t, padded_length, 0, false, math::kBatchLengthWidth);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|