From 04f4be48184f4628ec452eab37ffde2d010cb187 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Wed, 19 Aug 2020 20:12:50 +0800 Subject: [PATCH] fix gpu loss grad --- .../gpu/cuda_impl/loss_with_reduction_impl.cu | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu index 9a2c560bc7..edf1929261 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu @@ -97,10 +97,14 @@ __global__ void KLDivLossGradKernel(const int input_size, const int reduction, c dy[i] = (logf(denominator) + 1 - input_x[i]) * dloss[i]; } } else { + T dloss1 = dloss[0]; + if (reduction == 1) { + dloss1 = dloss[0] / input_size; + } for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T denominator = max(input_y[i], epsilon); - dx[i] = -input_y[i] * dloss[0]; - dy[i] = (logf(denominator) + 1 - input_x[i]) * dloss[0]; + dx[i] = -input_y[i] * dloss1; + dy[i] = (logf(denominator) + 1 - input_x[i]) * dloss1; } } } @@ -169,10 +173,14 @@ __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int dx[i] = value * dloss[i]; } } else { + T dloss1 = dloss[0]; + if (reduction == 1) { + dloss1 = dloss[0] / input_size; + } for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T denominator = max(input_x[i] * (1 - input_x[i]), epsilon); T value = weight[i] * (input_x[i] - input_y[i]) / denominator; - dx[i] = value * dloss[0]; + dx[i] = value * dloss1; } } }