diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.cc new file mode 100644 index 0000000000..c141f18ba1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + Unique, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + UniqueGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + Unique, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + UniqueGpuKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + Unique, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UniqueGpuKernel, int, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.h new file mode 100644 index 0000000000..237bb1cbec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.h @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNIQUEGPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNIQUEGPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/unique_impl.cuh" +namespace mindspore { +namespace kernel { +template +class UniqueGpuKernel : public GpuKernel { + public: + UniqueGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), num_elements_(1), post_output_size_(0) {} + ~UniqueGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + S *input_index = GetDeviceAddress(workspace, 0); + S *sorted_index = GetDeviceAddress(workspace, 1); + T *output = GetDeviceAddress(outputs, 0); + S *index = GetDeviceAddress(outputs, 1); + stream_ptr_ = stream_ptr; + post_output_size_ = CalUnique(input, num_elements_, input_index, sorted_index, output, index, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto x : shape) { + num_elements_ *= x; + } + input_size_ = num_elements_ * sizeof(T); + output_size_ = input_size_; + workspace_size_ = num_elements_ * sizeof(S); + InitSizeLists(); + return true; + } + + void PostExecute() override { + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "cudaStreamSynchronized failed"); + std::vector type_ids; + std::vector> shapes; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node_); + for (size_t i = 0; i < output_num; ++i) { + std::vector shape = AnfAlgo::GetOutputInferShape(kernel_node_, i); + if (i == 0) { + shape[0] = post_output_size_; + } + TypeId type_id = AnfAlgo::GetOutputInferDataType(kernel_node_, i); + type_ids.emplace_back(type_id); + shapes.emplace_back(shape); + } + AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, kernel_node_.get()); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(num_elements_ * sizeof(S)); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + } + + private: + void *stream_ptr_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + int num_elements_; + int post_output_size_; + CNodePtr kernel_node_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNIQUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unique_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unique_impl.cu new file mode 100644 index 0000000000..9b08d55204 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unique_impl.cu @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "unique_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" + +template +int CalUnique(const T *input, int num_elements, S *input_index, S *sorted_index, T *output, S *index, + cudaStream_t cuda_stream) { + auto policy = thrust::cuda::par.on(cuda_stream); + thrust::sequence(policy, + thrust::device_pointer_cast(sorted_index), + thrust::device_pointer_cast(sorted_index) + num_elements); + thrust::copy(thrust::device_pointer_cast(input), + thrust::device_pointer_cast(input) + num_elements, + thrust::device_pointer_cast(output)); + thrust::stable_sort_by_key(policy, + thrust::device_pointer_cast(output), + thrust::device_pointer_cast(output) + num_elements, + thrust::device_pointer_cast(sorted_index)); + thrust::adjacent_difference(policy, + thrust::device_pointer_cast(output), + thrust::device_pointer_cast(output) + num_elements, + thrust::device_pointer_cast(input_index), + thrust::not_equal_to()); + thrust::fill(policy, + thrust::device_pointer_cast(input_index), + thrust::device_pointer_cast(input_index) + 1, + 0); + thrust::inclusive_scan(policy, + thrust::device_pointer_cast(input_index), + thrust::device_pointer_cast(input_index) + num_elements, + thrust::device_pointer_cast(input_index)); + thrust::scatter(policy, + thrust::device_pointer_cast(input_index), + thrust::device_pointer_cast(input_index) + num_elements, + thrust::device_pointer_cast(sorted_index), + thrust::device_pointer_cast(index)); + thrust::device_ptr output_end; + output_end = thrust::unique(policy, + thrust::device_pointer_cast(output), + thrust::device_pointer_cast(output) + num_elements); + int output_size = thrust::distance(thrust::device_pointer_cast(output), output_end); + return output_size; +} + +template int CalUnique(const float *input, int num_elements, int *input_index, int *sorted_index, + float *output, int *index, cudaStream_t cuda_stream); +template int CalUnique(const half *input, int num_elements, int *input_index, int *sorted_index, + half *output, int *index, cudaStream_t cuda_stream); +template int CalUnique(const int *input, int num_elements, int *input_index, int *sorted_index, + int *output, int *index, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unique_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unique_impl.cuh new file mode 100644 index 0000000000..92dc4740df --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unique_impl.cuh @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_UNIQUE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_UNIQUE_H_ +template +int CalUnique(const T *input, int num_elements, S *input_index, S *sorted_index, T *output, S *index, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_UNIQUE_H_ diff --git a/tests/st/ops/gpu/test_unique_op.py b/tests/st/ops/gpu/test_unique_op.py new file mode 100644 index 0000000000..4aa45a1c95 --- /dev/null +++ b/tests/st/ops/gpu/test_unique_op.py @@ -0,0 +1,226 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetUnique(nn.Cell): + def __init__(self): + super(NetUnique, self).__init__() + self.unique = P.Unique() + + def construct(self, x): + x_unique, x_idx = self.unique(x) + return x_unique, x_idx + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_1d(): + x = Tensor(np.array([4, 5, 1, 2, 3, 3, 4, 5]).astype(np.float32)) + exp_output = np.array([1, 2, 3, 4, 5]).astype(np.float32) + exp_idx = np.array([3, 4, 0, 1, 2, 2, 3, 4]).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_1d_float(): + x = Tensor(np.array([0.4, 0.5, 1.23, 2.2, 12.43, 12.43, 0.4, 0.5]).astype(np.float32)) + exp_output = np.array([0.4, 0.5, 1.23, 2.2, 12.43]).astype(np.float32) + exp_idx = np.array([0, 1, 2, 3, 4, 4, 0, 1]).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_1d_sorted(): + x = Tensor(np.array([1, 1, 2, 4, 4, 4, 7, 8, 8]).astype(np.float32)) + exp_output = np.array([1, 2, 4, 7, 8]).astype(np.float32) + exp_idx = np.array([0, 0, 1, 2, 2, 2, 3, 4, 4]).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_zeros(): + x = Tensor(np.zeros(1000).astype(np.float32)) + exp_output = np.zeros(1).astype(np.float32) + exp_idx = np.zeros(1000).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_large(): + x_np1 = np.arange(100) + x_np2 = np.arange(100, 200) + x_np3 = np.arange(200, 300) + x_np = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)) + x = Tensor(x_np.astype(np.float32)) + exp_output = np.arange(300).astype(np.float32) + exp_idx = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_1d_half(): + x = Tensor(np.array([0.4, 0.5, 1.23, 2.2, 12.43, 12.43, 0.4, 0.5]).astype(np.float16)) + exp_output = np.array([0.4, 0.5, 1.23, 2.2, 12.43]).astype(np.float16) + exp_idx = np.array([0, 1, 2, 3, 4, 4, 0, 1]).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_1d_sorted_half(): + x = Tensor(np.array([1, 1, 2, 4, 4, 4, 7, 8, 8]).astype(np.float16)) + exp_output = np.array([1, 2, 4, 7, 8]).astype(np.float16) + exp_idx = np.array([0, 0, 1, 2, 2, 2, 3, 4, 4]).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_zeros_half(): + x = Tensor(np.zeros(1000).astype(np.float16)) + exp_output = np.zeros(1).astype(np.float16) + exp_idx = np.zeros(1000).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_large_half(): + x_np1 = np.arange(100) + x_np2 = np.arange(100, 200) + x_np3 = np.arange(200, 300) + x_np = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)) + x = Tensor(x_np.astype(np.float16)) + exp_output = np.arange(300).astype(np.float16) + exp_idx = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_1d_int32(): + x = Tensor(np.array([4, 5, 1, 2, 3, 3, 4, 5]).astype(np.int32)) + exp_output = np.array([1, 2, 3, 4, 5]).astype(np.int32) + exp_idx = np.array([3, 4, 0, 1, 2, 2, 3, 4]).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_1d_sorted_int32(): + x = Tensor(np.array([1, 1, 2, 4, 4, 4, 7, 8, 8]).astype(np.int32)) + exp_output = np.array([1, 2, 4, 7, 8]).astype(np.int32) + exp_idx = np.array([0, 0, 1, 2, 2, 2, 3, 4, 4]).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_zeros_int32(): + x = Tensor(np.zeros(1000).astype(np.int32)) + exp_output = np.zeros(1).astype(np.int32) + exp_idx = np.zeros(1000).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_large_int32(): + x_np1 = np.arange(100) + x_np2 = np.arange(100, 200) + x_np3 = np.arange(200, 300) + x_np = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)) + x = Tensor(x_np.astype(np.int32)) + exp_output = np.arange(300).astype(np.int32) + exp_idx = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)).astype(np.int32) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetUnique() + x_unique, x_idx = net(x) + assert (x_unique.asnumpy() == exp_output).all() + assert (x_idx.asnumpy() == exp_idx).all()