diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cu index b8718db289..76cdc7204c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cu @@ -19,20 +19,58 @@ template __global__ void SquareSumAllKernel(const size_t size, const T* input_addr_0, const T* input_addr_1, - T* output_addr_0, T* output_addr_1) { + float* ws_addr_0, float* ws_addr_1) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { size_t split = size / 2; + float power = 2.0; if (i < split) { - T ret = input_addr_0[i] * input_addr_0[i]; - MsAtomicAdd(output_addr_0, ret); + float ret = powf(__half2float(input_addr_0[i]), power); + MsAtomicAdd(ws_addr_0, ret); } else { - T ret = input_addr_1[i - split] * input_addr_1[i - split]; - MsAtomicAdd(output_addr_1, ret); + float ret = powf(__half2float(input_addr_1[i - split]), power); + MsAtomicAdd(ws_addr_1, ret); } } return; } +template <> +__global__ void SquareSumAllKernel(const size_t size, const float* input_addr_0, const float* input_addr_1, + float* ws_addr_0, float* ws_addr_1) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + size_t split = size / 2; + float power = 2.0; + if (i < split) { + float ret = powf(input_addr_0[i], power); + MsAtomicAdd(ws_addr_0, ret); + } else { + float ret = powf(input_addr_1[i - split], power); + MsAtomicAdd(ws_addr_1, ret); + } + } + return; +} + +template +__global__ void AssignKernel(const size_t size, T* output_addr_0, T* output_addr_1, + float* ws_addr_0, float* ws_addr_1) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + output_addr_0[0] = __float2half(ws_addr_0[0]); + output_addr_1[0] = __float2half(ws_addr_1[0]); + } + return; +} + +template <> +__global__ void AssignKernel(const size_t size, float* output_addr_0, float* output_addr_1, + float* ws_addr_0, float* ws_addr_1) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + output_addr_0[0] = ws_addr_0[0]; + output_addr_1[0] = ws_addr_1[0]; + } + return; +} + template __global__ void InitOutput(const size_t size, T *output) { T zero = 0; @@ -44,15 +82,19 @@ __global__ void InitOutput(const size_t size, T *output) { template void SquareSumAll(const size_t input_size_, const T* input_addr_0, const T* input_addr_1, - T* output_addr_0, T* output_addr_1, cudaStream_t cuda_stream) { - InitOutput<<>>(1, output_addr_0); - InitOutput<<>>(1, output_addr_1); + T* output_addr_0, T* output_addr_1, + float* ws_addr_0, float* ws_addr_1, cudaStream_t cuda_stream) { + InitOutput<<>>(1, ws_addr_0); + InitOutput<<>>(1, ws_addr_1); size_t size = input_size_ * 2; SquareSumAllKernel<<>>(size, input_addr_0, input_addr_1, - output_addr_0, output_addr_1); + ws_addr_0, ws_addr_1); + AssignKernel<<>>(1, output_addr_0, output_addr_1, ws_addr_0, ws_addr_1); } template void SquareSumAll(const size_t input_size_, const half* input_addr_0, const half* input_addr_1, - half* output_addr_0, half* output_addr_1, cudaStream_t cuda_stream); + half* output_addr_0, half* output_addr_1, float* ws_addr_0, float* ws_addr_1, + cudaStream_t cuda_stream); template void SquareSumAll(const size_t input_size_, const float* input_addr_0, const float* input_addr_1, - float* output_addr_0, float* output_addr_1, cudaStream_t cuda_stream); + float* output_addr_0, float* output_addr_1, float* ws_addr_0, float* ws_addr_1, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cuh index 6182786d89..f950385888 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/square_sum_all_impl.cuh @@ -20,6 +20,6 @@ #include "runtime/device/gpu/cuda_common.h" template void SquareSumAll(const size_t input_size_, const T* input_addr_0, const T* input_addr_1, - T* output_addr_0, T* output_addr_1, cudaStream_t cuda_stream); + T* output_addr_0, T* output_addr_1, float* ws_addr_0, float* ws_addr_1, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SQUARE_SUM_ALL_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/square_sum_all_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/square_sum_all_gpu_kernel.h index ef110e83b7..a095352b10 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/square_sum_all_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/square_sum_all_gpu_kernel.h @@ -43,7 +43,9 @@ class SquareSumAllGpuFwdKernel : public GpuKernel { T *input_addr_1 = GetDeviceAddress(inputs, 1); T *output_addr_0 = GetDeviceAddress(outputs, 0); T *output_addr_1 = GetDeviceAddress(outputs, 1); - SquareSumAll(input_size_, input_addr_0, input_addr_1, output_addr_0, output_addr_1, + float *ws_addr_0 = GetDeviceAddress(workspace, 0); + float *ws_addr_1 = GetDeviceAddress(workspace, 1); + SquareSumAll(input_size_, input_addr_0, input_addr_1, output_addr_0, output_addr_1, ws_addr_0, ws_addr_1, reinterpret_cast(stream_ptr)); return true; @@ -67,7 +69,8 @@ class SquareSumAllGpuFwdKernel : public GpuKernel { input_size_list_.push_back(input_size_ * sizeof(T)); output_size_list_.push_back(sizeof(T)); output_size_list_.push_back(sizeof(T)); - workspace_size_list_.push_back(0); + workspace_size_list_.push_back(sizeof(float)); + workspace_size_list_.push_back(sizeof(float)); } private: