diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h index dca1136228..cf72a32b65 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h @@ -51,10 +51,12 @@ class CtcLossGpuKernel : public GpuKernel { float *grads = GetDeviceAddress(outputs, 1); // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires - void *labels_host = nullptr; + int *labels_host = nullptr; + int *no_blank_labels_host = nullptr; void *input_lengths_host = nullptr; void *label_lengths_host = nullptr; CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed."); cudaStream_t stream = reinterpret_cast(stream_ptr); @@ -68,12 +70,21 @@ class CtcLossGpuKernel : public GpuKernel { "cudaMemcpyAsync failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + + size_t j = 0; + for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) { + if (labels_host[i] != 0) { + no_blank_labels_host[j] = labels_host[i]; + j++; + } + } + size_t workspace_size = 0; CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(labels_host), - reinterpret_cast(label_lengths_host), - reinterpret_cast(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, - ctcloss_desc_, &workspace_size), + cudnnGetCTCLossWorkspaceSize( + cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(no_blank_labels_host), + reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), + CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size), "cudnnGetCTCLossWorkspaceSize failed."); void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size); if (workspace == nullptr) { @@ -81,7 +92,7 @@ class CtcLossGpuKernel : public GpuKernel { } CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(labels_host), + cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(no_blank_labels_host), reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), costs, probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size), "cudnnCtcLoss failed."); @@ -91,6 +102,7 @@ class CtcLossGpuKernel : public GpuKernel { CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed."); return true; } bool Init(const CNodePtr &kernel_node) override {