|
|
|
@ -230,8 +230,35 @@ class WarpCTCKernel : public framework::OpKernel<T> {
|
|
|
|
|
static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
// warpctc accesses labels in CPU memory
|
|
|
|
|
Tensor warpctc_label;
|
|
|
|
|
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
|
|
|
|
|
LoDTensor warpctc_label;
|
|
|
|
|
if (ctx.HasInput("LogitsLength")) {
|
|
|
|
|
warpctc_label.mutable_data<int>(
|
|
|
|
|
{static_cast<int64_t>(math::TotalSequenceLength(label_lod)), 1},
|
|
|
|
|
platform::CPUPlace());
|
|
|
|
|
std::vector<framework::Vector<size_t>> lod;
|
|
|
|
|
lod.push_back(label_lod);
|
|
|
|
|
warpctc_label.set_lod(lod);
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), *label,
|
|
|
|
|
&warpctc_label, label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
|
|
|
|
|
false /*norm_by_times*/, math::kBatchLengthWidth);
|
|
|
|
|
} else {
|
|
|
|
|
LoDTensor gpu_label;
|
|
|
|
|
gpu_label.mutable_data<int>(
|
|
|
|
|
{static_cast<int64_t>(math::TotalSequenceLength(label_lod)), 1},
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
gpu_label.set_lod(lod);
|
|
|
|
|
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), *label, &gpu_label,
|
|
|
|
|
label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
|
|
|
|
|
false /*norm_by_times*/, math::kBatchLengthWidth);
|
|
|
|
|
TensorCopySync(gpu_label, platform::CPUPlace(), &warpctc_label);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int* warpctc_label_data = warpctc_label.data<int>();
|
|
|
|
|
// warpctc stores loss in CPU memory
|
|
|
|
|