commit
						60e86d80d8
					
				
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,51 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
 | 
				
			||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
void CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
 | 
				
			||||
                     bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
 | 
				
			||||
                     int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
void CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
 | 
				
			||||
                     bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
 | 
				
			||||
                     int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
void InnerSoftMax(const T *probs, T *softmax_cost, const int *sequence_length, int max_time, int batch, int numclass,
 | 
				
			||||
                  cudaStream_t stream);
 | 
				
			||||
 | 
				
			||||
void GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, int *cum_labels_length,
 | 
				
			||||
                      int *max_labels_length, int batch, cudaStream_t stream);
 | 
				
			||||
 | 
				
			||||
void GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length,
 | 
				
			||||
                       int *precum_labels_length, int *cum_labels_length, int batch, int blank, cudaStream_t stream);
 | 
				
			||||
 | 
				
			||||
void GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values,
 | 
				
			||||
                   int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size, int blank,
 | 
				
			||||
                   int batch, cudaStream_t stream);
 | 
				
			||||
 | 
				
			||||
void CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
 | 
				
			||||
                        int *max_labels_length, const int64_t *label_indices, int batch, int size, cudaStream_t stream);
 | 
				
			||||
void CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream);
 | 
				
			||||
template <typename T>
 | 
				
			||||
void CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, int SOffSet,
 | 
				
			||||
             int maxtime, int numclass, const int *sequence_length, int *label_squence_length, int *cum_labels_length,
 | 
				
			||||
             T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
 | 
				
			||||
