|
|
|
|
@ -21,74 +21,74 @@ namespace math {
|
|
|
|
|
|
|
|
|
|
template <typename T, bool Padding>
|
|
|
|
|
__global__ void SequencePaddingKernel(
|
|
|
|
|
T* padding_data, T* seq_data, const size_t* abs_offset,
|
|
|
|
|
const size_t& seq_num, const size_t& max_seq_len, const size_t& seq_width,
|
|
|
|
|
const PaddingLayout& padding_layout, bool norm_by_times = false,
|
|
|
|
|
const T& padding_value = 0) {
|
|
|
|
|
size_t padding_idx = blockIdx.y;
|
|
|
|
|
size_t seq_start = abs_offset[padding_idx];
|
|
|
|
|
size_t seq_len = abs_offset[padding_idx + 1] - seq_start;
|
|
|
|
|
T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num,
|
|
|
|
|
const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times,
|
|
|
|
|
const T& pad_value, const OutputLayout& output_layout) {
|
|
|
|
|
size_t seq_idx = blockIdx.y;
|
|
|
|
|
size_t seq_start = seq_offset[seq_idx];
|
|
|
|
|
size_t seq_len = seq_offset[seq_idx + 1] - seq_start;
|
|
|
|
|
|
|
|
|
|
size_t seq_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
|
|
|
|
size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
|
|
|
|
|
|
|
|
|
size_t seq_offset = (seq_start + seq_idx) * seq_width;
|
|
|
|
|
size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width;
|
|
|
|
|
|
|
|
|
|
size_t padding_offset = 0;
|
|
|
|
|
size_t pad_data_offset = 0;
|
|
|
|
|
|
|
|
|
|
if (padding_layout == LENGTH_BATCH_WIDTH) {
|
|
|
|
|
padding_offset = (seq_idx * seq_num + padding_idx) * seq_width;
|
|
|
|
|
if (output_layout == kLengthBatchWidth) {
|
|
|
|
|
pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width;
|
|
|
|
|
} else {
|
|
|
|
|
padding_offset = (padding_idx * max_seq_len + seq_idx) * seq_width;
|
|
|
|
|
pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (seq_idx < seq_len) {
|
|
|
|
|
if (seq_step_idx < seq_len) {
|
|
|
|
|
T scale = norm_by_times ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
|
|
|
|
|
if (Padding) {
|
|
|
|
|
/* sequence -> padding */
|
|
|
|
|
/* seq -> pad */
|
|
|
|
|
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
|
|
|
|
|
padding_data[padding_offset + i] = scale * seq_data[seq_offset + i];
|
|
|
|
|
pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
/* padding -> sequence */
|
|
|
|
|
/* pad -> seq */
|
|
|
|
|
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
|
|
|
|
|
seq_data[seq_offset + i] = scale * padding_data[padding_offset + i];
|
|
|
|
|
seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (seq_idx < max_seq_len) {
|
|
|
|
|
} else if (seq_step_idx < max_seq_len) {
|
|
|
|
|
if (Padding) {
|
|
|
|
|
/* sequence -> padding */
|
|
|
|
|
/* seq -> pad */
|
|
|
|
|
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
|
|
|
|
|
padding_data[padding_offset + i] = padding_value;
|
|
|
|
|
pad_data[pad_data_offset + i] = pad_value;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, PaddingLayout padding_layout>
|
|
|
|
|
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T, padding_layout> {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const framework::LoDTensor& seq_tensor,
|
|
|
|
|
framework::Tensor* padding_tensor,
|
|
|
|
|
T padding_value = static_cast<T>(0),
|
|
|
|
|
bool norm_by_times = false, size_t lod_level = 0) {
|
|
|
|
|
ValidateLoD(seq_tensor, lod_level);
|
|
|
|
|
framework::Tensor* pad_tensor,
|
|
|
|
|
T pad_value = static_cast<T>(0), bool norm_by_times = false,
|
|
|
|
|
size_t lod_level = 0,
|
|
|
|
|
OutputLayout output_layout = kBatchLengthWidth) {
|
|
|
|
|
CheckLoD(seq_tensor, lod_level);
|
|
|
|
|
|
|
|
|
|
auto& lod = seq_tensor.lod();
|
|
|
|
|
auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];
|
|
|
|
|
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
|
|
|
|
|
|
|
|
|
|
auto seq_dims = seq_tensor.dims();
|
|
|
|
|
auto padding_dims = padding_tensor->dims();
|
|
|
|
|
int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
|
|
|
|
|
const int64_t seq_num = abs_offset.size() - 1;
|
|
|
|
|
const int64_t seq_width = seq_tensor.numel() / seq_dims[0];
|
|
|
|
|
auto seq_tensor_dims = seq_tensor.dims();
|
|
|
|
|
auto pad_tensor_dims = pad_tensor->dims();
|
|
|
|
|
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
|
|
|
|
|
int64_t seq_num = seq_offset.size() - 1;
|
|
|
|
|
int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0];
|
|
|
|
|
|
|
|
|
|
ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len,
|
|
|
|
|
seq_num, seq_width, padding_layout);
|
|
|
|
|
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
|
|
|
|
|
seq_num, seq_width, output_layout);
|
|
|
|
|
|
|
|
|
|
if (!norm_by_times && seq_num == 1UL) {
|
|
|
|
|
TensorCopy(seq_tensor, context.GetPlace(), context, padding_tensor);
|
|
|
|
|
padding_tensor->Resize(padding_dims);
|
|
|
|
|
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
|
|
|
|
|
pad_tensor->Resize(pad_tensor_dims);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -107,37 +107,40 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T, padding_layout> {
|
|
|
|
|
dim3 grid(grid_dim_x, grid_dim_y);
|
|
|
|
|
|
|
|
|
|
const T* seq_data = seq_tensor.data<T>();
|
|
|
|
|
T* padding_data = padding_tensor->data<T>();
|
|
|
|
|
T* pad_data = pad_tensor->data<T>();
|
|
|
|
|
|
|
|
|
|
SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
padding_data, const_cast<T*>(seq_data),
|
|
|
|
|
abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
|
|
|
|
|
seq_width, padding_layout, norm_by_times, padding_value);
|
|
|
|
|
pad_data, const_cast<T*>(seq_data),
|
|
|
|
|
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
|
|
|
|
|
seq_width, norm_by_times, pad_value, output_layout);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T, PaddingLayout padding_layout>
|
|
|
|
|
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T,
|
|
|
|
|
padding_layout> {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
framework::LoDTensor* seq_tensor,
|
|
|
|
|
const framework::Tensor& padding_tensor,
|
|
|
|
|
bool norm_by_times = false, size_t lod_level = 0) {
|
|
|
|
|
ValidateLoD(*seq_tensor, lod_level);
|
|
|
|
|
const framework::Tensor& pad_tensor,
|
|
|
|
|
bool norm_by_times = false, size_t lod_level = 0,
|
|
|
|
|
OutputLayout output_layout = kBatchLengthWidth) {
|
|
|
|
|
CheckLoD(*seq_tensor, lod_level);
|
|
|
|
|
|
|
|
|
|
auto& lod = seq_tensor->lod();
|
|
|
|
|
auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];
|
|
|
|
|
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
|
|
|
|
|
|
|
|
|
|
auto seq_dims = seq_tensor->dims();
|
|
|
|
|
auto padding_dims = padding_tensor.dims();
|
|
|
|
|
int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
|
|
|
|
|
int64_t seq_num = abs_offset.size() - 1;
|
|
|
|
|
int64_t seq_width = seq_tensor->numel() / seq_dims[0];
|
|
|
|
|
auto seq_tensor_dims = seq_tensor->dims();
|
|
|
|
|
auto pad_tensor_dims = pad_tensor.dims();
|
|
|
|
|
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
|
|
|
|
|
int64_t seq_num = seq_offset.size() - 1;
|
|
|
|
|
int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0];
|
|
|
|
|
|
|
|
|
|
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
|
|
|
|
|
seq_num, seq_width, output_layout);
|
|
|
|
|
|
|
|
|
|
if (!norm_by_times && seq_num == 1UL) {
|
|
|
|
|
TensorCopy(padding_tensor, context.GetPlace(), context, seq_tensor);
|
|
|
|
|
seq_tensor->Resize(seq_dims);
|
|
|
|
|
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
|
|
|
|
|
seq_tensor->Resize(seq_tensor_dims);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -155,20 +158,25 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T,
|
|
|
|
|
size_t grid_dim_y = seq_num;
|
|
|
|
|
dim3 grid(grid_dim_x, grid_dim_y);
|
|
|
|
|
|
|
|
|
|
const T* padding_data = padding_tensor.data<T>();
|
|
|
|
|
const T* pad_data = pad_tensor.data<T>();
|
|
|
|
|
T* seq_data = seq_tensor->data<T>();
|
|
|
|
|
|
|
|
|
|
SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
const_cast<T*>(padding_data), seq_data,
|
|
|
|
|
abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
|
|
|
|
|
seq_width, padding_layout, norm_by_times);
|
|
|
|
|
SequencePaddingKernel<T, 0><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
const_cast<T*>(pad_data), seq_data,
|
|
|
|
|
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
|
|
|
|
|
seq_width, norm_by_times, static_cast<T>(0), output_layout);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float,
|
|
|
|
|
LENGTH_BATCH_WIDTH>;
|
|
|
|
|
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float,
|
|
|
|
|
LENGTH_BATCH_WIDTH>;
|
|
|
|
|
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
|
|
|
|
|
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
|
|
|
|
|
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
|
|
|
|
|
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
|
|
|
|
|
|
|
|
|
|
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
|
|
|
|
|
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
|
|
|
|
|
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
|
|
|
|
|
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|
|