From 37ac7ebcf6e55e7f6ff045eee459a00c28f9e5d1 Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Wed, 9 Dec 2020 15:51:16 +0800 Subject: [PATCH] fix bug of op l2loss on gpu --- .../backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu | 9 ++++++++- mindspore/ops/operations/math_ops.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu index 41103cc92b..b80db775ba 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu @@ -22,15 +22,22 @@ template __global__ void L2LossKernel(const size_t input_size, const T *input , T *output) { T ret = 0; for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < input_size; id += blockDim.x * gridDim.x) { - ret = (input[id] * input[id]); + ret = input[id] * input[id]; ret /= static_cast(2); MsAtomicAdd(output, ret); } return; } +template +__global__ void ClearOutputMem(T *output) { + output[0] = static_cast(0); + return; +} + template void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream) { + ClearOutputMem<<>>(output); L2LossKernel<<>>(input_size, input, output); } diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index cf40af3e26..3a8c29e520 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -445,7 +445,7 @@ class ReduceAll(_Reduce): the shape of output is :math:`(x_1, x_4, ..., x_R)`. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Examples: >>> input_x = Tensor(np.array([[True, False], [True, True]])) @@ -487,7 +487,7 @@ class ReduceAny(_Reduce): the shape of output is :math:`(x_1, x_4, ..., x_R)`. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Examples: >>> input_x = Tensor(np.array([[True, False], [True, True]]))