From ad8a786b07ebbdce8b3d8060e62c826842164b9f Mon Sep 17 00:00:00 2001 From: TFbunny Date: Thu, 23 Jul 2020 17:38:58 -0400 Subject: [PATCH] add GPU support to RandomChoiceWithMask --- .../cuda_impl/random_choice_with_mask_impl.cu | 265 ++++++++++++++++++ .../random_choice_with_mask_impl.cuh | 34 +++ .../random_choice_with_mask_gpu_kernel.cc | 26 ++ .../random_choice_with_mask_gpu_kernel.h | 129 +++++++++ mindspore/ops/operations/random_ops.py | 7 +- .../ops/gpu/test_random_choice_with_mask.py | 86 ++++++ 6 files changed, 544 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_random_choice_with_mask.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu new file mode 100644 index 0000000000..6ce1fda22b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu @@ -0,0 +1,265 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" +#include + +int RcwmRoundUpPower2(int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__inline__ __device__ void Swap(T *lhs, T *rhs) { + T tmp = lhs[0]; + lhs[0] = rhs[0]; + rhs[0] = tmp; +} + +template +__global__ void InitArray(const int input_size, const int ceil_power2, const T *input, S *mask_buff, S *rank_buff) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ceil_power2; pos += blockDim.x * gridDim.x) { + mask_buff[pos] = (pos < input_size) ? static_cast(input[pos]) : 0; + rank_buff[pos] = (pos < input_size && input[pos] != false) ? pos : (ceil_power2 + 1); + } +} + +template +__device__ void WarpReduce(volatile T *sdata, size_t tid) { + if (blockSize >= 64) sdata[tid] += sdata[tid + 32]; + if (blockSize >= 32) sdata[tid] += sdata[tid + 16]; + if (blockSize >= 16) sdata[tid] += sdata[tid + 8]; + if (blockSize >= 8) sdata[tid] += sdata[tid + 4]; + if (blockSize >= 4) sdata[tid] += sdata[tid + 2]; + if (blockSize >= 2) sdata[tid] += sdata[tid + 1]; +} + +template +__global__ void ReductionSum(T *g_idata, T *g_odata, size_t n) { + __shared__ T sdata[blockSize]; + + size_t tid = threadIdx.x; + size_t i = blockIdx.x * (blockSize) + tid; + size_t gridSize = blockSize * gridDim.x; + sdata[tid] = 0; + + while (i < n) { + sdata[tid] += g_idata[i]; + i += gridSize; + } + + __syncthreads(); + + if (blockSize >= 1024) { + if (tid < 512) { + sdata[tid] += sdata[tid + 512]; + } + __syncthreads(); + } + if (blockSize >= 512) { + if (tid < 256) { + sdata[tid] += sdata[tid + 256]; + } + __syncthreads(); + } + if (blockSize >= 256) { + if (tid < 128) { + sdata[tid] += sdata[tid + 128]; + } + __syncthreads(); + } + if (blockSize >= 128) { + if (tid < 64) { + sdata[tid] += sdata[tid + 64]; + } + __syncthreads(); + } + + if (tid < 32) WarpReduce(sdata, tid); + if (tid == 0) g_odata[blockIdx.x] = sdata[0]; +} + +template +__global__ void Reshape2Index(const int input_size, const int input_shape_size, const int d1, const int d2, + const int d3, const int d4, const int d5, const T *input, S *output_index) { + int pos_array[MAX_DIMENSION]; + int index_pos; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) { + pos_array[0] = pos / (d2 * d3 * d4 * d5) % d1; + pos_array[1] = pos / (d3 * d4 * d5) % d2; + pos_array[2] = pos / (d4 * d5) % d3; + pos_array[3] = pos / (d5) % d4; + pos_array[4] = pos % d5; + + index_pos = pos * input_shape_size; + if (input[pos] == false) { + for (int i = 0; i < input_shape_size; i++) { + output_index[index_pos++] = 0; + } + } else { + for (int i = MAX_DIMENSION - input_shape_size; i < MAX_DIMENSION; i++) { + output_index[index_pos++] = pos_array[i]; + } + } + } +} + +template +__global__ void Copy(const T *src, T *dst, const int n) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n; pos += blockDim.x * gridDim.x) { + dst[pos] = src[pos]; + } +} + +template +__global__ void Sort(const int ceil_power2, T *rank_buff) { + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (rank_buff[tid] > rank_buff[tid_comp]) { + Swap(&rank_buff[tid], &rank_buff[tid_comp]); + } + } else { + if (rank_buff[tid] < rank_buff[tid_comp]) { + Swap(&rank_buff[tid], &rank_buff[tid_comp]); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ void SrandInit(const int ceil_power2, curandState *globalState, const int seedc) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < ceil_power2; i += blockDim.x * gridDim.x) { + curand_init(seedc, i, 0, &globalState[i]); + } +} + +template +__global__ void Shuffle(const int ceil_power2, curandState *globalState, T *rank_buff) { + int limit = ceil_power2 + 1; + int value; + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + value = static_cast(curand(&globalState[tid])); + if (value & 1) { + if (rank_buff[tid] != limit && rank_buff[tid_comp] != limit) { + Swap(&rank_buff[tid], &rank_buff[tid_comp]); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void MoveToOutput(const int input_shape_size, const int count, const T *input, S *output_index, + T *output_mask, S *index_buff, S *rank_buff, S *Tnum_buff) { + int Tnum = static_cast(Tnum_buff[0]); + int idx = 0; + int pos; + if (count <= Tnum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + idx = rank_buff[i]; + pos = i; + output_mask[pos] = input[idx]; + pos *= input_shape_size; + idx *= input_shape_size; + for (size_t j = 0; j < input_shape_size; j++) { + output_index[pos] = index_buff[idx]; + pos++; + idx++; + } + } + } else { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + if (i < Tnum) { + idx = rank_buff[i]; + pos = i; + output_mask[pos] = input[idx]; + pos *= input_shape_size; + idx *= input_shape_size; + for (size_t j = 0; j < input_shape_size; j++) { + output_index[pos] = index_buff[idx]; + pos++; + idx++; + } + } else { + pos = i; + output_mask[pos] = static_cast(0); + pos *= input_shape_size; + for (size_t j = 0; j < input_shape_size; j++) { + output_index[pos] = static_cast(0); + pos++; + } + } + } + } +} + +template +void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, + const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, + const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff, + S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream) { + int ceil_power2 = RcwmRoundUpPower2(input_size); + + InitArray<<>>(input_size, ceil_power2, input, mask_buff, rank_buff); + + size_t BLOCKNUM; + size_t n = ceil_power2; + Copy<<>>(mask_buff, tmp_buff, ceil_power2); + do { + BLOCKNUM = std::ceil(static_cast(n) / BLOCKSIZE); + ReductionSum<<>>(tmp_buff, Tnum_buff, n); + Copy<<>>(Tnum_buff, tmp_buff, BLOCKNUM); + n = BLOCKNUM; + } while (n > BLOCKSIZE); + if (n > 1) ReductionSum<<<1, BLOCKSIZE, 0, stream>>>(Tnum_buff, Tnum_buff, n); + + Reshape2Index<<>>(input_size, input_shape_size, d1, d2, d3, d4, d5, + input, index_buff); + + Sort<<>>(ceil_power2, rank_buff); + + SrandInit<<>>(ceil_power2, globalState, seedc); + Shuffle<<>>(ceil_power2, globalState, rank_buff); + + MoveToOutput<<>>(input_shape_size, count, input, output_index, output_mask, + index_buff, rank_buff, Tnum_buff); +} + +template void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, + const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, + const bool *input, int *output_index, bool *output_mask, int *index_buff, + int *mask_buff, int *rank_buff, int *Tnum_buff, int *tmp_buff, + curandState *globalState, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh new file mode 100644 index 0000000000..bb654e4b58 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ + +#include +#include +#include "runtime/device/gpu/cuda_common.h" +#define BLOCKSIZE 256 +#define MAX_DIMENSION 5 + +template +void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, + const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, + const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff, + S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream); + +int RcwmRoundUpPower2(int v); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.cc new file mode 100644 index 0000000000..9d810878b0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + RandomChoiceWithMask, + KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + RandomChoiceWithMaskGpuKernel, bool, int) +} +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h new file mode 100644 index 0000000000..c4c3380723 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class RandomChoiceWithMaskGpuKernel : public GpuKernel { + public: + RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {} + ~RandomChoiceWithMaskGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + S *output_index = GetDeviceAddress(outputs, 0); + T *output_mask = GetDeviceAddress(outputs, 1); + S *index_buff = GetDeviceAddress(workspaces, 0); + S *mask_buff = GetDeviceAddress(workspaces, 1); + S *rank_buff = GetDeviceAddress(workspaces, 2); + S *Tnum_buff = GetDeviceAddress(workspaces, 3); + S *tmp_buff = GetDeviceAddress(workspaces, 4); + void *States = GetDeviceAddress(workspaces, 5); + curandState *devStates = reinterpret_cast(States); + CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2], + input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask, + index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_shape_size_ = input_shape.size(); + if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) { + MS_LOG(ERROR) << "Input is " << input_shape_size_ + << "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs."; + return false; + } + // convert size_t to int + for (auto i = 0; i < input_shape_size_; i++) { + input_shape_5D_.push_back(input_shape[i]); + } + // convert shape to 5D + while (input_shape_5D_.size() != MAX_DIMENSION) { + input_shape_5D_.insert(input_shape_5D_.begin(), 1); + } + // init seedc_ + int seed = GetAttr(kernel_node, "seed"); + int seed2 = GetAttr(kernel_node, "seed2"); + if (seed2 != 0) + seedc_ = seed2; + else if (seed != 0) + seedc_ = seed; + else + seedc_ = time(NULL); + // init memory + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + count_ = GetAttr(kernel_node, "count"); + // upper ceiling for input for ceil_power2 + ceil_power2_ = RcwmRoundUpPower2(input_size_); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S)); + output_size_list_.push_back(count_ * sizeof(T)); + workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + int blocknum = std::ceil(static_cast(ceil_power2_) / BLOCKSIZE); + workspace_size_list_.push_back(blocknum * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); + } + + private: + int input_shape_size_; + int seedc_; + int input_size_; + int count_; + int ceil_power2_; + std::vector input_shape_5D_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 065c4eaf27..b536bc696e 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -348,13 +348,13 @@ class RandomChoiceWithMask(PrimitiveWithInfer): seed2 (int): Random seed2. Default: 0. Inputs: - - **input_x** (Tensor[bool]) - The input tensor. + - **input_x** (Tensor[bool]) - The input tensor. The input tensor rank should be >= 1 and <= 5. Outputs: Two tensors, the first one is the index tensor and the other one is the mask tensor. - - **index** (Tensor) - The output has shape between 2-D and 5-D. - - **mask** (Tensor) - The output has shape 1-D. + - **index** (Tensor) - The output shape is 2-D. + - **mask** (Tensor) - The output shape is 1-D. Examples: >>> rnd_choice_mask = P.RandomChoiceWithMask() @@ -372,6 +372,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): def infer_shape(self, x_shape): validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): diff --git a/tests/st/ops/gpu/test_random_choice_with_mask.py b/tests/st/ops/gpu/test_random_choice_with_mask.py new file mode 100644 index 0000000000..3ca12e7dd6 --- /dev/null +++ b/tests/st/ops/gpu/test_random_choice_with_mask.py @@ -0,0 +1,86 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class RCWM_count_in(nn.Cell): + def __init__(self): + super(RCWM_count_in, self).__init__() + self.RCWM_count_in = P.RandomChoiceWithMask(count=4, seed=1) + + def construct(self, x): + return self.RCWM_count_in(x) + +class RCWM_count_out(nn.Cell): + def __init__(self): + super(RCWM_count_out, self).__init__() + self.RCWM_count_out = P.RandomChoiceWithMask(count=10, seed=1) + + def construct(self, x): + return self.RCWM_count_out(x) + +class RCWM_3D(nn.Cell): + def __init__(self): + super(RCWM_3D, self).__init__() + self.RCWM_3D = P.RandomChoiceWithMask(count=10, seed=1) + + def construct(self, x): + return self.RCWM_3D(x) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_RCWM_3D(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.ones([3, 4, 5]).astype(np.bool)) + expect1 = [[0, 1, 1], [0, 2, 1], [0, 2, 2], [1, 0, 1], [0, 1, 3], [0, 3, 0], [1, 3, 2], \ + [0, 0, 0], [1, 1, 2], [1, 3, 4]] + expect2 = [True, True, True, True, True, True, True, True, True, True] + rcwm = RCWM_3D() + output1, output2 = rcwm(input_tensor) + assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) + assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_RCWM_count_out(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) + expect1 = [[0, 2], [2, 2], [2, 1], [2, 0], [0, 0], [3, 3], [2, 3], [1, 3], [0, 0], [0, 0]] + expect2 = [True, True, True, True, True, True, True, True, False, False] + rcwm = RCWM_count_out() + output1, output2 = rcwm(input_tensor) + assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) + assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_RCWM_count_in(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) + expect1 = [[0, 2], [2, 2], [2, 1], [2, 0]] + expect2 = [True, True, True, True] + rcwm = RCWM_count_in() + output1, output2 = rcwm(input_tensor) + assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) + assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2)