@ -1,31 +1,31 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace kernel {
 | 
				
			||||
MS_REG_GPU_KERNEL_ONE(CTCLossV2,
 | 
				
			||||
                      KernelAttr()
 | 
				
			||||
                        .AddInputAttr(kNumberTypeFloat32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddOutputAttr(kNumberTypeFloat32)
 | 
				
			||||
                        .AddOutputAttr(kNumberTypeFloat32),
 | 
				
			||||
                      CtcLossGpuKernel, float)
 | 
				
			||||
}  // namespace kernel
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace kernel {
 | 
				
			||||
MS_REG_GPU_KERNEL_ONE(CTCLoss,
 | 
				
			||||
                      KernelAttr()
 | 
				
			||||
                        .AddInputAttr(kNumberTypeFloat32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt64)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddOutputAttr(kNumberTypeFloat32)
 | 
				
			||||
                        .AddOutputAttr(kNumberTypeFloat32),
 | 
				
			||||
                      CtcLossGpuKernel, float)
 | 
				
			||||
}  // namespace kernel
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,31 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace kernel {
 | 
				
			||||
MS_REG_GPU_KERNEL_ONE(CTCLossV2,
 | 
				
			||||
                      KernelAttr()
 | 
				
			||||
                        .AddInputAttr(kNumberTypeFloat32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddInputAttr(kNumberTypeInt32)
 | 
				
			||||
                        .AddOutputAttr(kNumberTypeFloat32)
 | 
				
			||||
                        .AddOutputAttr(kNumberTypeFloat32),
 | 
				
			||||
                      CtcLossV2GpuKernel, float)
 | 
				
			||||
}  // namespace kernel
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,192 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
 | 
				
			||||
 | 
				
			||||
#include <cuda_runtime_api.h>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
 | 
				
			||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
 | 
				
			||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace kernel {
 | 
				
			||||
template <typename T>
 | 
				
			||||
class CtcLossV2GpuKernel : public GpuKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  CtcLossV2GpuKernel()
 | 
				
			||||
      : cudnn_handle_(nullptr),
 | 
				
			||||
        probs_desc_(nullptr),
 | 
				
			||||
        ctcloss_desc_(nullptr),
 | 
				
			||||
        label_size_(0),
 | 
				
			||||
        input_lengths_size_(0),
 | 
				
			||||
        label_lengths_size_(0) {}
 | 
				
			||||
  ~CtcLossV2GpuKernel() override { DestroyResource(); }
 | 
				
			||||
 | 
				
			||||
  const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
 | 
				
			||||
  const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
 | 
				
			||||
  const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
 | 
				
			||||
 | 
				
			||||
  bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
 | 
				
			||||
              const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
 | 
				
			||||
    float *probs = GetDeviceAddress<float>(inputs, 0);
 | 
				
			||||
    float *costs = GetDeviceAddress<float>(outputs, 0);
 | 
				
			||||
    float *grads = GetDeviceAddress<float>(outputs, 1);
 | 
				
			||||
 | 
				
			||||
    // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
 | 
				
			||||
    int *labels_host = nullptr;
 | 
				
			||||
    int *no_blank_labels_host = nullptr;
 | 
				
			||||
    void *input_lengths_host = nullptr;
 | 
				
			||||
    void *label_lengths_host = nullptr;
 | 
				
			||||
    cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
 | 
				
			||||
    AllocHostMem(&labels_host, &no_blank_labels_host, &input_lengths_host, &label_lengths_host, inputs);
 | 
				
			||||
    CopyToHostSync(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host, inputs, stream);
 | 
				
			||||
 | 
				
			||||
    size_t workspace_size = 0;
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_EXCEPT(
 | 
				
			||||
      cudnnGetCTCLossWorkspaceSize(
 | 
				
			||||
        cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host),
 | 
				
			||||
        reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host),
 | 
				
			||||
        CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size),
 | 
				
			||||
      "cudnnGetCTCLossWorkspaceSize failed.");
 | 
				
			||||
    void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size);
 | 
				
			||||
    if (workspace == nullptr) {
 | 
				
			||||
      MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_EXCEPT(
 | 
				
			||||
      cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host),
 | 
				
			||||
                   reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs,
 | 
				
			||||
                   probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size),
 | 
				
			||||
      "cudnnCtcLoss failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
 | 
				
			||||
 | 
				
			||||
    device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace);
 | 
				
			||||
    FreeHostMem(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host);
 | 
				
			||||
    return true;
 | 
				
			||||
  }
 | 
				
			||||
  bool Init(const CNodePtr &kernel_node) override {
 | 
				
			||||
    InitResource();
 | 
				
			||||
    auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
 | 
				
			||||
    if (probs_shape.size() != 3) {
 | 
				
			||||
      MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
 | 
				
			||||
    }
 | 
				
			||||
    probs_dims_[0] = probs_shape[0];
 | 
				
			||||
    probs_dims_[1] = probs_shape[1];
 | 
				
			||||
    probs_dims_[2] = probs_shape[2];
 | 
				
			||||
 | 
				
			||||
    auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
 | 
				
			||||
    if (labels_dims.size() != 1 && labels_dims.size() != 2) {
 | 
				
			||||
      MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
 | 
				
			||||
    }
 | 
				
			||||
    label_size_ = sizeof(int);
 | 
				
			||||
    for (auto i : labels_dims) {
 | 
				
			||||
      label_size_ *= i;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
 | 
				
			||||
    input_lengths_size_ = input_length_dims[0] * sizeof(int);
 | 
				
			||||
    auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
 | 
				
			||||
    label_lengths_size_ = label_length_dims[0] * sizeof(int);
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_EXCEPT(
 | 
				
			||||
      cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_),
 | 
				
			||||
      "cudnnSetTensorNdDescriptorEx failed.");
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT,
 | 
				
			||||
                                                            CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN),
 | 
				
			||||
                                "cudnnSetCTCLossDescriptorEx failed.");
 | 
				
			||||
    InitSizeLists();
 | 
				
			||||
    return true;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  void InitResource() override {
 | 
				
			||||
    cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed.");
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed.");
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void InitSizeLists() override {
 | 
				
			||||
    input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
 | 
				
			||||
    input_size_list_.push_back(label_size_);
 | 
				
			||||
    input_size_list_.push_back(input_lengths_size_);
 | 
				
			||||
    input_size_list_.push_back(label_lengths_size_);
 | 
				
			||||
 | 
				
			||||
    output_size_list_.push_back(probs_dims_[1] * sizeof(float));
 | 
				
			||||
    output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  void DestroyResource() noexcept {
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed.");
 | 
				
			||||
    CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed.");
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void AllocHostMem(int **labels_host, int **no_blank_labels_host, void **input_lengths_host, void **label_lengths_host,
 | 
				
			||||
                    const std::vector<AddressPtr> &inputs) {
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(labels_host, inputs[1]->size), "cudaMallocHost failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void FreeHostMem(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host) {
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void CopyToHostSync(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host,
 | 
				
			||||
                      const std::vector<AddressPtr> &inputs, cudaStream_t stream) {
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(
 | 
				
			||||
      cudaMemcpyAsync(labels_host, inputs[1]->addr, inputs[1]->size, cudaMemcpyDeviceToHost, stream),
 | 
				
			||||
      "cudaMemcpyAsync failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(
 | 
				
			||||
      cudaMemcpyAsync(input_lengths_host, inputs[2]->addr, inputs[2]->size, cudaMemcpyDeviceToHost, stream),
 | 
				
			||||
      "cudaMemcpyAsync failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(
 | 
				
			||||
      cudaMemcpyAsync(label_lengths_host, inputs[3]->addr, inputs[3]->size, cudaMemcpyDeviceToHost, stream),
 | 
				
			||||
      "cudaMemcpyAsync failed.");
 | 
				
			||||
    CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
 | 
				
			||||
 | 
				
			||||
    // remove blank element
 | 
				
			||||
    size_t j = 0;
 | 
				
			||||
    for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
 | 
				
			||||
      if (labels_host[i] != 0) {
 | 
				
			||||
        no_blank_labels_host[j] = labels_host[i];
 | 
				
			||||
        j++;
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<size_t> input_size_list_;
 | 
				
			||||
  std::vector<size_t> output_size_list_;
 | 
				
			||||
  std::vector<size_t> workspace_size_list_;
 | 
				
			||||
 | 
				
			||||
  cudnnHandle_t cudnn_handle_;
 | 
				
			||||
  cudnnTensorDescriptor_t probs_desc_;
 | 
				
			||||
  cudnnCTCLossDescriptor_t ctcloss_desc_;
 | 
				
			||||
  int probs_dims_[3] = {0};
 | 
				
			||||
  int label_size_;
 | 
				
			||||
  int input_lengths_size_;
 | 
				
			||||
  int label_lengths_size_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace kernel
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue