From 16f0688230436c1ec1c0a19ede96c9f5aa99551a Mon Sep 17 00:00:00 2001 From: wilfChen Date: Wed, 6 May 2020 19:43:37 +0800 Subject: [PATCH] gpu support broadcast kernels --- .../kernel/gpu/cuda_impl/broadcast_impl.cu | 138 ++++++++++++++++++ .../kernel/gpu/cuda_impl/broadcast_impl.cuh | 40 +++++ .../kernel/gpu/math/binary_op_gpu_kernel.cc | 8 - .../kernel/gpu/math/binary_op_gpu_kernel.h | 19 +-- .../kernel/gpu/math/broadcast_gpu_kernel.cc | 40 +++++ .../kernel/gpu/math/broadcast_gpu_kernel.h | 132 +++++++++++++++++ tests/st/ops/gpu/test_broadcast_op.py | 81 ++++++++++ 7 files changed, 433 insertions(+), 25 deletions(-) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_broadcast_op.py diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu new file mode 100644 index 0000000000..17adb738e1 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/cuda_impl/broadcast_impl.cuh" +#include "device/gpu/cuda_common.h" + +template +struct GreaterFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } +}; + +template +struct LessFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } +}; + +template +struct MinimumFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } +}; + +template +struct MaximumFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } +}; + +template +struct PowerFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } +}; + +__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } + +template +__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, + const int &r0, const int &r1, const int &r2, const int &r3, + const int &d0, const int &d1, const int &d2, const int &d3, + const T *input0, const T *input1, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { + int i = pos / (d1 * d2 * d3) % d0; + int j = pos / (d2 * d3) % d1; + int k = pos / d3 % d2; + int l = pos % d3; + + int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); + int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); + output[pos] = Func()(input0[l_index], input1[r_index]); + } +} + +template +__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, + const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, + enum BroadcastOpType op, const T *input0, const T *input1, S *output) { + switch (op) { + case BROADCAST_TYPE_GREATER: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_LESS: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MINIMUM: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MAXIMUM: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_POWER: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + } +} + +template +void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, + const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, + const T *input0, const T *input1, S *output, cudaStream_t stream) { + int size = d0 * d1 * d2 * d3; + BroadcastKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, + input0, input1, output); +} + +template +__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + output[pos] = Func()(input0[pos], input1[pos]); + } +} + +template +__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1, + S *output) { + switch (op) { + case BROADCAST_TYPE_GREATER: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_LESS: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MINIMUM: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MAXIMUM: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_POWER: + return NoBroadcastOperator>(nums, input0, input1, output); + } +} + +template +void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output, + cudaStream_t stream) { + NoBroadcastKernel<<>>(nums, op, input0, input1, output); +} + +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const float *input0, const float *input1, bool *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const float *input0, const float *input1, float *output, + cudaStream_t stream); + +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, + bool *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, + float *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh new file mode 100644 index 0000000000..e67a8a6e4d --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ + +#include "device/gpu/cuda_common.h" + +enum BroadcastOpType { + BROADCAST_TYPE_GREATER = 0, + BROADCAST_TYPE_LESS = 1, + BROADCAST_TYPE_MAXIMUM = 2, + BROADCAST_TYPE_MINIMUM = 3, + BROADCAST_TYPE_POWER = 4, + BROADCAST_TYPE_INVALID = 0xffffffff, +}; + +template +void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, + const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, + const T *input0, const T *input1, S *output, cudaStream_t stream); + +template +void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, + cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc index 4fe2acb726..56a0905e4e 100644 --- a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc @@ -38,13 +38,5 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BinaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - Maximum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BinaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - Maximum, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BinaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h index 3bf141fc0b..9d6a45ac0d 100644 --- a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h @@ -27,16 +27,9 @@ #include "kernel/gpu/kernel_constants.h" namespace mindspore { namespace kernel { -enum BinaryOpType { - BINARY_OP_ADD = 0, - BINARY_OP_SUB, - BINARY_OP_MUL, - BINARY_OP_DIV, - BINARY_OP_MAX, - BINARY_OP_INVALID_TYPE = 255 -}; +enum BinaryOpType { BINARY_OP_ADD = 0, BINARY_OP_SUB, BINARY_OP_MUL, BINARY_OP_DIV, BINARY_OP_INVALID_TYPE = 255 }; static const std::map kBinaryOpTypeMap = { - {"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}, {"Maximum", BINARY_OP_MAX}}; + {"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}}; template class BinaryOpGpuKernel : public GpuKernel { public: @@ -88,10 +81,6 @@ class BinaryOpGpuKernel : public GpuKernel { inputB_addr = workspace_addr; break; } - case BINARY_OP_MAX: { - inputB_addr = input_addr2; - break; - } default: { MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; } @@ -209,10 +198,6 @@ class BinaryOpGpuKernel : public GpuKernel { tensor_op_ = CUDNN_OP_TENSOR_ADD; break; } - case BINARY_OP_MAX: { - tensor_op_ = CUDNN_OP_TENSOR_MAX; - break; - } default: { MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; } diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc new file mode 100644 index 0000000000..1761597c7b --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/math/broadcast_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + Greater, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h new file mode 100644 index 0000000000..dfb0487ee4 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/broadcast_impl.cuh" +#include "kernel/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { + +template +class BroadcastOpGpuKernel : public GpuKernel { + public: + BroadcastOpGpuKernel() + : op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} + ~BroadcastOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + T *lhs = GetDeviceAddress(inputs, 0); + T *rhs = GetDeviceAddress(inputs, 1); + S *output = GetDeviceAddress(outputs, 0); + + if (need_broadcast_) { + Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], + rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, + rhs, output, reinterpret_cast(stream_ptr)); + } else { + NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + GetOpType(kernel_node); + auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); + need_broadcast_ = IsBroadcast(shape1, shape2); + if (need_broadcast_ && shape1.size() > 4) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; + } + + for (size_t i = 0; i < shape1.size(); i++) { + lhs_shape_[i] = shape1[i]; + rhs_shape_[i] = shape2[i]; + output_shape_[i] = shape3[i]; + + input1_num_ *= shape1[i]; + input2_num_ *= shape2[i]; + output_num_ *= shape3[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { return; } + void InitSizeLists() override { + input_size_list_.push_back(input1_num_ * sizeof(T)); + input_size_list_.push_back(input2_num_ * sizeof(T)); + output_size_list_.push_back(output_num_ * sizeof(S)); + } + + private: + void GetOpType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + + static std::map kBroadcastTypeMap = { + {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, + {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, + }; + + auto iter = kBroadcastTypeMap.find(kernel_name); + if (iter == kBroadcastTypeMap.end()) { + MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; + } else { + op_type_ = iter->second; + } + } + + bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i] != rhs[i]) { + return true; + } + } + return false; + } + + BroadcastOpType op_type_; + bool need_broadcast_; + int input1_num_; + int input2_num_; + int output_num_; + int lhs_shape_[4] = {1, 1, 1, 1}; + int rhs_shape_[4] = {1, 1, 1, 1}; + int output_shape_[4] = {1, 1, 1, 1}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_broadcast_op.py b/tests/st/ops/gpu/test_broadcast_op.py new file mode 100644 index 0000000000..2baa72ad6f --- /dev/null +++ b/tests/st/ops/gpu/test_broadcast_op.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +import mindspore.context as context +import numpy as np + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nobroadcast(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x1_np = np.random.rand(10, 20).astype(np.float32) + x2_np = np.random.rand(10, 20).astype(np.float32) + + output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.minimum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.maximum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np > x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np < x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.power(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x1_np = np.random.rand(3, 1, 5, 1).astype(np.float32) + x2_np = np.random.rand(1, 4, 1, 6).astype(np.float32) + + output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.minimum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.maximum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np > x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np < x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.power(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np)