diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h index fb7796b022..304f0ab161 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -38,7 +38,7 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 1); S *index = GetDeviceAddress(outputs, 0); - CalArgmaxWithValue(input_size_ / sizeof(T), input, bound_, outerSize_, innerSize_, index, output, + CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output, reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu index a0687a2768..3313fc6853 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu @@ -18,41 +18,39 @@ #include "device/gpu/cuda_common.h" #include "include/cuda_fp16.h" template -__global__ void ArgmaxWithValue(size_t size, const T* input, const int bound, int outerSize, int innerSize, - S* index, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - for (int i = 0; i < outerSize; i++) { - int inputOutterOffset = i * innerSize * bound; - int outputOutterOffset = i * innerSize; - for (int j = 0; j < innerSize; j++) { - auto outputInnerOffset = outputOutterOffset + j; - S idx = 0; - T maxData = input[j + inputOutterOffset]; - for (S c = 0; c < bound; c++) { - int offset = j + c * innerSize; - auto inputData = input[inputOutterOffset + offset]; - idx = inputData > maxData ? c : idx; - maxData = inputData > maxData ? inputData : maxData; - } - output[outputInnerOffset] = maxData; - index[outputInnerOffset] = idx; +__global__ void ArgmaxWithValue(const T* input, const int bound, int outerSize, int innerSize, S* index, + T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outerSize); pos += blockDim.x * gridDim.x) { + int inputOutterOffset = pos * innerSize * bound; + int outputOutterOffset = pos * innerSize; + for (int j = 0; j < innerSize; j++) { + auto outputInnerOffset = outputOutterOffset + j; + S idx = 0; + T maxData = input[j + inputOutterOffset]; + for (S c = 0; c < bound; c++) { + int offset = j + c * innerSize; + auto inputData = input[inputOutterOffset + offset]; + idx = inputData > maxData ? c : idx; + maxData = inputData > maxData ? inputData : maxData; + } + output[outputInnerOffset] = maxData; + index[outputInnerOffset] = idx; } - } } return; } template -void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, +void CalArgmaxWithValue(const T* input, const int bound_, const int outerSize_, const int innerSize_, S* index, T* output, cudaStream_t cuda_stream) { - ArgmaxWithValue<<>>(size, input, bound_, outerSize_, innerSize_, - index, output); + ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, + index, output); return; } -template void CalArgmaxWithValue(size_t size, const float* input, const int bound_, const int outerSize_, +template void CalArgmaxWithValue(const float* input, const int bound_, const int outerSize_, const int innerSize_, int* index, float* output, cudaStream_t cuda_stream); -template void CalArgmaxWithValue(size_t size, const half* input, const int bound_, const int outerSize_, +template void CalArgmaxWithValue(const half* input, const int bound_, const int outerSize_, const int innerSize_, int* index, half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh index 0d4f4b62a3..67c061a966 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh @@ -17,6 +17,6 @@ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ template -void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, - S* index, T* output, cudaStream_t cuda_stream); +void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index, + T *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu index 4a8af83aa4..09b347e3d5 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu @@ -36,6 +36,13 @@ __global__ void LogarithmKernel(T *input, T *output, size_t count) { } return; } +template <> +__global__ void LogarithmKernel(half *input, half *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hlog(input[i]); + } + return; +} template __global__ void NegativeKernel(T *input, T *output, size_t count) { T neg_one = -1;