diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc index 0246707d16..5e80cccd75 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc @@ -27,5 +27,10 @@ MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut ActivationGpuFwdKernel, float) MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGpuFwdKernel, half) + +MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc index 506d2268f7..35d11f8b47 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc @@ -35,5 +35,14 @@ MS_REG_GPU_KERNEL_ONE( TanhGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGradGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE( + SigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + SigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_sigmoid_grad_op.py b/tests/st/ops/gpu/test_sigmoid_grad_op.py new file mode 100644 index 0000000000..92d1d4d9f7 --- /dev/null +++ b/tests/st/ops/gpu/test_sigmoid_grad_op.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class NetSigmoidGrad(nn.Cell): + def __init__(self): + super(NetSigmoidGrad, self).__init__() + self.sigmoid_grad = G.SigmoidGrad() + + def construct(self, y, dy): + return self.sigmoid_grad(y, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sigmoid_grad(): + y = Tensor(np.array([[[[-1, 1, 2], + [1, -1, 1], + [2, 1, -1]]]]).astype(np.float32)) + dy = Tensor(np.array([[[[-11, 2, 4], + [-1, 1, -1], + [-4, 4, -4]]]]).astype(np.float32)) + + expect = np.array([[[[22, 0, -8], + [0, -2, 0], + [8, 0, 8]]]]).astype(np.float32) + + error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + sigmoid_grad = NetSigmoidGrad() + output = sigmoid_grad(y, dy) + diff = output.asnumpy() - expect + assert np.all(abs(diff) < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + sigmoid_grad = NetSigmoidGrad() + output = sigmoid_grad(y, dy) + diff = output.asnumpy() - expect + assert np.all(abs(diff) < error) diff --git a/tests/st/ops/gpu/test_sigmoid_op.py b/tests/st/ops/gpu/test_sigmoid_op.py new file mode 100644 index 0000000000..f3d724a35b --- /dev/null +++ b/tests/st/ops/gpu/test_sigmoid_op.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetSigmoid(nn.Cell): + def __init__(self): + super(NetSigmoid, self).__init__() + self.sigmoid = P.Sigmoid() + + def construct(self, x): + return self.sigmoid(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sigmoid(): + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.float32)) + expect = np.array([[[[0.268941, 0.731059, 0.999955], + [0.731059, 0.268941, 0.731059], + [0.999955, 0.731059, 0.268941]]]]).astype(np.float32) + + error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + sigmoid = NetSigmoid() + output = sigmoid(x) + diff = output.asnumpy() - expect + assert np.all(abs(diff) < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + sigmoid = NetSigmoid() + output = sigmoid(x) + diff = output.asnumpy() - expect + assert np.all(abs(diff) < error)