From 6dc36187585509b6a282cfa4b973344d15a63bfc Mon Sep 17 00:00:00 2001 From: linqingke Date: Thu, 29 Oct 2020 10:28:53 +0800 Subject: [PATCH] new add softplus and softplus grad gpu ops. --- .../gpu/cuda_impl/softplus_impl.cu | 83 ++++++++++++++ .../gpu/cuda_impl/softplus_impl.cuh | 27 +++++ .../gpu/nn/softplus_gpu_kernel.cc | 26 +++++ .../gpu/nn/softplus_gpu_kernel.h | 72 ++++++++++++ .../gpu/nn/softplus_grad_gpu_kernel.cc | 30 +++++ .../gpu/nn/softplus_grad_gpu_kernel.h | 74 ++++++++++++ tests/st/ops/gpu/test_softplus_grad_op.py | 77 +++++++++++++ tests/st/ops/gpu/test_softplus_op.py | 106 ++++++++++++++++++ 8 files changed, 495 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_softplus_grad_op.py create mode 100644 tests/st/ops/gpu/test_softplus_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cu new file mode 100644 index 0000000000..90ecfa53ab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cu @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void SoftplusKernel(const size_t size, const T *input_addr, T *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + float x = input_addr[pos]; + output_addr[pos] = logf(1. + exp(x)); + } +} + +template <> +__global__ void SoftplusKernel(const size_t size, const half *input_addr, half *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + float x = __half2float(input_addr[pos]); + output_addr[pos] = __float2half(logf(1. + exp(x))); + } +} + +template +void Softplus(const size_t size, const T *input_addr, T *output_addr, cudaStream_t cuda_stream) { + SoftplusKernel<<>>(size, input_addr, output_addr); + return; +} + +template <> +void Softplus(const size_t size, const half *input_addr, half *output_addr, cudaStream_t cuda_stream) { + SoftplusKernel<<>>(size, input_addr, output_addr); + return; +} + +template +__global__ void SoftplusGradKernel(const size_t size, const T *dy_addr, const T *x_addr, T *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + T exp_x = exp(x_addr[pos]); + dx_addr[pos] = dy_addr[pos] * exp_x / (1. + exp_x); + } +} + +template +__global__ void SoftplusGradKernel(const size_t size, const half *dy_addr, const half *x_addr, half *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + float x = __half2float(x_addr[pos]); + float dy = __half2float(dy_addr[pos]); + float exp_x = exp(x); + dx_addr[pos] = __float2half(dy * exp_x / (1. + exp_x)); + } +} + +template +void SoftplusGrad(const size_t size, const T *dy_addr, const T *x_addr, T *dx_addr, cudaStream_t cuda_stream) { + SoftplusGradKernel<<>>(size, dy_addr, x_addr, dx_addr); + return; +} + +template <> +void SoftplusGrad(const size_t size, const half *dy_addr, const half *x_addr, half *dx_addr, cudaStream_t cuda_stream) { + SoftplusGradKernel<<>>(size, dy_addr, x_addr, dx_addr); + return; +} + +template void Softplus(const size_t size, const float *input_addr, float *output_addr, cudaStream_t cuda_stream); +template void Softplus(const size_t size, const half *input_addr, half *output_addr, cudaStream_t cuda_stream); +template void SoftplusGrad(const size_t size, const float *dy_addr, const float *x_addr, float *dx_addr, + cudaStream_t cuda_stream); +template void SoftplusGrad(const size_t size, const half *dy_addr, const half *x_addr, half *dx_addr, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh new file mode 100644 index 0000000000..639f71f87f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SOFTPLUS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SOFTPLUS_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void Softplus(const size_t input_size, const T* input_addr, T* output_addr, cudaStream_t cuda_stream); + +template +void SoftplusGrad(const size_t size, const T* dy_addr, const T* x_addr, T* dx_addr, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SOFTPLUS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.cc new file mode 100644 index 0000000000..1a4daf337f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftplusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftplusGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.h new file mode 100644 index 0000000000..752690b05e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SoftplusGpuKernel : public GpuKernel { + public: + SoftplusGpuKernel() : input_size_(0) {} + ~SoftplusGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + Softplus(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.cc new file mode 100644 index 0000000000..47c034ec55 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SoftplusGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftplusGpuGradKernel, float) +MS_REG_GPU_KERNEL_ONE( + SoftplusGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftplusGpuGradKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.h new file mode 100644 index 0000000000..8b3f2a221a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GRAD_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SoftplusGpuGradKernel : public GpuKernel { + public: + SoftplusGpuGradKernel() : input_size_(0) {} + ~SoftplusGpuGradKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *dy_addr = GetDeviceAddress(inputs, 0); + T *x_addr = GetDeviceAddress(inputs, 1); + T *dx_addr = GetDeviceAddress(outputs, 0); + + SoftplusGrad(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GRAD_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_softplus_grad_op.py b/tests/st/ops/gpu/test_softplus_grad_op.py new file mode 100644 index 0000000000..f3dda50877 --- /dev/null +++ b/tests/st/ops/gpu/test_softplus_grad_op.py @@ -0,0 +1,77 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class SoftplusNet(nn.Cell): + def __init__(self): + super(SoftplusNet, self).__init__() + self.softplus = P.Softplus() + + def construct(self, x): + return self.softplus(x) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, input_data, sens): + gout = self.grad(self.network)(input_data, sens) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softplusgrad(): + x = np.array([0.58401114, 0.68800163, 0.9760397, 0.14702141, 0.46563736, 0.9607501, + 0.14567593, 0.12261796, 0.37054458, 0.46421242]).astype(np.float32) + dy = np.array([0.5559598, 0.96994054, 0.24770357, 0.34646875, 0.2984393, 0.03287048, + 0.55681044, 0.966908, 0.06015943, 0.6099489]).astype(np.float32) + x_ms = Tensor(x) + dy_ms = Tensor(dy) + + net = SoftplusNet() + grad = Grad(net) + + output = grad(x_ms, dy_ms) + expect = dy * np.exp(x) / (1 + np.exp(x)) + assert np.allclose(output[0].asnumpy(), expect, rtol=1e-3) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softplusgrad_fp16(): + np.random.seed(42) + x_np = np.random.randn(5, 3, 6).astype(np.float16) + dy_np = np.random.randn(5, 3, 6).astype(np.float16) + net = SoftplusNet() + grad = Grad(net) + output = grad(Tensor(x_np), Tensor(dy_np)) + expect = dy_np * np.exp(x_np) / (1 + np.exp(x_np)) + assert np.allclose(output[0].asnumpy(), expect, rtol=1e-2) diff --git a/tests/st/ops/gpu/test_softplus_op.py b/tests/st/ops/gpu/test_softplus_op.py new file mode 100644 index 0000000000..1ee9d3c9c4 --- /dev/null +++ b/tests/st/ops/gpu/test_softplus_op.py @@ -0,0 +1,106 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class SoftplusNet(nn.Cell): + def __init__(self): + super(SoftplusNet, self).__init__() + self.softplus = P.Softplus() + + def construct(self, x): + return self.softplus(x) + + +def SoftplusCompute(x): + return np.log(1 + np.exp(x)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softplus_1d(): + x_np = np.random.random((50,)).astype(np.float32) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softplus_2d(): + x_np = np.random.random((50, 40)).astype(np.float32) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softplus_4d(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softplus_neg(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) * -1 + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softplus_4d_fp16(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float16) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy(), rtol=5e-3)