From f6e87143c6040c9acaf7306bcf0354f094da356d Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Wed, 25 Nov 2020 20:23:17 +0800 Subject: [PATCH] add support to op L2Loss on gpu --- .../kernel_compiler/gpu/cuda_impl/l2_loss.cu | 38 +++++++ .../kernel_compiler/gpu/cuda_impl/l2_loss.cuh | 21 ++++ .../gpu/nn/l2_loss_gpu_kernel.cc | 26 +++++ .../gpu/nn/l2_loss_gpu_kernel.h | 71 +++++++++++++ mindspore/ops/operations/nn_ops.py | 6 +- tests/st/ops/gpu/test_l2loss_op.py | 100 ++++++++++++++++++ 6 files changed, 258 insertions(+), 4 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_l2loss_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu new file mode 100644 index 0000000000..41103cc92b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cu @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "l2_loss.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" + +template +__global__ void L2LossKernel(const size_t input_size, const T *input , T *output) { + T ret = 0; + for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < input_size; id += blockDim.x * gridDim.x) { + ret = (input[id] * input[id]); + ret /= static_cast(2); + MsAtomicAdd(output, ret); + } + return; +} + +template +void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream) { + L2LossKernel<<>>(input_size, input, output); +} + +template void L2Loss(const size_t input_size, const float *input , float *output, cudaStream_t stream); +template void L2Loss(const size_t input_size, const half *input , half *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cuh new file mode 100644 index 0000000000..428451c84f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/l2_loss.cuh @@ -0,0 +1,21 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_L2_LOSS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_L2_LOSS_H_ +template +void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_L2_LOSS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.cc new file mode 100644 index 0000000000..bdd0850e15 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(L2Loss, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + L2LossGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(L2Loss, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + L2LossGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.h new file mode 100644 index 0000000000..b7b1cc585d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_L2_LOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_L2_LOSS_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/l2_loss.cuh" +namespace mindspore { +namespace kernel { +template +class L2LossGpuKernel : public GpuKernel { + public: + L2LossGpuKernel() : input_size_(1) {} + ~L2LossGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + L2Loss(input_size_, input, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(sizeof(T)); + } + + private: + size_t input_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_L2_LOSS_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 44a8b68505..748d75642a 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2157,9 +2157,7 @@ class L2Loss(PrimitiveWithInfer): Set `input_x` as x and output as loss. .. math:: - loss = sum(x ** 2) / nelement(x) - - :math:`nelement(x)` represents the number of `input_x`. + loss = sum(x ** 2) / 2 Inputs: - **input_x** (Tensor) - A input Tensor. Data type must be float16 or float32. @@ -2168,7 +2166,7 @@ class L2Loss(PrimitiveWithInfer): Tensor, has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Examples >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16) diff --git a/tests/st/ops/gpu/test_l2loss_op.py b/tests/st/ops/gpu/test_l2loss_op.py new file mode 100644 index 0000000000..822d5b9ffc --- /dev/null +++ b/tests/st/ops/gpu/test_l2loss_op.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor +from mindspore.ops import operations as P + + +class L2LossNet(nn.Cell): + def __init__(self): + super(L2LossNet, self).__init__() + self.l2_loss = P.L2Loss() + + def construct(self, x): + return self.l2_loss(x) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_pynative_fp32_22(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-4 + x = Tensor(np.array([[1., 2.], [3., 4.]]), ms.float32) + expect = np.array(15, np.float32) + output = P.L2Loss()(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_pynative_fp16_22(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-4 + x = Tensor(np.array([[1., 2.], [3., 4.]]), ms.float16) + expect = np.array(15, np.float16) + output = P.L2Loss()(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_pynative_fp32_14(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-4 + x = Tensor(np.array([1., 2., 3., 4.]), ms.float32) + expect = np.array(15, np.float32) + output = P.L2Loss()(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_pynative_fp16_14(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-4 + x = Tensor(np.array([1., 2., 3., 4.]), ms.float16) + expect = np.array(15, np.float16) + output = P.L2Loss()(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +def test_gather_graph_fp32_14(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + error = 1e-4 + x = Tensor(np.array([1., 2., 3., 4.]), ms.float32) + expect = np.array(15, np.float32) + l2_loss = L2LossNet() + output = l2_loss(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +def test_gather_graph_fp16_14(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + error = 1e-4 + x = Tensor(np.array([1., 2., 3., 4.]), ms.float16) + expect = np.array(15, np.float16) + l2_loss = L2LossNet() + output = l2_loss(x) + diff = output.asnumpy() - expect + assert np.all(diff < error)