|
|
|
@ -19,46 +19,32 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace math {
|
|
|
|
|
|
|
|
|
|
template <typename T, bool Padding>
|
|
|
|
|
template <typename T, CopyType Type>
|
|
|
|
|
__global__ void SequencePaddingKernel(
|
|
|
|
|
T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num,
|
|
|
|
|
const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times,
|
|
|
|
|
const T& pad_value, const OutputLayout& output_layout) {
|
|
|
|
|
T* dst, const T* src, const T* pad_value, bool is_constant_pad,
|
|
|
|
|
const size_t* seq_offsets, const size_t& seq_num, const size_t& pad_seq_len,
|
|
|
|
|
const size_t& step_width, bool norm_by_len, const PadLayout& layout) {
|
|
|
|
|
size_t seq_idx = blockIdx.y;
|
|
|
|
|
size_t seq_start = seq_offset[seq_idx];
|
|
|
|
|
size_t seq_len = seq_offset[seq_idx + 1] - seq_start;
|
|
|
|
|
|
|
|
|
|
size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
|
|
|
|
|
|
|
|
|
size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width;
|
|
|
|
|
|
|
|
|
|
size_t pad_data_offset = 0;
|
|
|
|
|
|
|
|
|
|
if (output_layout == kLengthBatchWidth) {
|
|
|
|
|
pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width;
|
|
|
|
|
} else {
|
|
|
|
|
pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (seq_step_idx < seq_len) {
|
|
|
|
|
T scale = norm_by_times ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
|
|
|
|
|
if (Padding) {
|
|
|
|
|
/* seq -> pad */
|
|
|
|
|
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
|
|
|
|
|
pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
/* pad -> seq */
|
|
|
|
|
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
|
|
|
|
|
seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
|
|
|
|
|
}
|
|
|
|
|
size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
|
|
|
|
|
|
|
|
|
|
size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
|
|
|
|
size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
|
|
|
|
|
size_t pad_data_offset = layout == kBatchLengthWidth
|
|
|
|
|
? (seq_idx * pad_seq_len + step_idx) * step_width
|
|
|
|
|
: (step_idx * seq_num + seq_idx) * step_width;
|
|
|
|
|
|
|
|
|
|
T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
|
|
|
|
|
const T* src_data =
|
|
|
|
|
src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);
|
|
|
|
|
|
|
|
|
|
if (step_idx < seq_len) {
|
|
|
|
|
float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
|
|
|
|
|
for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
|
|
|
|
|
dst_data[i] = scale * src_data[i];
|
|
|
|
|
}
|
|
|
|
|
} else if (seq_step_idx < max_seq_len) {
|
|
|
|
|
if (Padding) {
|
|
|
|
|
/* seq -> pad */
|
|
|
|
|
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
|
|
|
|
|
pad_data[pad_data_offset + i] = pad_value;
|
|
|
|
|
}
|
|
|
|
|
} else if (step_idx < pad_seq_len && Type == kSeqToPad) {
|
|
|
|
|
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
|
|
|
|
|
dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -69,24 +55,26 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const framework::LoDTensor& seq_tensor,
|
|
|
|
|
framework::Tensor* pad_tensor,
|
|
|
|
|
T pad_value = static_cast<T>(0), bool norm_by_times = false,
|
|
|
|
|
size_t lod_level = 0,
|
|
|
|
|
OutputLayout output_layout = kBatchLengthWidth) {
|
|
|
|
|
CheckLoD(seq_tensor, lod_level);
|
|
|
|
|
|
|
|
|
|
auto& lod = seq_tensor.lod();
|
|
|
|
|
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
|
|
|
|
|
|
|
|
|
|
auto seq_tensor_dims = seq_tensor.dims();
|
|
|
|
|
auto pad_tensor_dims = pad_tensor->dims();
|
|
|
|
|
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
|
|
|
|
|
int64_t seq_num = seq_offset.size() - 1;
|
|
|
|
|
int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0];
|
|
|
|
|
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
|
|
|
|
|
int lod_level = 0, bool norm_by_times = false,
|
|
|
|
|
const PadLayout layout = kBatchLengthWidth) {
|
|
|
|
|
auto seq_lod = seq_tensor.lod();
|
|
|
|
|
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
|
|
|
|
|
const auto& seq_tensor_dims = seq_tensor.dims();
|
|
|
|
|
const auto& pad_tensor_dims = pad_tensor->dims();
|
|
|
|
|
if (pad_seq_len == -1) {
|
|
|
|
|
pad_seq_len = MaximumSequenceLength(seq_offsets);
|
|
|
|
|
}
|
|
|
|
|
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
|
|
|
|
|
int seq_num = seq_offset.size() - 1;
|
|
|
|
|
|
|
|
|
|
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
|
|
|
|
|
seq_num, seq_width, output_layout);
|
|
|
|
|
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
|
|
|
|
|
step_width, layout);
|
|
|
|
|
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
|
|
|
|
|
"The numel of 'pad_value' can only be 1 or be equal to the "
|
|
|
|
|
"'step_width'.");
|
|
|
|
|
|
|
|
|
|
if (!norm_by_times && seq_num == 1UL) {
|
|
|
|
|
if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) {
|
|
|
|
|
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
|
|
|
|
|
pad_tensor->Resize(pad_tensor_dims);
|
|
|
|
|
return;
|
|
|
|
@ -98,21 +86,22 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
* and at least 8 elements for each thread.
|
|
|
|
|
*/
|
|
|
|
|
size_t block_dim_x =
|
|
|
|
|
std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
|
|
|
|
|
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
|
|
|
|
|
size_t block_dim_y = kBlockSize / block_dim_x;
|
|
|
|
|
dim3 threads(block_dim_x, block_dim_y);
|
|
|
|
|
|
|
|
|
|
size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
|
|
|
|
|
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
|
|
|
|
|
size_t grid_dim_y = seq_num;
|
|
|
|
|
dim3 grid(grid_dim_x, grid_dim_y);
|
|
|
|
|
|
|
|
|
|
const T* seq_data = seq_tensor.data<T>();
|
|
|
|
|
T* pad_data = pad_tensor->data<T>();
|
|
|
|
|
const T* pad_value_data = pad_value.data<T>();
|
|
|
|
|
|
|
|
|
|
SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
pad_data, const_cast<T*>(seq_data),
|
|
|
|
|
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
|
|
|
|
|
seq_width, norm_by_times, pad_value, output_layout);
|
|
|
|
|
SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
|
|
|
|
|
seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
|
|
|
|
|
step_width, norm_by_times, layout);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -120,25 +109,23 @@ template <typename T>
|
|
|
|
|
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
framework::LoDTensor* seq_tensor,
|
|
|
|
|
const framework::Tensor& pad_tensor,
|
|
|
|
|
bool norm_by_times = false, size_t lod_level = 0,
|
|
|
|
|
OutputLayout output_layout = kBatchLengthWidth) {
|
|
|
|
|
CheckLoD(*seq_tensor, lod_level);
|
|
|
|
|
|
|
|
|
|
auto& lod = seq_tensor->lod();
|
|
|
|
|
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
|
|
|
|
|
|
|
|
|
|
auto seq_tensor_dims = seq_tensor->dims();
|
|
|
|
|
auto pad_tensor_dims = pad_tensor.dims();
|
|
|
|
|
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
|
|
|
|
|
int64_t seq_num = seq_offset.size() - 1;
|
|
|
|
|
int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0];
|
|
|
|
|
const framework::LoDTensor& pad_tensor,
|
|
|
|
|
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
|
|
|
|
|
int lod_level = 0, bool norm_by_times = false,
|
|
|
|
|
const PadLayout layout = kBatchLengthWidth) {
|
|
|
|
|
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
|
|
|
|
|
const auto& seq_tensor_dims = seq_tensor->dims();
|
|
|
|
|
const auto& pad_tensor_dims = pad_tensor.dims();
|
|
|
|
|
if (pad_seq_len == -1) {
|
|
|
|
|
pad_seq_len = MaximumSequenceLength(seq_offsets);
|
|
|
|
|
}
|
|
|
|
|
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
|
|
|
|
|
int seq_num = seq_offset.size() - 1;
|
|
|
|
|
|
|
|
|
|
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
|
|
|
|
|
seq_num, seq_width, output_layout);
|
|
|
|
|
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
|
|
|
|
|
step_width, layout);
|
|
|
|
|
|
|
|
|
|
if (!norm_by_times && seq_num == 1UL) {
|
|
|
|
|
if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) {
|
|
|
|
|
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
|
|
|
|
|
seq_tensor->Resize(seq_tensor_dims);
|
|
|
|
|
return;
|
|
|
|
@ -150,21 +137,21 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
* and at least 8 elements for each thread.
|
|
|
|
|
*/
|
|
|
|
|
size_t block_dim_x =
|
|
|
|
|
std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
|
|
|
|
|
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
|
|
|
|
|
size_t block_dim_y = kBlockSize / block_dim_x;
|
|
|
|
|
dim3 threads(block_dim_x, block_dim_y);
|
|
|
|
|
|
|
|
|
|
size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
|
|
|
|
|
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
|
|
|
|
|
size_t grid_dim_y = seq_num;
|
|
|
|
|
dim3 grid(grid_dim_x, grid_dim_y);
|
|
|
|
|
|
|
|
|
|
const T* pad_data = pad_tensor.data<T>();
|
|
|
|
|
T* seq_data = seq_tensor->data<T>();
|
|
|
|
|
|
|
|
|
|
SequencePaddingKernel<T, 0><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
const_cast<T*>(pad_data), seq_data,
|
|
|
|
|
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
|
|
|
|
|
seq_width, norm_by_times, static_cast<T>(0), output_layout);
|
|
|
|
|
SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
seq_data, pad_data, nullptr, false,
|
|
|
|
|
seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
|
|
|
|
|
step_width, norm_by_times, layout);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|