From a304304c30705e117492592cf3df2963582702ec Mon Sep 17 00:00:00 2001 From: wilfChen Date: Thu, 7 May 2020 15:39:27 +0800 Subject: [PATCH] gpu support Gelu & GeluGrad kernels --- .../ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu | 65 ++++++++++++++ .../ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh | 27 ++++++ .../ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc | 29 ++++++ .../ccsrc/kernel/gpu/nn/gelu_grad_kernel.h | 75 ++++++++++++++++ mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc | 24 +++++ mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h | 72 +++++++++++++++ tests/st/ops/gpu/test_gelu_grad_op.py | 61 +++++++++++++ tests/st/ops/gpu/test_gelu_op.py | 88 +++++++++++++++++++ 8 files changed, 441 insertions(+) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h create mode 100644 mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h create mode 100644 tests/st/ops/gpu/test_gelu_grad_op.py create mode 100644 tests/st/ops/gpu/test_gelu_op.py diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu new file mode 100644 index 0000000000..bb476179d5 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "kernel/gpu/cuda_impl/gelu_impl.cuh" +#include "device/gpu/cuda_common.h" + +template +__global__ void GeluKernel(size_t size, T* input_addr, T* output_addr) { + // formula: + // gelu(x) = 0.5 * x * (1.0 + tanh(y)) + // tanh(y) = 2 / (1 + exp(-2y)) - 1) + // y = sqrt(2/pi) * (x + 0.044715 * x^3) + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + float x = input_addr[pos]; + float tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); + output_addr[pos] = 0.5 * x * (1.0 + tanh_res); + } +} + +template +void Gelu(size_t size, T* input_addr, T* output_addr, cudaStream_t cuda_stream) { + GeluKernel<<>>(size, input_addr, output_addr); + return; +} + + +template +__global__ void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr) { + // formula: + // dx = dy * y' + // y' = 0.5 * (1 + tanh(tanh_para)) + + // 0.5 * x * (1 - tanh(tanh_para) * tanh(tanh_para)) * mul_right + // tanh_para = sqrt(2/pi) * (x + 0.044715 * x^3) + // mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)) + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + T x = x_addr[pos]; + T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); + T mul_right = 0.7978845608 + 0.1070322244 * x * x; + T y_res = 0.5 * (1 + tanh_res) + 0.5 * x * (1 - tanh_res * tanh_res) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream) { + GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); +} + + +template void Gelu(size_t size, float* input_addr, float* output_addr, cudaStream_t cuda_stream); +template void GeluGradKernel(size_t size, float* dy_addr, float* x_addr, float* dx_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh new file mode 100644 index 0000000000..7a8e1fae8a --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ + +#include "device/gpu/cuda_common.h" +template +void Gelu(size_t input_size, T* input_addr, T* output_addr, cudaStream_t cuda_stream); + +template +void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc new file mode 100644 index 0000000000..2b6c53aa28 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/gelu_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(GeluGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + GeLUGpuGradKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h new file mode 100644 index 0000000000..7ce6d4d491 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/kernel_constants.h" +#include "kernel/gpu/cuda_impl/gelu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class GeLUGpuGradKernel : public GpuKernel { + public: + GeLUGpuGradKernel() : input_size_(0) {} + ~GeLUGpuGradKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + T *dy_addr = GetDeviceAddress(inputs, 0); + T *x_addr = GetDeviceAddress(inputs, 1); + T *dx_addr = GetDeviceAddress(outputs, 0); + + GeluGradKernel(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc new file mode 100644 index 0000000000..604dee04c4 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/gelu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + GeluGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h new file mode 100644 index 0000000000..f0dd37dec4 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/kernel_constants.h" +#include "kernel/gpu/cuda_impl/gelu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class GeluGpuKernel : public GpuKernel { + public: + GeluGpuKernel() : input_size_(0) {} + ~GeluGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + Gelu(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_gelu_grad_op.py b/tests/st/ops/gpu/test_gelu_grad_op.py new file mode 100644 index 0000000000..5891868d23 --- /dev/null +++ b/tests/st/ops/gpu/test_gelu_grad_op.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class GeluNet(nn.Cell): + def __init__(self): + super(GeluNet, self).__init__() + self.gelu = P.Gelu() + + def construct(self, x): + return self.gelu(x) + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, input_data, sens): + gout = self.grad(self.network)(input_data, sens) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gelugrad(): + x_ms = Tensor(np.array([0.58401114, 0.68800163, 0.9760397, 0.14702141, 0.46563736, 0.9607501, + 0.14567593, 0.12261796, 0.37054458, 0.46421242]).astype(np.float32)) + dy_ms = Tensor(np.array([0.5559598, 0.96994054, 0.24770357, 0.34646875, 0.2984393, 0.03287048, + 0.55681044, 0.966908, 0.06015943, 0.6099489 ]).astype(np.float32)) + + net = GeluNet() + grad = Grad(net) + + output = grad(x_ms, dy_ms) + print(output) + expect = [0.50963277, 0.9414753, 0.2667653, 0.21358444, 0.25243032, 0.0352667, + 0.34266686, 0.57757664, 0.04707306, 0.51536125] + assert np.allclose(output[0].asnumpy(), expect) \ No newline at end of file diff --git a/tests/st/ops/gpu/test_gelu_op.py b/tests/st/ops/gpu/test_gelu_op.py new file mode 100644 index 0000000000..9238bbc71c --- /dev/null +++ b/tests/st/ops/gpu/test_gelu_op.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class GeluNet(nn.Cell): + def __init__(self): + super(GeluNet, self).__init__() + self.gelu = P.Gelu() + + def construct(self, x): + return self.gelu(x) + + +def GeluCompute(x): + return 0.5 * x * (1.0 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x * x * x))) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gelu_1d(): + x_np = np.random.random((50,)).astype(np.float32) + y_np = GeluCompute(x_np) + + x_ms = Tensor(x_np) + net = GeluNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gelu_2d(): + x_np = np.random.random((50, 40)).astype(np.float32) + y_np = GeluCompute(x_np) + + x_ms = Tensor(x_np) + net = GeluNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gelu_4d(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) + y_np = GeluCompute(x_np) + + x_ms = Tensor(x_np) + net = GeluNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gelu_neg(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) * -1 + y_np = GeluCompute(x_np) + + x_ms = Tensor(x_np) + net = GeluNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy())