You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							101 lines
						
					
					
						
							3.6 KiB
						
					
					
				
			
		
		
	
	
							101 lines
						
					
					
						
							3.6 KiB
						
					
					
				/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
 | 
						|
#include <stdio.h>
 | 
						|
#include <thrust/device_vector.h>
 | 
						|
#include <thrust/host_vector.h>
 | 
						|
#include <vector>
 | 
						|
#include "paddle/fluid/operators/ctc_align_op.h"
 | 
						|
 | 
						|
namespace paddle {
 | 
						|
namespace operators {
 | 
						|
 | 
						|
template <typename T>
 | 
						|
__global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
 | 
						|
                                      const size_t num_seq, size_t* lod0,
 | 
						|
                                      const int blank, const int merge_repeated,
 | 
						|
                                      size_t* out_lod0, T* output) {
 | 
						|
  int ouput_idx = 0;
 | 
						|
  out_lod0[0] = 0;
 | 
						|
 | 
						|
  for (int i = 0; i < num_seq; ++i) {
 | 
						|
    T pre_token = -1;
 | 
						|
    for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
 | 
						|
      if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
 | 
						|
        output[ouput_idx] = tokens[j];
 | 
						|
        ++ouput_idx;
 | 
						|
      }
 | 
						|
      pre_token = tokens[j];
 | 
						|
    }
 | 
						|
    out_lod0[i + 1] = ouput_idx;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
template <typename T>
 | 
						|
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
						|
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
 | 
						|
                   "It must use CUDAPlace.");
 | 
						|
    const size_t level = 0;
 | 
						|
    auto* input = ctx.Input<LoDTensor>("Input");
 | 
						|
    auto* output = ctx.Output<LoDTensor>("Output");
 | 
						|
    auto input_lod = framework::ToAbsOffset(input->lod());
 | 
						|
 | 
						|
    const T* tokens = input->data<T>();
 | 
						|
    const int64_t num_tokens = input->dims()[0];
 | 
						|
    const size_t num_seq = input_lod[level].size() - 1;
 | 
						|
 | 
						|
    const int blank = ctx.Attr<int>("blank");
 | 
						|
    const int merge_repeated =
 | 
						|
        static_cast<int>(ctx.Attr<bool>("merge_repeated"));
 | 
						|
 | 
						|
    // prepare a lod to record lod information while merging elements
 | 
						|
    thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
 | 
						|
    size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
 | 
						|
 | 
						|
    // merge elements and delete blank
 | 
						|
    T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace());
 | 
						|
 | 
						|
    auto stream = ctx.cuda_device_context().stream();
 | 
						|
    MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
 | 
						|
        num_tokens, tokens, num_seq,
 | 
						|
        input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, merge_repeated,
 | 
						|
        dev_out_lod0_ptr, output_data);
 | 
						|
 | 
						|
    // set output lod
 | 
						|
    std::vector<size_t> host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end());
 | 
						|
    framework::LoD out_lod;
 | 
						|
    out_lod.push_back(host_out_lod0);
 | 
						|
    output->set_lod(out_lod);
 | 
						|
 | 
						|
    // resize output dims
 | 
						|
    output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
 | 
						|
 | 
						|
    if (host_out_lod0.back() == 0) {
 | 
						|
      output->Resize({1, 1});
 | 
						|
      output->mutable_data<T>(ctx.GetPlace());
 | 
						|
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
 | 
						|
      set_constant(ctx.template device_context<platform::CUDADeviceContext>(),
 | 
						|
                   output, -1);
 | 
						|
    }
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
}  // namespace operators
 | 
						|
}  // namespace paddle
 | 
						|
 | 
						|
REGISTER_OP_CUDA_KERNEL(ctc_align, paddle::operators::CTCAlignOpCUDAKernel<int>,
 | 
						|
                        paddle::operators::CTCAlignOpCUDAKernel<int64_t>);
 |