From 06fb28c703e49219773a6d750fa2aa92824f79c9 Mon Sep 17 00:00:00 2001 From: caojian05 Date: Mon, 28 Dec 2020 15:00:16 +0800 Subject: [PATCH] add CPU ops: Greater/GreaterEqual/Range/GatherNd for center net --- .../cpu/arithmetic_cpu_kernel.cc | 27 +++ .../cpu/arithmetic_cpu_kernel.h | 26 +++ .../backend/kernel_compiler/cpu/cpu_kernel.h | 7 +- .../cpu/gathernd_cpu_kernel.cc | 104 ++++++++++ .../kernel_compiler/cpu/gathernd_cpu_kernel.h | 67 +++++++ .../kernel_compiler/cpu/range_cpu_kernel.cc | 56 ++++++ .../kernel_compiler/cpu/range_cpu_kernel.h | 54 +++++ mindspore/nn/layer/math.py | 2 +- mindspore/ops/operations/array_ops.py | 2 +- mindspore/ops/operations/math_ops.py | 4 +- tests/st/ops/cpu/test_gathernd_op.py | 188 ++++++++++++++++++ tests/st/ops/cpu/test_greater_equal_op.py | 70 +++++++ tests/st/ops/cpu/test_greater_op.py | 70 +++++++ tests/st/ops/cpu/test_range_op.py | 62 ++++++ 14 files changed, 734 insertions(+), 5 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_gathernd_op.py create mode 100644 tests/st/ops/cpu/test_greater_equal_op.py create mode 100644 tests/st/ops/cpu/test_greater_op.py create mode 100644 tests/st/ops/cpu/test_range_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index 8715800dd9..ed24d7fc41 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -167,6 +167,24 @@ void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T } } +template +void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] > input2[idx[1]]; + } +} + +template +void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] >= input2[idx[1]]; + } +} + void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); @@ -190,6 +208,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = EQUAL; } else if (kernel_name == prim::kPrimNotEqual->name()) { operate_type_ = NOTEQUAL; + } else if (kernel_name == prim::kPrimGreater->name()) { + operate_type_ = GREATER; + } else if (kernel_name == prim::kPrimGreaterEqual->name()) { + operate_type_ = GREATEREQUAL; } else if (kernel_name == prim::kPrimAssignAdd->name()) { operate_type_ = ASSIGNADD; } else if (kernel_name == prim::kPrimSquaredDifference->name()) { @@ -301,6 +323,11 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector &input threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal, this, input1, input2, output, start, end)); } else if (operate_type_ == NOTEQUAL) { threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual, this, input1, input2, output, start, end)); + } else if (operate_type_ == GREATER) { + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Greater, this, input1, input2, output, start, end)); + } else if (operate_type_ == GREATEREQUAL) { + threads.emplace_back( + std::thread(&ArithmeticCPUKernel::GreaterEqual, this, input1, input2, output, start, end)); } else { MS_LOG(EXCEPTION) << "Not support " << operate_type_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 6bcb9a4a68..ae63e43857 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -63,6 +63,10 @@ class ArithmeticCPUKernel : public CPUKernel { void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); template void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end); + template + void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); std::vector input_shape0_; std::vector input_shape1_; std::vector input_element_num0_; @@ -213,6 +217,28 @@ MS_REG_CPU_KERNEL( SquaredDifference, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Greater, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index a2b5d1f9d4..359840be4a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -53,6 +53,9 @@ const char END[] = "end"; const char SIZE[] = "size"; const char USE_NESTEROV[] = "use_nesterov"; const char GROUP[] = "group"; +const char START[] = "start"; +const char LIMIT[] = "limit"; +const char DELTA[] = "delta"; enum OperateType { ADD = 0, @@ -79,7 +82,9 @@ enum OperateType { EQUAL, NOTEQUAL, FLOOR, - SQUAREDDIFFERENCE + SQUAREDDIFFERENCE, + GREATER, + GREATEREQUAL, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.cc new file mode 100644 index 0000000000..14ed76a5d1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/gathernd_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { + +void GatherNdCPUKernel::InitKernel(const CNodePtr &kernel_node) { + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + + // ReShape() + size_t dim_of_indices = 1; + for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); ++i) { + dim_of_indices *= indices_shapes_[i]; + } + + size_t dim_after_indices = 1; + size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)]; + for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) { + dim_after_indices *= input_shapes_[i]; + } + + dims_.emplace_back(dim_of_indices); + dims_.emplace_back(dim_after_indices); + dims_.emplace_back(dim_indices_last); + + batch_strides_.resize(dim_indices_last, 0); + batch_indices_.resize(dim_indices_last, 0); + + if (dim_indices_last > 0) { + batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1]; + batch_indices_[dim_indices_last - 1] = dims_[1]; + } + + for (size_t i = dim_indices_last - 1; i > 0; --i) { + batch_strides_[i - 1] = input_shapes_[i - 1]; + batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i]; + } +} + +bool GatherNdCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeInt32) { + return LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt64) { + return LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + return LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + return LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_); + } +} + +template +bool GatherNdCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto indices_addr = reinterpret_cast(inputs[1]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + // + size_t output_dim0 = dims_[0]; + size_t output_dim1 = dims_[1]; + size_t indices_dim1 = dims_[2]; + + int num = output_dim0 * output_dim1; + + for (int write_index = 0; write_index < num; write_index++) { + int i = write_index / output_dim1 % output_dim0; + int j = write_index % output_dim1; + + int read_index = 0; + for (size_t k = 0; k < indices_dim1; k++) { + size_t ind = indices_dim1 * i + k; + int indices_i = indices_addr[ind]; + read_index += indices_i * batch_indices_[k]; + } + read_index += j; + output_addr[write_index] = input_addr[read_index]; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.h new file mode 100644 index 0000000000..2b65ddccee --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class GatherNdCPUKernel : public CPUKernel { + public: + GatherNdCPUKernel() = default; + ~GatherNdCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + std::vector input_shapes_; + std::vector indices_shapes_; + std::vector output_shapes_; + + std::vector dims_; + std::vector batch_indices_; + std::vector batch_strides_; + + TypeId dtype_{kTypeUnknown}; +}; + +MS_REG_CPU_KERNEL( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + GatherNdCPUKernel); +MS_REG_CPU_KERNEL( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + GatherNdCPUKernel); +MS_REG_CPU_KERNEL( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherNdCPUKernel); +MS_REG_CPU_KERNEL( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + GatherNdCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.cc new file mode 100644 index 0000000000..906ab15814 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/range_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void RangeCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + + start_ = AnfAlgo::GetNodeAttr(kernel_node, START); + limit_ = AnfAlgo::GetNodeAttr(kernel_node, LIMIT); + delta_ = AnfAlgo::GetNodeAttr(kernel_node, DELTA); +} + +bool RangeCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeInt32) { + return LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt64) { + return LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + return LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + return LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_); + } +} + +template +bool RangeCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t elem_num = outputs[0]->size / sizeof(T); + for (size_t i = 0; i < elem_num; i++) { + output_addr[i] = start_ + i * delta_; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.h new file mode 100644 index 0000000000..f4e846b5ce --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class RangeCPUKernel : public CPUKernel { + public: + RangeCPUKernel() = default; + ~RangeCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + TypeId dtype_{kTypeUnknown}; + int64_t start_; + int64_t limit_; + int64_t delta_; +}; + +MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), RangeCPUKernel); +MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), RangeCPUKernel); +MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + RangeCPUKernel); +MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + RangeCPUKernel); + +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_ diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 878a2c1f42..3904a495c8 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -116,7 +116,7 @@ class Range(Cell): Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> net = nn.Range(1, 8, 2) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index be56f49468..6edf59067c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -3078,7 +3078,7 @@ class GatherNd(PrimitiveWithInfer): Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:]. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 87999a1eab..88aee88cc5 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2698,7 +2698,7 @@ class Greater(_LogicBinaryOp): Tensor, the shape is the same as the one after broadcasting,and the data type is bool. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) @@ -2739,7 +2739,7 @@ class GreaterEqual(_LogicBinaryOp): Tensor, the shape is the same as the one after broadcasting,and the data type is bool. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) diff --git a/tests/st/ops/cpu/test_gathernd_op.py b/tests/st/ops/cpu/test_gathernd_op.py new file mode 100644 index 0000000000..772770083d --- /dev/null +++ b/tests/st/ops/cpu/test_gathernd_op.py @@ -0,0 +1,188 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class OpNetWrapper(nn.Cell): + def __init__(self, op): + super(OpNetWrapper, self).__init__() + self.op = op + + def construct(self, *inputs): + return self.op(*inputs) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case1_basic_func(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) + params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [0, 3] + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case2_indices_to_matrix(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[1], [0]]), mindspore.int32) + params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [[2, 3], [0, 1]] + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case3_indices_to_3d_tensor(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[1]]), mindspore.int32) # (1, 1) + params = Tensor(np.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [[[4, 5], [6, 7]]] # (1, 2, 2) + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case4(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[0, 1], [1, 0]]), mindspore.int32) # (2, 2) + params = Tensor(np.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [[2, 3], [4, 5]] # (2, 2) + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case5(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[0, 0, 1], [1, 0, 1]]), mindspore.int32) # (2, 3) + params = Tensor(np.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [1, 5] # (2,) + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case6(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[[0, 0]], [[0, 1]]]), mindspore.int32) # (2, 1, 2) + params = Tensor(np.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [[[0, 1]], [[2, 3]]] # (2, 1, 2) + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case7(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[[1]], [[0]]]), mindspore.int32) # (2, 1, 1) + params = Tensor(np.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [[[[4, 5], [6, 7]]], [[[0, 1], [2, 3]]]] # (2, 1, 2, 2) + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case8(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[[0, 1], [1, 0]], [[0, 0], [1, 1]]]), mindspore.int32) # (2, 2, 2) + params = Tensor(np.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [[[2, 3], [4, 5]], [[0, 1], [6, 7]]] # (2, 2, 2) + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_case9(): + op = P.GatherNd() + op_wrapper = OpNetWrapper(op) + + indices = Tensor(np.array([[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]), mindspore.int32) # (2, 2, 3) + params = Tensor(np.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]]), mindspore.int64) # (2, 2, 2) + outputs = op_wrapper(params, indices) + print(outputs) + expected = [[1, 5], [3, 6]] # (2, 2, 2) + assert np.allclose(outputs.asnumpy(), np.array(expected)) + + +if __name__ == '__main__': + test_case1_basic_func() + test_case2_indices_to_matrix() + test_case3_indices_to_3d_tensor() + test_case4() + test_case5() + test_case6() + test_case7() + test_case8() + test_case9() diff --git a/tests/st/ops/cpu/test_greater_equal_op.py b/tests/st/ops/cpu/test_greater_equal_op.py new file mode 100644 index 0000000000..ce6827632f --- /dev/null +++ b/tests/st/ops/cpu/test_greater_equal_op.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class OpNetWrapper(nn.Cell): + def __init__(self, op): + super(OpNetWrapper, self).__init__() + self.op = op + + def construct(self, *inputs): + return self.op(*inputs) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_int32(): + op = P.GreaterEqual() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, 2, 3]).astype(np.int32)) + input_y = Tensor(np.array([3, 2, 1]).astype(np.int32)) + outputs = op_wrapper(input_x, input_y) + + print(outputs) + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, True, True]) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_float32(): + op = P.GreaterEqual() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, 2, -1]).astype(np.float32)) + input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32)) + outputs = op_wrapper(input_x, input_y) + + print(outputs) + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [True, True, True]) + + +if __name__ == '__main__': + test_int32() + test_float32() diff --git a/tests/st/ops/cpu/test_greater_op.py b/tests/st/ops/cpu/test_greater_op.py new file mode 100644 index 0000000000..91ca9b78c8 --- /dev/null +++ b/tests/st/ops/cpu/test_greater_op.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class OpNetWrapper(nn.Cell): + def __init__(self, op): + super(OpNetWrapper, self).__init__() + self.op = op + + def construct(self, *inputs): + return self.op(*inputs) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_int32(): + op = P.Greater() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, 2, 3]).astype(np.int32)) + input_y = Tensor(np.array([3, 2, 1]).astype(np.int32)) + outputs = op_wrapper(input_x, input_y) + + print(outputs) + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, False, True]) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_float32(): + op = P.Greater() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, 2, -1]).astype(np.float32)) + input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32)) + outputs = op_wrapper(input_x, input_y) + + print(outputs) + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [True, False, False]) + + +if __name__ == '__main__': + test_int32() + test_float32() diff --git a/tests/st/ops/cpu/test_range_op.py b/tests/st/ops/cpu/test_range_op.py new file mode 100644 index 0000000000..4acaae5baf --- /dev/null +++ b/tests/st/ops/cpu/test_range_op.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class OpNetWrapper(nn.Cell): + def __init__(self, op): + super(OpNetWrapper, self).__init__() + self.op = op + + def construct(self, *inputs): + return self.op(*inputs) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_int(): + op = nn.Range(0, 100, 10) + op_wrapper = OpNetWrapper(op) + + outputs = op_wrapper() + print(outputs) + assert outputs.shape == (10,) + assert np.allclose(outputs.asnumpy(), range(0, 100, 10)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_float(): + op = nn.Range(10., 100., 20.) + op_wrapper = OpNetWrapper(op) + + outputs = op_wrapper() + print(outputs) + assert outputs.shape == (5,) + assert np.allclose(outputs.asnumpy(), [10., 30., 50., 70., 90.]) + + +if __name__ == '__main__': + test_int() + test_float()