From ace4f2fb7104cffabf6c6507f0380c3839bc50cf Mon Sep 17 00:00:00 2001 From: zhaoting Date: Fri, 25 Dec 2020 14:35:20 +0800 Subject: [PATCH] add CPU TopK --- .../kernel_compiler/cpu/topk_cpu_kernel.cc | 87 +++++++++++++++++++ .../kernel_compiler/cpu/topk_cpu_kernel.h | 59 +++++++++++++ mindspore/ops/operations/nn_ops.py | 2 +- tests/st/ops/cpu/test_topk_op.py | 82 +++++++++++++++++ 4 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_topk_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.cc new file mode 100644 index 0000000000..01f29b01a9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/topk_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void TopKCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + if (inputs.size() != 2 || outputs.size() != 2) { + MS_LOG(EXCEPTION) << "TopK needs 2 inputs and 2 outputs, but get inputs: " << inputs.size() + << "outputs: " << outputs.size(); + } + if (inputs[0]->size != outer_size_ * inner_size_ * sizeof(T)) { + MS_LOG(EXCEPTION) << "Error input data size!"; + } + if (inputs[1]->size != sizeof(int)) { + MS_LOG(EXCEPTION) << "Input K must be int!"; + } + auto input = reinterpret_cast(inputs[0]->addr); + int k = reinterpret_cast(inputs[1]->addr)[0]; + auto output = reinterpret_cast(outputs[0]->addr); + auto indices = reinterpret_cast(outputs[1]->addr); + if (k < 1) { + MS_LOG(EXCEPTION) << "Input k must > 0!"; + } + int k_num = std::min(inner_size_, k); + if (outputs[0]->size != outer_size_ * k_num * sizeof(T)) { + MS_LOG(EXCEPTION) << "Error output data size!"; + } + for (size_t i = 0; i < outer_size_; ++i) { + std::vector idx(inner_size_); + auto base_input = i * inner_size_; + std::iota(idx.begin(), idx.end(), base_input); + std::sort(idx.begin(), idx.end(), + [&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; }); + auto base_output = i * k_num; + if (!sorted_) { + std::sort(idx.begin(), idx.begin() + k_num); + } + for (int j = 0; j < k_num; ++j) { + indices[base_output + j] = idx[j] - base_input; + output[base_output + j] = input[idx[j]]; + } + } +} + +void TopKCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < x_shape_.size() - 1; ++i) { + outer_size_ *= x_shape_[i]; + } + inner_size_ = x_shape_[x_shape_.size() - 1]; + sorted_ = AnfAlgo::GetNodeAttr(kernel_node, "sorted"); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); +} + +bool TopKCPUKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.h new file mode 100644 index 0000000000..17fc9a1ad0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class TopKCPUKernel : public CPUKernel { + public: + TopKCPUKernel() = default; + ~TopKCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + size_t outer_size_{1}; + size_t inner_size_{1}; + bool sorted_{false}; + TypeId dtype_{kTypeUnknown}; +}; + +MS_REG_CPU_KERNEL(TopK, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + TopKCPUKernel) +MS_REG_CPU_KERNEL(TopK, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt32), + TopKCPUKernel) +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 41d772586e..b31ab8217d 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1885,7 +1885,7 @@ class TopK(PrimitiveWithInfer): - **indices** (Tensor) - The indices of values within the last dimension of input. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> topk = ops.TopK(sorted=True) diff --git a/tests/st/ops/cpu/test_topk_op.py b/tests/st/ops/cpu/test_topk_op.py new file mode 100644 index 0000000000..1f701a8750 --- /dev/null +++ b/tests/st/ops/cpu/test_topk_op.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_topk(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + x_np = np.random.rand(3, 4).astype(np.float32) + k = 4 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(3, 4).astype(np.float32) + k = 4 + ms_output = P.TopK(False)(Tensor(x_np), k) + assert np.allclose(ms_output[0].asnumpy(), x_np) + + x_np = np.random.rand(2, 3, 4).astype(np.float32) + k = 2 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 1024).astype(np.float32) + k = 512 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + # sorted elements num greater than max thread per block + x_np = np.random.rand(512, 2048).astype(np.float32) + k = 1 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 2048).astype(np.float32) + k = 2048 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + # sorted elements num greater than max share memory per block + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 1 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 40960 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 40960 + ms_output = P.TopK(False)(Tensor(x_np), k) + assert np.allclose(ms_output[0].asnumpy(), x_np)