diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu new file mode 100644 index 0000000000..a0082b84c8 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" + +template +__global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, + T *outputs) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (logits[i] >= 0) { + outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; + } else { + const T exp_val = exp(logits[i]); + outputs[i] = exp_val / (1. + exp_val) - labels[i]; + } + } +} + +template +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream) { + SigmoidCrossEntropyWithLogitsGradKernel<<>>(size, logits, labels, + outputs); +} + +template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const float *logits, + const float *labels, float *outputs, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh new file mode 100644 index 0000000000..2cd4922d25 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ + +#include "device/gpu/cuda_common.h" +template +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc new file mode 100644 index 0000000000..dabc4df850 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h new file mode 100644 index 0000000000..01f416f6b7 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { + public: + SigmoidCrossEntropyWithLogitsGradGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} + ~SigmoidCrossEntropyWithLogitsGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *outputs_addr = GetDeviceAddress(outputs, 0); + + SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogitsGrad needs 3 inputs."; + return false; + } + logits_size_ = sizeof(T); + labels_size_ = sizeof(S); + outputs_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + logits_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + labels_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + outputs_size_ *= output_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(outputs_size_); + } + + private: + size_t logits_size_; + size_t labels_size_; + size_t outputs_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py b/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py new file mode 100644 index 0000000000..a548cab0e7 --- /dev/null +++ b/tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class NetSigmoidCrossEntropyWithLogits(nn.Cell): + def __init__(self): + super(NetSigmoidCrossEntropyWithLogits, self).__init__() + self.sigmoid_cross_entropy_with_logits_grad = G.SigmoidCrossEntropyWithLogitsGrad() + + def construct(self, logits, labels, dout): + return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sigmoid_cross_entropy_with_logits(): + logits = Tensor(np.array([[1, 1, 2], + [1, 2, 1], + [2, 1, 1]]).astype(np.float32)) + labels = Tensor(np.array([[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]).astype(np.float32)) + dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) + + expect = np.array([[0.731059, 0.731059, -0.119203], + [0.731059, -0.119203, 0.731059], + [-0.119203, 0.731059, 0.731059]]).astype(np.float32) + + error = np.ones(shape=[3, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() + output = sigmoid_cross_entropy_with_logits(logits, labels, dout) + diff = output.asnumpy() - expect + assert np.all(abs(diff) < error) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() + output = sigmoid_cross_entropy_with_logits(logits, labels, dout) + diff = output.asnumpy() - expect + assert np.all(abs(diff) < error)