Port WarpCTC Operator (#5107)
* Add Seq2BatchFunctor, which will be used in WarpCTCOp. * Implement WrapCTCFunctor and WrapCTCKernel. * Add unittest of warpctc_op. * Modify the check_output inferface in python unittest framework to allow check a subset of outputs. * Use absolute offset lod in warpctc_op and related functors. * Refine the comments of warpctc_op. * The new python unittest supports checking a subset of the outputs, so revoke the previous change. * Rename the transform from LoDTensor to Tensor with shape [max_sequence_length, num_sequences, sequence_width] to PaddingSequenceFunctor. * Update to the newest codes. * Rename the PaddingSequenceFunctor to PaddingLoDTensorFunctor and remove the computation of dimensions out of the functos.detection_output_fixbug
parent
fe341bacde
commit
b5fda2723f
@ -0,0 +1,144 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/sequence_padding.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& context,
|
||||
const framework::LoDTensor& seq, framework::Tensor& padding,
|
||||
bool norm_by_times) {
|
||||
auto lod = seq.lod();
|
||||
PADDLE_ENFORCE_GT(lod.size(), 0UL,
|
||||
"The LoD of LoDTensor seq should not be null.");
|
||||
|
||||
const size_t level = 0;
|
||||
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
|
||||
|
||||
auto seq_dims = seq.dims();
|
||||
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
|
||||
"The first dimension of LoDTensor seq should be "
|
||||
"equal to the sum of all sequences's length.");
|
||||
|
||||
auto padding_dims = padding.dims();
|
||||
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
|
||||
"The input padding should be a 3-D Tensor of shape "
|
||||
"[max_sequence_length, num_sequences, sequence_width].");
|
||||
|
||||
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
|
||||
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
|
||||
"The first dimension of Tensor padding should be the "
|
||||
"maximum length of all sequences in LoDTensor seq.");
|
||||
|
||||
const size_t num_sequences = abs_offset_lod[level].size() - 1;
|
||||
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
|
||||
"The second dimension of Tensor padding should be the "
|
||||
"number of sequences in LoDTensor seq.");
|
||||
|
||||
const size_t sequence_width = seq.numel() / seq_dims[0];
|
||||
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
|
||||
"The third dimension of Tensor padding should be the "
|
||||
"width of sequence in LoDTensor seq.");
|
||||
|
||||
const T* seq_data = seq.data<T>();
|
||||
T* padding_data = padding.data<T>();
|
||||
for (size_t i = 0; i < max_sequence_length; ++i) {
|
||||
for (size_t j = 0; j < num_sequences; ++j) {
|
||||
size_t start_pos = abs_offset_lod[level][j];
|
||||
size_t sequence_length = abs_offset_lod[level][j + 1] - start_pos;
|
||||
if (i < sequence_length) {
|
||||
// i > 0 => sequence_length > 0
|
||||
T scale =
|
||||
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
|
||||
for (size_t k = 0; k < sequence_width; ++k) {
|
||||
padding_data[(i * num_sequences + j) * sequence_width + k] =
|
||||
seq_data[(start_pos + i) * sequence_width + k] * scale;
|
||||
}
|
||||
} else {
|
||||
memset(padding_data + (i * num_sequences + j) * sequence_width, 0,
|
||||
sequence_width * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& context,
|
||||
framework::LoDTensor& seq, const framework::Tensor& padding,
|
||||
bool norm_by_times) {
|
||||
auto lod = seq.lod();
|
||||
PADDLE_ENFORCE_GT(lod.size(), 0UL,
|
||||
"The LoD of LoDTensor seq should not be null.");
|
||||
|
||||
const size_t level = 0;
|
||||
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
|
||||
|
||||
auto seq_dims = seq.dims();
|
||||
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
|
||||
"The first dimension of LoDTensor seq should be "
|
||||
"equal to the sum of all sequences's length.");
|
||||
|
||||
auto padding_dims = padding.dims();
|
||||
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
|
||||
"The input padding should be a 3-D Tensor of shape "
|
||||
"[max_sequnece_length, num_sequences, sequence_width].");
|
||||
|
||||
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
|
||||
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
|
||||
"The first dimension of Tensor padding should be "
|
||||
"the maximum length of all sequences in LoDTensor seq.");
|
||||
|
||||
const size_t num_sequences = abs_offset_lod[level].size() - 1;
|
||||
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
|
||||
"The second dimension of Tensor padding should be "
|
||||
"the number of sequences in LoDTensor seq.");
|
||||
|
||||
const size_t sequence_width = seq.numel() / seq_dims[0];
|
||||
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
|
||||
"The third dimension of Tensor padding should be the "
|
||||
"width of sequence in LoDTensor seq.");
|
||||
|
||||
const T* padding_data = padding.data<T>();
|
||||
T* seq_data = seq.data<T>();
|
||||
for (size_t i = 0; i < num_sequences; ++i) {
|
||||
size_t start_pos = abs_offset_lod[level][i];
|
||||
size_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;
|
||||
for (size_t j = 0; j < sequence_length; ++j) {
|
||||
// sequence_width > j > 0
|
||||
T scale =
|
||||
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
|
||||
for (size_t k = 0; k < sequence_width; ++k) {
|
||||
seq_data[(start_pos + j) * sequence_width + k] =
|
||||
padding_data[(j * num_sequences + i) * sequence_width + k] *
|
||||
scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
|
||||
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,209 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/sequence_padding.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T, bool NormByTimes, bool Padding>
|
||||
__global__ void SequencePaddingKernel(T* padding, T* sequence,
|
||||
const size_t* sequence_start_positions,
|
||||
const size_t sequence_width,
|
||||
const size_t max_sequence_length,
|
||||
const size_t num_sequences) {
|
||||
size_t padding_idx = blockIdx.y;
|
||||
size_t start_pos = sequence_start_positions[padding_idx];
|
||||
size_t sequence_length =
|
||||
sequence_start_positions[padding_idx + 1] - start_pos;
|
||||
|
||||
size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
size_t padding_base_idx =
|
||||
(sequence_idx * num_sequences + padding_idx) * sequence_width;
|
||||
size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width;
|
||||
|
||||
if (sequence_idx < sequence_length) {
|
||||
T scale = NormByTimes ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
|
||||
if (Padding) {
|
||||
/* sequence -> padding */
|
||||
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
|
||||
padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i];
|
||||
}
|
||||
} else {
|
||||
/* padding -> sequence */
|
||||
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
|
||||
sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i];
|
||||
}
|
||||
}
|
||||
} else if (sequence_idx < max_sequence_length) {
|
||||
if (Padding) {
|
||||
/* sequence -> padding */
|
||||
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
|
||||
padding[padding_base_idx + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& context,
|
||||
const framework::LoDTensor& seq, framework::Tensor& padding,
|
||||
bool norm_by_times) {
|
||||
auto lod = seq.lod();
|
||||
PADDLE_ENFORCE_GT(lod.size(), 0UL,
|
||||
"The lod of LoDTensor seq should not be null.");
|
||||
|
||||
const size_t level = 0;
|
||||
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
|
||||
|
||||
auto seq_dims = seq.dims();
|
||||
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
|
||||
"The first dimension of LoDTensor seq should be "
|
||||
"equal to the sum of all sequences's length.");
|
||||
|
||||
auto padding_dims = padding.dims();
|
||||
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
|
||||
"The input padding should be a 3-D Tensor of shape "
|
||||
"[max_sequence_length, num_sequences, sequence_width].");
|
||||
|
||||
size_t max_sequence_length = MaximumSequenceLength(lod, level);
|
||||
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
|
||||
"The first dimension of Tensor padding should be the "
|
||||
"maximum length of all sequences in LoDTensor seq.");
|
||||
|
||||
const size_t num_sequences = abs_offset_lod[level].size() - 1;
|
||||
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
|
||||
"The second dimension of Tensor padding should be the "
|
||||
"number of sequences in LoDTensor seq.");
|
||||
|
||||
const size_t sequence_width = seq.numel() / seq_dims[0];
|
||||
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
|
||||
"The third dimension of Tensor padding should be the "
|
||||
"width of sequence in LoDTensor seq.");
|
||||
|
||||
if (!norm_by_times && num_sequences == 1UL) {
|
||||
Copy(seq, context.GetPlace(), context, &padding);
|
||||
padding.Resize(padding_dims);
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t kBlockSize = 512;
|
||||
|
||||
/* At least use 32 threads to copy sequence_width elements,
|
||||
* and at least 8 elements for each thread.
|
||||
*/
|
||||
size_t block_dim_x =
|
||||
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
|
||||
size_t block_dim_y = kBlockSize / block_dim_x;
|
||||
dim3 threads(block_dim_x, block_dim_y);
|
||||
|
||||
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
|
||||
size_t grid_dim_y = num_sequences;
|
||||
dim3 grid(grid_dim_x, grid_dim_y);
|
||||
|
||||
const T* seq_data = seq.data<T>();
|
||||
T* padding_data = padding.data<T>();
|
||||
if (norm_by_times) {
|
||||
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
|
||||
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(),
|
||||
sequence_width, max_sequence_length, num_sequences);
|
||||
} else {
|
||||
SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
|
||||
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(),
|
||||
sequence_width, max_sequence_length, num_sequences);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& context,
|
||||
framework::LoDTensor& seq, const framework::Tensor& padding,
|
||||
bool norm_by_times) {
|
||||
auto lod = seq.lod();
|
||||
PADDLE_ENFORCE_GT(lod.size(), 0UL,
|
||||
"The lod of LoDTensor seq should not be null.");
|
||||
|
||||
const size_t level = 0;
|
||||
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
|
||||
|
||||
auto seq_dims = seq.dims();
|
||||
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
|
||||
"The first dimension of LoDTensor seq should be "
|
||||
"equal to the sum of all sequences's length.");
|
||||
|
||||
auto padding_dims = padding.dims();
|
||||
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
|
||||
"The input padding should be a 3-D Tensor of shape "
|
||||
"[max_sequnece_length, num_sequences, sequence_width].");
|
||||
|
||||
size_t max_sequence_length = MaximumSequenceLength(lod, level);
|
||||
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
|
||||
"The first dimension of Tensor padding should be "
|
||||
"the maximum length of all sequences in LoDTensor seq.");
|
||||
|
||||
const size_t num_sequences = abs_offset_lod[level].size() - 1;
|
||||
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
|
||||
"The second dimension of Tensor padding should be "
|
||||
"the number of sequences in LoDTensor seq.");
|
||||
|
||||
const size_t sequence_width = seq.numel() / seq_dims[0];
|
||||
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
|
||||
"The third dimension of Tensor padding should be the "
|
||||
"width of sequence in LoDTensor seq.");
|
||||
|
||||
if (!norm_by_times && num_sequences == 1UL) {
|
||||
Copy(padding, context.GetPlace(), context, &seq);
|
||||
seq.Resize(seq_dims);
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t kBlockSize = 512;
|
||||
|
||||
/* At least use 32 threads to copy sequence_width elements,
|
||||
* and at least 8 elements for each thread.
|
||||
*/
|
||||
size_t block_dim_x =
|
||||
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
|
||||
size_t block_dim_y = kBlockSize / block_dim_x;
|
||||
dim3 threads(block_dim_x, block_dim_y);
|
||||
|
||||
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
|
||||
size_t grid_dim_y = num_sequences;
|
||||
dim3 grid(grid_dim_x, grid_dim_y);
|
||||
|
||||
const T* padding_data = padding.data<T>();
|
||||
T* seq_data = seq.data<T>();
|
||||
if (norm_by_times) {
|
||||
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
|
||||
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(),
|
||||
sequence_width, max_sequence_length, num_sequences);
|
||||
} else {
|
||||
SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
|
||||
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(),
|
||||
sequence_width, max_sequence_length, num_sequences);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
|
||||
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,79 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
inline static size_t MaximumSequenceLength(const framework::LoD& lod,
|
||||
const size_t level) {
|
||||
const size_t num_sequences = lod[level].size() - 1;
|
||||
size_t max_sequence_length = 0;
|
||||
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
|
||||
for (size_t i = 0; i < num_sequences; ++i) {
|
||||
max_sequence_length =
|
||||
std::max(max_sequence_length,
|
||||
abs_offset_lod[level][i + 1] - abs_offset_lod[level][i]);
|
||||
}
|
||||
return max_sequence_length;
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief Padding/Unpadding LoDTensor to/from normal Tensor of the shape
|
||||
* [max_sequence_length, num_sequences, sequence_width].
|
||||
*
|
||||
* Padding sequence:
|
||||
* padding[i] = seq[lod[level][i]]
|
||||
* Unpadding sequence:
|
||||
* seq[lod[level][i]] = padding[i]
|
||||
*
|
||||
* All sequences will be padded to the same length and stored in a transposed
|
||||
* shape.
|
||||
* Example:
|
||||
* seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
|
||||
* padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0)
|
||||
*
|
||||
* \param context device context of this functor.
|
||||
* \param seq LoDTensor which is stored in sequence format, the shape
|
||||
* is [total_sequence_length, sequence_width] where
|
||||
* total_sequence_length is the sum of all sequences'
|
||||
* length.
|
||||
* \param padding Tensor which is padded to the same length, the shape is
|
||||
* [max_sequence_length, num_sequences, sequence_width].
|
||||
* \param norm_by_times whether dividing sequence's length.
|
||||
*
|
||||
* \note transposition is also done in this functor.
|
||||
*/
|
||||
template <typename DeviceContext, typename T>
|
||||
class PaddingLoDTensorFunctor {
|
||||
public:
|
||||
void operator()(const DeviceContext& context, const framework::LoDTensor& seq,
|
||||
framework::Tensor& padding, bool norm_by_times);
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class UnpaddingLoDTensorFunctor {
|
||||
public:
|
||||
void operator()(const DeviceContext& context, framework::LoDTensor& seq,
|
||||
const framework::Tensor& padding, bool norm_by_times);
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,104 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/sequence_padding.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
template <typename DeviceContext, typename Place, typename T>
|
||||
void TestSequencePadding(const paddle::framework::LoD& lod,
|
||||
const size_t sequence_width) {
|
||||
paddle::framework::LoDTensor cpu_seq;
|
||||
paddle::framework::LoDTensor cpu_seq_back;
|
||||
paddle::framework::LoDTensor seq;
|
||||
paddle::framework::LoDTensor seq_back;
|
||||
paddle::framework::Tensor padding;
|
||||
|
||||
const size_t level = lod.size() - 1;
|
||||
auto seq_dims =
|
||||
paddle::framework::make_ddim({static_cast<int64_t>(lod[level].back()),
|
||||
static_cast<int64_t>(sequence_width)});
|
||||
|
||||
cpu_seq.set_lod(lod);
|
||||
cpu_seq.mutable_data<T>(seq_dims, paddle::platform::CPUPlace());
|
||||
for (size_t i = 0; i < cpu_seq.numel(); ++i) {
|
||||
cpu_seq.data<T>()[i] = static_cast<T>(i);
|
||||
}
|
||||
|
||||
auto* place = new Place();
|
||||
DeviceContext* context = new DeviceContext(*place);
|
||||
if (paddle::platform::is_cpu_place(*place)) {
|
||||
seq = cpu_seq;
|
||||
} else {
|
||||
Copy(cpu_seq, *place, *context, &seq);
|
||||
seq.set_lod(lod);
|
||||
}
|
||||
|
||||
const size_t max_sequence_length =
|
||||
paddle::operators::math::MaximumSequenceLength(lod, level);
|
||||
const size_t num_sequences = lod[level].size() - 1;
|
||||
auto padding_dims =
|
||||
paddle::framework::make_ddim({static_cast<int64_t>(max_sequence_length),
|
||||
static_cast<int64_t>(num_sequences),
|
||||
static_cast<int64_t>(sequence_width)});
|
||||
padding.mutable_data<T>(padding_dims, *place);
|
||||
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
*context, seq, padding, false);
|
||||
|
||||
seq_back.set_lod(lod);
|
||||
seq_back.mutable_data<T>(seq_dims, *place);
|
||||
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
*context, seq_back, padding, false);
|
||||
|
||||
if (paddle::platform::is_cpu_place(*place)) {
|
||||
cpu_seq_back = seq_back;
|
||||
} else {
|
||||
Copy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back);
|
||||
cpu_seq_back.set_lod(lod);
|
||||
}
|
||||
|
||||
EXPECT_EQ(cpu_seq.numel(), cpu_seq_back.numel());
|
||||
EXPECT_EQ(cpu_seq.dims(), cpu_seq_back.dims());
|
||||
for (size_t i = 0; i < cpu_seq.numel(); ++i) {
|
||||
EXPECT_EQ(cpu_seq.data<T>()[i], cpu_seq_back.data<T>()[i]);
|
||||
}
|
||||
|
||||
delete place;
|
||||
delete context;
|
||||
};
|
||||
|
||||
TEST(Seq2BatchPadding, CPU) {
|
||||
paddle::framework::LoD lod1;
|
||||
lod1.push_back(std::vector<size_t>{0, 10});
|
||||
TestSequencePadding<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::CPUPlace, float>(lod1, 16);
|
||||
|
||||
paddle::framework::LoD lod2;
|
||||
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
|
||||
TestSequencePadding<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::CPUPlace, float>(lod2, 128);
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
TEST(SequencePadding, CUDA) {
|
||||
paddle::framework::LoD lod1;
|
||||
lod1.push_back(std::vector<size_t>{0, 10});
|
||||
TestSequencePadding<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::CUDAPlace, float>(lod1, 16);
|
||||
|
||||
paddle::framework::LoD lod2;
|
||||
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
|
||||
TestSequencePadding<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::CUDAPlace, float>(lod2, 128);
|
||||
}
|
||||
#endif
|
@ -0,0 +1,141 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/warpctc_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class WarpCTCOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
||||
"Input(Logits) of WarpCTCOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
||||
"Input(Label) of WarpCTCOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("WarpCTCGrad"),
|
||||
"Output(WarpCTCGrad) of WarpCTCOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
|
||||
"Output(Loss) of WarpCTCOp should not be null.");
|
||||
|
||||
auto logits_dims = ctx->GetInputDim("Logits");
|
||||
int sequence_width =
|
||||
static_cast<int>(framework::product(logits_dims) / logits_dims[0]);
|
||||
int blank = ctx->Attrs().Get<int>("blank");
|
||||
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
|
||||
"The value of Attr(blank) should be in interval [0, %d).",
|
||||
sequence_width);
|
||||
// TODO(liuyiqun): it is tricky to set the wrong dimension here.
|
||||
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
WarpCTCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Logits",
|
||||
"(LodTensor, default: LoDTensor<float>), the unscaled "
|
||||
"probabilities of variable-length sequences, which is a 2-D "
|
||||
"Tensor with LoD information. It's shape is "
|
||||
"[Lp, num_classes + 1], where Lp is the sum of all input "
|
||||
"sequences' length and num_classes is the true number of classes "
|
||||
"(not including the blank label).");
|
||||
AddInput("Label",
|
||||
"(LodTensor, default: LoDTensor<int>), the ground truth "
|
||||
"of variable-length sequence, which is a 2-D Tensor with LoD "
|
||||
"information. It is of the shape [Lg, 1], where Lg is th sum of "
|
||||
"all labels' length.");
|
||||
AddOutput("WarpCTCGrad",
|
||||
"(Tensor, default: Tensor<float>), a temporary "
|
||||
"output Tensor to store the gradients of warp-ctc, which is "
|
||||
"computed with loss together in one call. It is a 3-D Tensor of "
|
||||
"the shape [max_sequence_length, batch_size, num_classes + 1].")
|
||||
.AsIntermediate();
|
||||
AddOutput("Loss",
|
||||
"(Tensor, default: Tensor<float>), the Connectionist "
|
||||
"Temporal Classification (CTC) loss, which is a 2-D Tensor of "
|
||||
"the shape [batch_size, 1]");
|
||||
AddAttr<int>("blank",
|
||||
"(int, default: 0), the blank label of Connectionist "
|
||||
"Temporal Classification (CTC) loss, which is in the "
|
||||
"half-opened interval [0, num_classes + 1).")
|
||||
.SetDefault(0);
|
||||
AddAttr<bool>("norm_by_times",
|
||||
"(bool, default: false), whether to "
|
||||
"normalize the gradients by the number of time-step, "
|
||||
"which is also the sequence's length.")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
An operator integrating the open-source
|
||||
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
|
||||
[Deep Speech 2: End-toEnd Speech Recognition in English and Mandarin](
|
||||
https://arxiv.org/pdf/1512.02595v1.pdf),
|
||||
to compute Connectionist Temporal Classification (CTC) loss.
|
||||
It can be aliased as softmax with ctc, since a native softmax activation is
|
||||
interated to the warp-ctc library, to to normlize values for each row of the
|
||||
input tensor.
|
||||
|
||||
More detail of CTC loss can be found by refering to
|
||||
[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with
|
||||
Recurrent Neural Networks](
|
||||
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf).
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class WarpCTCGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("WarpCTCGrad"),
|
||||
"Input(WarpCTCGrad) of WarpCTCGradOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
|
||||
"Output(Logits@GRAD) of WarpCTCGradOp should not be null.");
|
||||
ctx->SetOutputDim(framework::GradVarName("Logits"),
|
||||
ctx->GetInputDim("Logits"));
|
||||
ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, warpctc_grad,
|
||||
ops::WarpCTCGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
warpctc_grad,
|
||||
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>);
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/warpctc_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
warpctc, ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, float>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
warpctc_grad,
|
||||
ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, float>);
|
@ -0,0 +1,218 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
#include "paddle/operators/math/sequence_padding.h"
|
||||
#include "paddle/platform/dynload/warpctc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename DeviceContext>
|
||||
class WarpCTCFunctor {
|
||||
public:
|
||||
/*
|
||||
* \brief Compute the connectionist temporal classification loss,
|
||||
* and optionally compute the gradient with respect to the inputs.
|
||||
*
|
||||
* If gradient is nullptr, it only computes the ctc loss,
|
||||
* or computes both ctc loss and gradient.
|
||||
*
|
||||
* \param ctx execution context of this functor
|
||||
* \param input batch matrix of input probabilities, in
|
||||
* max_sequence_length x num_sequences x
|
||||
* sequence_width, (row-major) format
|
||||
* \param gradient batch matrix of gradient, with the same shape as
|
||||
* input.
|
||||
* \param cpu_labels labels always in CPU memory.
|
||||
* \param cpu_label_lengths length of all labels in CPU memory.
|
||||
* \param cpu_input_lengths length of all sequences in CPU memory.
|
||||
* \param sequence_width number of possible output symbols.
|
||||
* \param num_sequences number of sequence.
|
||||
* \param blank blank label used in ctc loss function.
|
||||
* \param cpu_losss cost of each sequence in CPU memory.
|
||||
*/
|
||||
void operator()(const framework::ExecutionContext& ctx, const float* input,
|
||||
float* gradient, const int* cpu_labels,
|
||||
const int* cpu_label_lengths, const int* cpu_input_lengths,
|
||||
const size_t sequence_width, const size_t num_sequences,
|
||||
const size_t blank, float* cpu_loss) {
|
||||
// Init warp-ctc options
|
||||
init(ctx, blank);
|
||||
|
||||
// Compute the required workspace size.
|
||||
// There is no memory allocated operations within warp-ctc.
|
||||
size_t workspace_bytes = 0;
|
||||
ctcStatus_t status = platform::dynload::get_workspace_size(
|
||||
cpu_label_lengths, cpu_input_lengths, static_cast<int>(sequence_width),
|
||||
static_cast<int>(num_sequences), options_, &workspace_bytes);
|
||||
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
|
||||
"warp-ctc [version %d] Error in get_workspace_size: ",
|
||||
warpctc_version_,
|
||||
platform::dynload::ctcGetStatusString(status));
|
||||
PADDLE_ENFORCE_GT(workspace_bytes, 0UL,
|
||||
"Bytes of workspace got by warp-ctc function, "
|
||||
"get_workspace_size(), should be larger than 0.");
|
||||
|
||||
Tensor workspace;
|
||||
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL;
|
||||
float* workspace_data = workspace.mutable_data<float>(
|
||||
framework::make_ddim({static_cast<int64_t>(workspace_elements)}),
|
||||
ctx.GetPlace());
|
||||
math::SetConstant<DeviceContext, float>()(
|
||||
ctx.template device_context<DeviceContext>(), &workspace,
|
||||
static_cast<float>(0));
|
||||
|
||||
// compute loss and gradient
|
||||
status = platform::dynload::compute_ctc_loss(
|
||||
input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths,
|
||||
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
|
||||
cpu_loss, workspace_data, options_);
|
||||
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
|
||||
"warp-ctc [version %d] Error in compute_ctc_loss: ",
|
||||
warpctc_version_,
|
||||
platform::dynload::ctcGetStatusString(status));
|
||||
}
|
||||
|
||||
protected:
|
||||
void init(const framework::ExecutionContext& ctx, const size_t blank) {
|
||||
warpctc_version_ = platform::dynload::get_warpctc_version();
|
||||
|
||||
if (platform::is_gpu_place(ctx.GetPlace())) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
options_.loc = CTC_GPU;
|
||||
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream();
|
||||
#else
|
||||
PADDLE_THROW("[warpctc init] GPU is not enabled.");
|
||||
#endif
|
||||
} else {
|
||||
options_.loc = CTC_CPU;
|
||||
options_.num_threads = 1;
|
||||
}
|
||||
|
||||
options_.blank_label = blank;
|
||||
}
|
||||
|
||||
private:
|
||||
int warpctc_version_;
|
||||
ctcOptions options_;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class WarpCTCKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* logits = ctx.Input<LoDTensor>("Logits");
|
||||
auto* label = ctx.Input<LoDTensor>("Label");
|
||||
auto* warpctc_grad = ctx.Output<Tensor>("WarpCTCGrad");
|
||||
auto* loss = ctx.Output<Tensor>("Loss");
|
||||
|
||||
const size_t level = 0;
|
||||
|
||||
auto logits_lod = framework::ToAbsOffset(logits->lod());
|
||||
auto logits_dims = logits->dims();
|
||||
PADDLE_ENFORCE_EQ(logits_dims[0],
|
||||
static_cast<int64_t>(logits_lod[level].back()),
|
||||
"The first dimension of Input(Logits) should be equal to "
|
||||
"the sum of all sequences' lengths.");
|
||||
|
||||
auto label_lod = framework::ToAbsOffset(label->lod());
|
||||
auto label_dims = label->dims();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
label_dims[0], label->numel(),
|
||||
"The width of each timestep in Input(Label) should be 1.");
|
||||
|
||||
const size_t num_sequences = logits_lod[level].size() - 1;
|
||||
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
|
||||
"The number of sequences of Input(Logits) should be "
|
||||
"equal to that of Input(Label).");
|
||||
|
||||
const size_t sequence_width = logits->numel() / logits_dims[0];
|
||||
auto loss_dims =
|
||||
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
|
||||
|
||||
// warpctc needs sequences data stored in transposed padding format
|
||||
Tensor warpctc_logits;
|
||||
const size_t max_sequence_length =
|
||||
math::MaximumSequenceLength(logits_lod, level);
|
||||
auto warpctc_logits_dims =
|
||||
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
|
||||
static_cast<int64_t>(num_sequences),
|
||||
static_cast<int64_t>(sequence_width)});
|
||||
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
|
||||
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), *logits, warpctc_logits,
|
||||
false);
|
||||
const T* warpctc_logits_data = warpctc_logits.data<T>();
|
||||
|
||||
std::vector<int> warpctc_label_lengths(num_sequences);
|
||||
std::vector<int> warpctc_logits_lengths(num_sequences);
|
||||
|
||||
for (size_t i = 0; i < num_sequences; ++i) {
|
||||
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
|
||||
warpctc_logits_lengths[i] =
|
||||
logits_lod[level][i + 1] - logits_lod[level][i];
|
||||
}
|
||||
|
||||
// warpctc computes loss and gradient in one call, gradient data also stored
|
||||
// in batch format
|
||||
T* warpctc_grad_data =
|
||||
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
|
||||
|
||||
// warpctc accesses labels in CPU memory
|
||||
Tensor warpctc_label;
|
||||
Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label);
|
||||
const int* warpctc_label_data = warpctc_label.data<int>();
|
||||
|
||||
// warpctc stores loss in CPU memory
|
||||
Tensor warpctc_loss;
|
||||
T* warpctc_loss_data =
|
||||
warpctc_loss.mutable_data<T>(loss_dims, platform::CPUPlace());
|
||||
|
||||
const size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
|
||||
|
||||
WarpCTCFunctor<DeviceContext>()(
|
||||
ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data,
|
||||
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
|
||||
sequence_width, num_sequences, blank, warpctc_loss_data);
|
||||
|
||||
// Copy the loss back
|
||||
Copy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class WarpCTCGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* warpctc_grad = ctx.Input<Tensor>("WarpCTCGrad");
|
||||
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
|
||||
|
||||
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
|
||||
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), *logits_grad,
|
||||
*warpctc_grad, norm_by_times);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,3 +1,4 @@
|
||||
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
|
||||
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
|
||||
DEPS dynamic_loader nccl)
|
||||
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
|
||||
|
@ -0,0 +1,30 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/platform/dynload/warpctc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
std::once_flag warpctc_dso_flag;
|
||||
void* warpctc_dso_handle = nullptr;
|
||||
|
||||
#define DEFINE_WRAP(__name) DynLoad__##__name __name
|
||||
|
||||
WARPCTC_ROUTINE_EACH(DEFINE_WRAP);
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,63 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <mutex>
|
||||
#include "ctc.h"
|
||||
#include "paddle/platform/dynload/dynamic_loader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
extern std::once_flag warpctc_dso_flag;
|
||||
extern void* warpctc_dso_handle;
|
||||
|
||||
/**
|
||||
* The following macro definition can generate structs
|
||||
* (for each function) to dynamic load warpctc routine
|
||||
* via operator overloading.
|
||||
*/
|
||||
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
auto operator()(Args... args) -> decltype(__name(args...)) { \
|
||||
using warpctcFunc = decltype(__name(args...)) (*)(Args...); \
|
||||
std::call_once(warpctc_dso_flag, \
|
||||
paddle::platform::dynload::GetWarpCTCDsoHandle, \
|
||||
&warpctc_dso_handle); \
|
||||
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
|
||||
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
|
||||
} \
|
||||
}; \
|
||||
extern DynLoad__##__name __name
|
||||
|
||||
#define DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
|
||||
DYNAMIC_LOAD_WARPCTC_WRAP(__name)
|
||||
|
||||
#define WARPCTC_ROUTINE_EACH(__macro) \
|
||||
__macro(get_warpctc_version); \
|
||||
__macro(ctcGetStatusString); \
|
||||
__macro(compute_ctc_loss); \
|
||||
__macro(get_workspace_size)
|
||||
|
||||
WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP);
|
||||
|
||||
#undef DYNAMIC_LOAD_WARPCTC_WRAP
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,200 @@
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
from test_softmax_op import stable_softmax
|
||||
|
||||
|
||||
class CTCForward(object):
|
||||
def __init__(self, softmax, softmax_lod, labels, labels_lod, blank,
|
||||
norm_by_times):
|
||||
self.softmax = softmax
|
||||
self.softmax_lod = softmax_lod
|
||||
assert labels.shape[1] == 1
|
||||
self.labels = labels
|
||||
self.labels_lod = labels_lod
|
||||
self.blank = blank
|
||||
self.norm_by_times = norm_by_times
|
||||
|
||||
self.level = 0
|
||||
self.num_classes = softmax.shape[1]
|
||||
self.batch_size = len(softmax_lod[self.level]) - 1
|
||||
assert self.batch_size == len(labels_lod[self.level]) - 1
|
||||
|
||||
self.loss = np.zeros([self.batch_size, 1], dtype="float32")
|
||||
self.gradient = np.zeros(self.softmax.shape, dtype="float32")
|
||||
|
||||
# float64
|
||||
self.EXP_MAX = sys.float_info.max
|
||||
self.EXP_MIN = sys.float_info.min
|
||||
self.LOG_ZERO = np.log(self.EXP_MIN)
|
||||
self.LOG_INFINITY = np.log(self.EXP_MAX)
|
||||
|
||||
def safe_exp(self, x):
|
||||
if x <= self.LOG_ZERO:
|
||||
return 0.0
|
||||
if x >= self.LOG_INFINITY:
|
||||
return self.EXP_MAX
|
||||
return np.exp(x)
|
||||
|
||||
def safe_log(self, x):
|
||||
if x <= self.EXP_MIN:
|
||||
return self.LOG_ZERO
|
||||
return np.log(x)
|
||||
|
||||
# x = lna and y = lnb are in log scale, ln(a / b) = lna - lnb
|
||||
def log_div(self, x, y):
|
||||
res = x - y
|
||||
if res <= self.LOG_ZERO:
|
||||
return self.LOG_ZERO
|
||||
if res >= self.LOG_INFINITY:
|
||||
return self.LOG_INFINITY
|
||||
return res
|
||||
|
||||
# x = lna and y = lnb are in log scale, ln(a * b) = lna + lnb
|
||||
def log_mul(self, x, y):
|
||||
res = x + y
|
||||
if res <= self.LOG_ZERO:
|
||||
return self.LOG_ZERO
|
||||
if res >= self.LOG_INFINITY:
|
||||
return self.LOG_INFINITY
|
||||
return res
|
||||
|
||||
# x = lna and y = lnb are in log scale,
|
||||
# ln(a + b) = lna + ln(1 + exp(lnb - lna)), where b > a
|
||||
def log_add(self, x, y):
|
||||
if x < y:
|
||||
t = y
|
||||
y = x
|
||||
x = t
|
||||
return x + self.safe_log(1 + self.safe_exp(y - x))
|
||||
|
||||
def segment_range(self, time, total_times, total_segments):
|
||||
start = max(0, total_segments - (2 * (total_times - time)))
|
||||
end = min(total_segments, 2 * (time + 1))
|
||||
return start, end
|
||||
|
||||
def forward_a_sequence(self, softmax_a_sequence, labels_a_sequence):
|
||||
total_times = softmax_a_sequence.shape[0]
|
||||
total_segments = labels_a_sequence.shape[0] * 2 + 1
|
||||
|
||||
required_times = labels_a_sequence.shape[0]
|
||||
old_label = -1
|
||||
for i in range(labels_a_sequence.shape[0]):
|
||||
# two contingous labels with the same value
|
||||
if labels_a_sequence[i, 0] == old_label:
|
||||
required_times = required_times + 1
|
||||
old_label = labels_a_sequence[i, 0]
|
||||
|
||||
if total_times < required_times:
|
||||
return 0
|
||||
|
||||
# calculate the forward and backward variables,
|
||||
# reference Chapter 7.3 of "Alex Grave, Supervised Sequence
|
||||
# Labelling with Recurrent Neural Networks"
|
||||
log_acts = np.zeros([total_times, self.num_classes], dtype="float32")
|
||||
for i in range(total_times):
|
||||
for j in range(self.num_classes):
|
||||
log_acts[i, j] = self.safe_log(softmax_a_sequence[i, j])
|
||||
|
||||
# calculate the forward variables
|
||||
forward_vars = np.zeros([total_times, total_segments], dtype="float32")
|
||||
for i in range(total_times):
|
||||
for j in range(total_segments):
|
||||
forward_vars[i, j] = self.LOG_ZERO
|
||||
|
||||
for i in range(total_times):
|
||||
# dp initialization at t0
|
||||
if i == 0:
|
||||
forward_vars[i, 0] = log_acts[0, self.blank]
|
||||
if total_segments > 1:
|
||||
forward_vars[i, 1] = log_acts[0, labels_a_sequence[i, 0]]
|
||||
continue
|
||||
|
||||
# dp from t1
|
||||
start, end = self.segment_range(i, total_times, total_segments)
|
||||
for k in range(end - start):
|
||||
j = k + start
|
||||
if j & 1 == 1:
|
||||
label_idx = j / 2
|
||||
label_val = labels_a_sequence[label_idx, 0]
|
||||
fv = self.log_add(forward_vars[i - 1, j],
|
||||
forward_vars[i - 1, j - 1])
|
||||
if j > 1 and label_val != labels_a_sequence[label_idx - 1,
|
||||
0]:
|
||||
fv = self.log_add(fv, forward_vars[i - 1, j - 2])
|
||||
fv = self.log_mul(fv, log_acts[i, label_val])
|
||||
else:
|
||||
fv = forward_vars[i - 1, j]
|
||||
if j > 0:
|
||||
fv = self.log_add(fv, forward_vars[i - 1, j - 1])
|
||||
fv = self.log_mul(fv, log_acts[i, self.blank])
|
||||
forward_vars[i, j] = fv
|
||||
|
||||
# sum the last two value as log_prob
|
||||
log_prob = forward_vars[total_times - 1, total_segments - 1]
|
||||
if total_segments > 1:
|
||||
log_prob = self.log_add(
|
||||
log_prob, forward_vars[total_times - 1, total_segments - 2])
|
||||
|
||||
return -log_prob
|
||||
|
||||
def forward(self):
|
||||
for i in range(self.batch_size):
|
||||
softmax_start_i = self.softmax_lod[self.level][i]
|
||||
softmax_end_i = self.softmax_lod[self.level][i + 1]
|
||||
labels_start_i = self.labels_lod[self.level][i]
|
||||
labels_end_i = self.labels_lod[self.level][i + 1]
|
||||
|
||||
softmax_a_sequence = self.softmax[softmax_start_i:softmax_end_i, :]
|
||||
labels_a_sequence = self.labels[labels_start_i:labels_end_i, :]
|
||||
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
|
||||
labels_a_sequence)
|
||||
return self.loss
|
||||
|
||||
|
||||
class TestWarpCTCOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "warpctc"
|
||||
|
||||
batch_size = 4
|
||||
num_classes = 8
|
||||
logits_lod = [[0, 4, 5, 8, 11]]
|
||||
logits = np.random.uniform(0.1, 1.0,
|
||||
[11, num_classes]).astype("float32")
|
||||
softmax = np.apply_along_axis(stable_softmax, 1, logits)
|
||||
labels_lod = [[0, 3, 4, 8, 12]]
|
||||
# labels should not be blank
|
||||
labels = np.random.randint(0, num_classes - 1, [12, 1], dtype="int32")
|
||||
|
||||
blank = num_classes - 1
|
||||
norm_by_times = False
|
||||
|
||||
ctc = CTCForward(softmax, logits_lod, labels, labels_lod, blank,
|
||||
norm_by_times)
|
||||
loss = ctc.forward()
|
||||
|
||||
max_sequence_length = 0
|
||||
for i in range(batch_size):
|
||||
max_sequence_length = max(max_sequence_length,
|
||||
logits_lod[0][i + 1] - logits_lod[0][i])
|
||||
gradient = np.zeros(
|
||||
[max_sequence_length, batch_size, num_classes], dtype="float32")
|
||||
|
||||
self.inputs = {
|
||||
"Logits": (logits, logits_lod),
|
||||
"Label": (labels, labels_lod)
|
||||
}
|
||||
self.outputs = {"Loss": loss}
|
||||
self.attrs = {"blank": blank, "norm_by_times": norm_by_times}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
# def test_check_grad(self):
|
||||
# self.outputs["WarpCTCGrad"] = None
|
||||
# self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue