diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu index e83dbff060..68cc73fabd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,10 +21,10 @@ __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const const T *dout_addr, T *outputs) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (logits[i] >= 0) { - outputs[i] = (1. / (1. + exp(-logits[i])) - labels[i]) * dout_addr[i]; + outputs[i] = (static_cast(1.) / (static_cast(1.) + exp(-logits[i])) - labels[i]) * dout_addr[i]; } else { const T exp_val = exp(logits[i]); - outputs[i] = (exp_val / (1. + exp_val) - labels[i]) * dout_addr[i]; + outputs[i] = (exp_val / (static_cast(1.) + exp_val) - labels[i]) * dout_addr[i]; } } } @@ -39,3 +39,6 @@ void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const float *logits, const float *labels, const float *dout_addr, float *outputs, cudaStream_t cuda_stream); +template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const double *logits, + const double *labels, const double *dout_addr, + double *outputs, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu index 7425ac3809..823267f240 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,8 @@ template __global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { const T reverse_factor = static_cast(logits[i] >= 0); - outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); + outputs[i] = + log1p(exp(logits[i] - static_cast(2) * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); } } @@ -32,3 +33,6 @@ void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S * template void SigmoidCrossEntropyWithLogits(const size_t size, const float *logits, const float *labels, float *outputs, cudaStream_t cuda_stream); +template void SigmoidCrossEntropyWithLogits(const size_t size, const double *logits, + const double *labels, double *outputs, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc index 96d2d29549..f331a162ce 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO( SigmoidCrossEntropyWithLogits, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SigmoidCrossEntropyWithLogitsGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + SigmoidCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + SigmoidCrossEntropyWithLogitsGpuKernel, double, double) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc index 05c9a4234b..deee2afc80 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + SigmoidCrossEntropyWithLogitsGradGpuKernel, double, double) } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py b/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py index a548cab0e7..356f38ebab 100644 --- a/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py +++ b/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,32 +31,43 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell): return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_sigmoid_cross_entropy_with_logits(): +def sigmoid_cross_entropy_with_logits_grad(nptype): logits = Tensor(np.array([[1, 1, 2], [1, 2, 1], - [2, 1, 1]]).astype(np.float32)) + [2, 1, 1]]).astype(nptype)) labels = Tensor(np.array([[0, 0, 1], [0, 1, 0], - [1, 0, 0]]).astype(np.float32)) - dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) + [1, 0, 0]]).astype(nptype)) + dout = Tensor(np.ones(shape=[3, 3]).astype(nptype)) expect = np.array([[0.731059, 0.731059, -0.119203], [0.731059, -0.119203, 0.731059], - [-0.119203, 0.731059, 0.731059]]).astype(np.float32) + [-0.119203, 0.731059, 0.731059]]).astype(nptype) error = np.ones(shape=[3, 3]) * 1.0e-6 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() - output = sigmoid_cross_entropy_with_logits(logits, labels, dout) + net = NetSigmoidCrossEntropyWithLogits() + output = net(logits, labels, dout) diff = output.asnumpy() - expect assert np.all(abs(diff) < error) context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() - output = sigmoid_cross_entropy_with_logits(logits, labels, dout) + net = NetSigmoidCrossEntropyWithLogits() + output = net(logits, labels, dout) diff = output.asnumpy() - expect assert np.all(abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sigmoid_cross_entropy_with_logits_float32(): + sigmoid_cross_entropy_with_logits_grad(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sigmoid_cross_entropy_with_logits_float64(): + sigmoid_cross_entropy_with_logits_grad(np.float64) diff --git a/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_op.py b/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_op.py index e3f8512e9c..bfc95c13d8 100644 --- a/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_op.py +++ b/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,30 +31,41 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell): return self.loss(logits, labels) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_sigmoid_cross_entropy_with_logits(): +def sigmoid_cross_entropy_with_logits(nptype): logits = Tensor(np.array([[1, 1, 2], [1, 2, 1], - [2, 1, 1]]).astype(np.float32)) + [2, 1, 1]]).astype(nptype)) labels = Tensor(np.array([[0, 0, 1], [0, 1, 0], - [1, 0, 0]]).astype(np.float32)) + [1, 0, 0]]).astype(nptype)) expect_loss = np.array([[1.313262, 1.313262, 0.126928], [1.313262, 0.126928, 1.313262], - [0.126928, 1.313262, 1.313262]]).astype(np.float32) + [0.126928, 1.313262, 1.313262]]).astype(nptype) error = np.ones(shape=[3, 3]) * 1.0e-6 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() - output = sigmoid_cross_entropy_with_logits(logits, labels) + net = NetSigmoidCrossEntropyWithLogits() + output = net(logits, labels) diff = output.asnumpy() - expect_loss assert np.all(abs(diff) < error) context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() - output = sigmoid_cross_entropy_with_logits(logits, labels) + net = NetSigmoidCrossEntropyWithLogits() + output = net(logits, labels) diff = output.asnumpy() - expect_loss assert np.all(abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sigmoid_cross_entropy_with_logits_float32(): + sigmoid_cross_entropy_with_logits(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sigmoid_cross_entropy_with_logits_float64(): + sigmoid_cross_entropy_with_logits(np.float64)