commit
60e86d80d8
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
|
||||
|
||||
template <typename T>
|
||||
void CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
|
||||
bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
|
||||
int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
|
||||
bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
|
||||
int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void InnerSoftMax(const T *probs, T *softmax_cost, const int *sequence_length, int max_time, int batch, int numclass,
|
||||
cudaStream_t stream);
|
||||
|
||||
void GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, int *cum_labels_length,
|
||||
int *max_labels_length, int batch, cudaStream_t stream);
|
||||
|
||||
void GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length,
|
||||
int *precum_labels_length, int *cum_labels_length, int batch, int blank, cudaStream_t stream);
|
||||
|
||||
void GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values,
|
||||
int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size, int blank,
|
||||
int batch, cudaStream_t stream);
|
||||
|
||||
void CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
|
||||
int *max_labels_length, const int64_t *label_indices, int batch, int size, cudaStream_t stream);
|
||||
void CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream);
|
||||
template <typename T>
|
||||
void CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, int SOffSet,
|
||||
int maxtime, int numclass, const int *sequence_length, int *label_squence_length, int *cum_labels_length,
|
||||
T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
|
@ -1,31 +1,31 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CTCLossV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CtcLossGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CTCLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CtcLossGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CTCLossV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CtcLossV2GpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,192 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CtcLossV2GpuKernel : public GpuKernel {
|
||||
public:
|
||||
CtcLossV2GpuKernel()
|
||||
: cudnn_handle_(nullptr),
|
||||
probs_desc_(nullptr),
|
||||
ctcloss_desc_(nullptr),
|
||||
label_size_(0),
|
||||
input_lengths_size_(0),
|
||||
label_lengths_size_(0) {}
|
||||
~CtcLossV2GpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
float *probs = GetDeviceAddress<float>(inputs, 0);
|
||||
float *costs = GetDeviceAddress<float>(outputs, 0);
|
||||
float *grads = GetDeviceAddress<float>(outputs, 1);
|
||||
|
||||
// Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
|
||||
int *labels_host = nullptr;
|
||||
int *no_blank_labels_host = nullptr;
|
||||
void *input_lengths_host = nullptr;
|
||||
void *label_lengths_host = nullptr;
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
AllocHostMem(&labels_host, &no_blank_labels_host, &input_lengths_host, &label_lengths_host, inputs);
|
||||
CopyToHostSync(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host, inputs, stream);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetCTCLossWorkspaceSize(
|
||||
cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size),
|
||||
"cudnnGetCTCLossWorkspaceSize failed.");
|
||||
void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size);
|
||||
if (workspace == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size;
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs,
|
||||
probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size),
|
||||
"cudnnCtcLoss failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace);
|
||||
FreeHostMem(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host);
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (probs_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
|
||||
}
|
||||
probs_dims_[0] = probs_shape[0];
|
||||
probs_dims_[1] = probs_shape[1];
|
||||
probs_dims_[2] = probs_shape[2];
|
||||
|
||||
auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
if (labels_dims.size() != 1 && labels_dims.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
|
||||
}
|
||||
label_size_ = sizeof(int);
|
||||
for (auto i : labels_dims) {
|
||||
label_size_ *= i;
|
||||
}
|
||||
|
||||
auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
input_lengths_size_ = input_length_dims[0] * sizeof(int);
|
||||
auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
label_lengths_size_ = label_length_dims[0] * sizeof(int);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_),
|
||||
"cudnnSetTensorNdDescriptorEx failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT,
|
||||
CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN),
|
||||
"cudnnSetCTCLossDescriptorEx failed.");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed.");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
|
||||
input_size_list_.push_back(label_size_);
|
||||
input_size_list_.push_back(input_lengths_size_);
|
||||
input_size_list_.push_back(label_lengths_size_);
|
||||
|
||||
output_size_list_.push_back(probs_dims_[1] * sizeof(float));
|
||||
output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
|
||||
}
|
||||
|
||||
private:
|
||||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
void AllocHostMem(int **labels_host, int **no_blank_labels_host, void **input_lengths_host, void **label_lengths_host,
|
||||
const std::vector<AddressPtr> &inputs) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
|
||||
}
|
||||
|
||||
void FreeHostMem(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
|
||||
}
|
||||
|
||||
void CopyToHostSync(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host,
|
||||
const std::vector<AddressPtr> &inputs, cudaStream_t stream) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(labels_host, inputs[1]->addr, inputs[1]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(input_lengths_host, inputs[2]->addr, inputs[2]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(label_lengths_host, inputs[3]->addr, inputs[3]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
|
||||
// remove blank element
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
|
||||
if (labels_host[i] != 0) {
|
||||
no_blank_labels_host[j] = labels_host[i];
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnTensorDescriptor_t probs_desc_;
|
||||
cudnnCTCLossDescriptor_t ctcloss_desc_;
|
||||
int probs_dims_[3] = {0};
|
||||
int label_size_;
|
||||
int input_lengths_size_;
|
||||
int label_lengths_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
Loading…
Reference in new issue