|
|
|
@ -44,9 +44,6 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
|
|
|
|
float *probs = GetDeviceAddress<float>(inputs, 0);
|
|
|
|
|
int *labels = GetDeviceAddress<int>(inputs, 1);
|
|
|
|
|
int *input_lengths = GetDeviceAddress<int>(inputs, 2);
|
|
|
|
|
int *label_lengths = GetDeviceAddress<int>(inputs, 3);
|
|
|
|
|
float *costs = GetDeviceAddress<float>(outputs, 0);
|
|
|
|
|
float *grads = GetDeviceAddress<float>(outputs, 1);
|
|
|
|
|
|
|
|
|
@ -55,29 +52,9 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
int *no_blank_labels_host = nullptr;
|
|
|
|
|
void *input_lengths_host = nullptr;
|
|
|
|
|
void *label_lengths_host = nullptr;
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
|
|
|
|
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(labels_host, labels, inputs[1]->size, cudaMemcpyDeviceToHost, stream),
|
|
|
|
|
"cudaMemcpyAsync failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(input_lengths_host, input_lengths, inputs[2]->size, cudaMemcpyDeviceToHost, stream),
|
|
|
|
|
"cudaMemcpyAsync failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(label_lengths_host, label_lengths, inputs[3]->size, cudaMemcpyDeviceToHost, stream),
|
|
|
|
|
"cudaMemcpyAsync failed.");
|
|
|
|
|
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
|
|
|
|
|
|
|
|
|
size_t j = 0;
|
|
|
|
|
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
|
|
|
|
|
if (labels_host[i] != 0) {
|
|
|
|
|
no_blank_labels_host[j] = labels_host[i];
|
|
|
|
|
j++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
AllocHostMem(&labels_host, &no_blank_labels_host, &input_lengths_host, &label_lengths_host, inputs);
|
|
|
|
|
CopyToHostSync(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host, inputs, stream);
|
|
|
|
|
|
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
|
|
|
@ -99,10 +76,7 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
|
|
|
|
|
|
|
|
|
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace);
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
|
|
|
|
|
FreeHostMem(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
@ -160,6 +134,46 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed.");
|
|
|
|
|
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AllocHostMem(int **labels_host, int **no_blank_labels_host, void **input_lengths_host, void **label_lengths_host,
|
|
|
|
|
const std::vector<AddressPtr> &inputs) {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FreeHostMem(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host) {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CopyToHostSync(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host,
|
|
|
|
|
const std::vector<AddressPtr> &inputs, cudaStream_t stream) {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(labels_host, inputs[1]->addr, inputs[1]->size, cudaMemcpyDeviceToHost, stream),
|
|
|
|
|
"cudaMemcpyAsync failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(input_lengths_host, inputs[2]->addr, inputs[2]->size, cudaMemcpyDeviceToHost, stream),
|
|
|
|
|
"cudaMemcpyAsync failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(label_lengths_host, inputs[3]->addr, inputs[3]->size, cudaMemcpyDeviceToHost, stream),
|
|
|
|
|
"cudaMemcpyAsync failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
|
|
|
|
|
|
|
|
|
// remove blank element
|
|
|
|
|
size_t j = 0;
|
|
|
|
|
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
|
|
|
|
|
if (labels_host[i] != 0) {
|
|
|
|
|
no_blank_labels_host[j] = labels_host[i];
|
|
|
|
|
j++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> input_size_list_;
|
|
|
|
|
std::vector<size_t> output_size_list_;
|
|
|
|
|
std::vector<size_t> workspace_size_list_;
|
|
|
|
|