|
|
|
@ -82,7 +82,7 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
|
|
|
|
|
|
|
|
|
|
if (host_out_lod0.back() == 0) {
|
|
|
|
|
output->Resize({1});
|
|
|
|
|
output->Resize({1, 1});
|
|
|
|
|
output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
|
|
|
|
|
set_constant(ctx.template device_context<platform::CUDADeviceContext>(),
|
|
|
|
|