From bd0b46269132e16a3bf909ac59f3b4083a2de803 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Tue, 20 Oct 2020 15:17:52 -0400 Subject: [PATCH] new gpu op for cbg repeat_elements fixed ci fixed ci addressed comments --- .../gpu/arrays/repeat_elements_gpu_kernel.cc | 28 + .../gpu/arrays/repeat_elements_gpu_kernel.h | 161 +++++ .../gpu/cuda_impl/repeat_elements_impl.cu | 318 +++++++++ .../gpu/cuda_impl/repeat_elements_impl.cuh | 52 ++ mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/array_ops.py | 49 ++ tests/st/ops/gpu/test_repeat_elements_op.py | 656 ++++++++++++++++++ 7 files changed, 1267 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh create mode 100644 tests/st/ops/gpu/test_repeat_elements_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc new file mode 100644 index 0000000000..3bd4a0a433 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(RepeatElements, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + RepeatElementsGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE(RepeatElements, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + RepeatElementsGpuKernel, int32_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h new file mode 100644 index 0000000000..6900ef4e65 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_ + +#include "backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh" + +#include + +#include +#include + +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class RepeatElementsGpuKernel : public GpuKernel { + public: + RepeatElementsGpuKernel() : rep_(1), axis_(0), input_size_(1), output_size_(0) {} + ~RepeatElementsGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input_device_address = GetDeviceAddress(inputs, 0); + T *output_device_address = GetDeviceAddress(outputs, 0); + + switch (input_dim_) { + case 1: + CalRepeatElements1d(input_device_address, rep_, axis_, output_device_address, output_size_, + reinterpret_cast(stream_ptr)); + break; + case 2: + CalRepeatElements2d(input_device_address, input_shape_[1], rep_, axis_, output_device_address, output_shape_[1], + output_size_, reinterpret_cast(stream_ptr)); + break; + case 3: + CalRepeatElements3d(input_device_address, input_shape_[1], input_shape_[2], rep_, axis_, output_device_address, + output_shape_[1], output_shape_[2], output_size_, + reinterpret_cast(stream_ptr)); + break; + case 4: + CalRepeatElements4d(input_device_address, input_shape_[1], input_shape_[2], input_shape_[3], rep_, axis_, + output_device_address, output_shape_[1], output_shape_[2], output_shape_[3], output_size_, + reinterpret_cast(stream_ptr)); + break; + case 5: + CalRepeatElements5d(input_device_address, input_shape_[1], input_shape_[2], input_shape_[3], input_shape_[4], + rep_, axis_, output_device_address, output_shape_[1], output_shape_[2], output_shape_[3], + output_shape_[4], output_size_, reinterpret_cast(stream_ptr)); + break; + default: + int *input_shape_device_address = GetDeviceAddress(workspace, 0); + int *output_shape_device_address = GetDeviceAddress(workspace, 1); + int *input_shape_cumulative_product_device_address = GetDeviceAddress(workspace, 2); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(input_shape_device_address, input_shape_.data(), workspace_size_list_[0], + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(output_shape_device_address, output_shape_.data(), workspace_size_list_[1], + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(input_shape_cumulative_product_device_address, input_shape_cumulative_product_.data(), + workspace_size_list_[2], cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape_cumulative_product_device_address failed"); + + CalRepeatElements(input_device_address, input_dim_, input_shape_device_address, + input_shape_cumulative_product_device_address, rep_, axis_, output_device_address, + output_shape_device_address, output_size_, reinterpret_cast(stream_ptr)); + break; + } + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_count != 1) { + MS_LOG(EXCEPTION) << input_count << " arguments were provided, but RepeatElementGpuKernel expects 1."; + } + + std::vector temp_input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_dim_ = temp_input_shape.size(); + for (size_t e : temp_input_shape) { + input_size_ *= e; + input_shape_.push_back(e); + } + + int cumulative_product = 1; + for (size_t i = input_dim_ - 1; i > 0; i--) { + cumulative_product *= input_shape_[i]; + input_shape_cumulative_product_.push_back(cumulative_product); + } + std::reverse(input_shape_cumulative_product_.begin(), input_shape_cumulative_product_.end()); + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + axis_ += input_dim_; + } + + rep_ = GetAttr(kernel_node, "rep"); + output_size_ = input_size_ * rep_; + output_shape_ = input_shape_; + output_shape_[axis_] *= rep_; + + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(output_size_ * sizeof(T)); + + // workspaces for input shape, output shape and cumulative sum + workspace_size_list_.push_back(input_dim_ * sizeof(int)); + workspace_size_list_.push_back(input_dim_ * sizeof(int)); + workspace_size_list_.push_back((input_dim_ - 1) * sizeof(int)); + } + + private: + int rep_; + int axis_; + int input_dim_; + std::vector input_shape_; + std::vector input_shape_cumulative_product_; + std::vector output_shape_; + + size_t input_size_; + size_t output_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu new file mode 100644 index 0000000000..c95f2b6e70 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu @@ -0,0 +1,318 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "repeat_elements_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void RepeatElements1d(const T *input, const int rep, const int axis, T *output, + const int output_size) { + for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { + int copied_value_index = gt_id / rep; + output[gt_id] = input[copied_value_index]; + } +} + +template +__global__ void RepeatElements2d(const T *input, const int input_d1, const int rep, const int axis, T *output, + const int output_d1, const int output_size) { + for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { + int global_array_index = gt_id; + + int index_d1 = global_array_index % output_d1; + global_array_index -= index_d1; + global_array_index /= output_d1; + + int index_d0 = global_array_index; + + switch (axis) { + case 0: + index_d0 /= rep; + break; + case 1: + index_d1 /= rep; + break; + } + + const int term0 = index_d0 * input_d1; + const int copied_value_index = term0 + index_d1; + output[gt_id] = input[copied_value_index]; + } +} + +template +__global__ void RepeatElements3d(const T *input, const int input_d1, const int input_d2, const int rep, const int axis, + T *output, const int output_d1, const int output_d2, const int output_size) { + for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { + int global_array_index = gt_id; + + int index_d2 = global_array_index % output_d2; + global_array_index -= index_d2; + global_array_index /= output_d2; + + int index_d1 = global_array_index % output_d1; + global_array_index -= index_d1; + global_array_index /= output_d1; + + int index_d0 = global_array_index; + + switch (axis) { + case 0: + index_d0 /= rep; + break; + case 1: + index_d1 /= rep; + break; + case 2: + index_d2 /= rep; + break; + default: + asm("trap;"); + } + + const int term0 = index_d0 * input_d1 * input_d2; + const int term1 = index_d1 * input_d2; + const int copied_value_index = term0 + term1 + index_d2; + output[gt_id] = input[copied_value_index]; + } +} + +template +__global__ void RepeatElements4d(const T *input, const int input_d1, const int input_d2, const int input_d3, + const int rep, const int axis, T *output, const int output_d1, const int output_d2, + const int output_d3, const int output_size) { + for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { + int global_array_index = gt_id; + + int index_d3 = global_array_index % output_d3; + global_array_index -= index_d3; + global_array_index /= output_d3; + + int index_d2 = global_array_index % output_d2; + global_array_index -= index_d2; + global_array_index /= output_d2; + + int index_d1 = global_array_index % output_d1; + global_array_index -= index_d1; + global_array_index /= output_d1; + + int index_d0 = global_array_index; + + switch (axis) { + case 0: + index_d0 /= rep; + break; + case 1: + index_d1 /= rep; + break; + case 2: + index_d2 /= rep; + break; + case 3: + index_d3 /= rep; + break; + } + + const int term0 = index_d0 * input_d1 * input_d2 * input_d3; + const int term1 = index_d1 * input_d2 * input_d3; + const int term2 = index_d2 * input_d3; + const int copied_value_index = term0 + term1 + term2 + index_d3; + output[gt_id] = input[copied_value_index]; + } +} + +template +__global__ void RepeatElements5d(const T *input, const int input_d1, const int input_d2, const int input_d3, + const int input_d4, const int rep, const int axis, T *output, const int output_d1, + const int output_d2, const int output_d3, const int output_d4, const int output_size) { + for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { + int global_array_index = gt_id; + + int index_d4 = global_array_index % output_d4; + global_array_index -= index_d4; + global_array_index /= output_d4; + + int index_d3 = global_array_index % output_d3; + global_array_index -= index_d3; + global_array_index /= output_d3; + + int index_d2 = global_array_index % output_d2; + global_array_index -= index_d2; + global_array_index /= output_d2; + + int index_d1 = global_array_index % output_d1; + global_array_index -= index_d1; + global_array_index /= output_d1; + + int index_d0 = global_array_index; + + switch (axis) { + case 0: + index_d0 /= rep; + break; + case 1: + index_d1 /= rep; + break; + case 2: + index_d2 /= rep; + break; + case 3: + index_d3 /= rep; + break; + case 4: + index_d4 /= rep; + break; + } + + const int term0 = index_d0 * input_d1 * input_d2 * input_d3 * input_d4; + const int term1 = index_d1 * input_d2 * input_d3 * input_d4; + const int term2 = index_d2 * input_d3 * input_d4; + const int term3 = index_d3 * input_d4; + const int copied_value_index = term0 + term1 + term2 + term3 + index_d4; + output[gt_id] = input[copied_value_index]; + } +} + +template +__global__ void RepeatElements(const T *input, const int input_dim, const int* const input_shape, + const int* const coefficients, const int rep, const int axis, T *output, + const int* const output_shape, const int output_size) { + for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { + int index_tuple[REPEAT_ELEMENTS_MAX_INPUT_DIM]; + + int global_array_index = gt_id; + for (size_t i = input_dim - 1; i > 0; i--) { + int coordinate = global_array_index % output_shape[i]; + index_tuple[i] = coordinate; + global_array_index -= coordinate; + global_array_index /= output_shape[i]; + } + index_tuple[0] = global_array_index; + + index_tuple[axis] /= rep; + + int copied_value_index = 0; + for (size_t i = 0; i < input_dim - 1; i++) { + copied_value_index += index_tuple[i] * coefficients[i]; + } + copied_value_index += index_tuple[input_dim - 1]; + + output[gt_id] = input[copied_value_index]; + } +} + +template +void CalRepeatElements1d( + const T *input, const int rep, const int axis, T *output, const int output_size, cudaStream_t cuda_stream) { + RepeatElements1d<<>>(input, rep, axis, output, output_size); +} + +template +void CalRepeatElements2d(const T *input, const int input_d1, const int rep, const int axis, T *output, + const int output_d1, const int output_size, cudaStream_t cuda_stream) { + RepeatElements2d<<>>(input, input_d1, rep, axis, output, + output_d1, output_size); +} + +template +void CalRepeatElements3d(const T *input, const int input_d1, const int input_d2, const int rep, const int axis, + T *output, const int output_d1, const int output_d2, const int output_size, + cudaStream_t cuda_stream) { + RepeatElements3d<<>>(input, input_d1, input_d2, rep, axis, + output, output_d1, output_d2, output_size); +} + +template +void CalRepeatElements4d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int rep, + const int axis, T *output, const int output_d1, const int output_d2, const int output_d3, + const int output_size, cudaStream_t cuda_stream) { + RepeatElements4d<<>>(input, input_d1, input_d2, input_d3, rep, + axis, output, output_d1, output_d2, + output_d3, output_size); +} + +template +void CalRepeatElements5d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int input_d4, + const int rep, const int axis, T *output, const int output_d1, const int output_d2, + const int output_d3, const int output_d4, const int output_size, cudaStream_t cuda_stream) { + RepeatElements5d<<>>(input, input_d1, input_d2, input_d3, + input_d4, rep, axis, output, output_d1, + output_d2, output_d3, output_d4, + output_size); +} + +template +void CalRepeatElements(const T *input, const int input_dim, const int* const input_shape, + const int* const input_shape_cumulative_product, const int rep, const int axis, T *output, + const int* const output_shape, const int output_size, cudaStream_t cuda_stream) { + RepeatElements<<>>(input, input_dim, input_shape, + input_shape_cumulative_product, rep, axis, + output, output_shape, output_size); +} + +// int32 +template void CalRepeatElements1d( + const int *input, const int rep, const int axis, int *output, const int output_size, cudaStream_t cuda_stream); + +template void CalRepeatElements2d(const int *input, const int input_d1, const int rep, const int axis, int *output, + const int output_d1, const int output_size, cudaStream_t cuda_stream); + +template void CalRepeatElements3d(const int *input, const int input_d1, const int input_d2, const int rep, + const int axis, int *output, const int output_d1, const int output_d2, + const int output_size, cudaStream_t cuda_stream); + +template void CalRepeatElements4d(const int *input, const int input_d1, const int input_d2, const int input_d3, + const int rep, const int axis, int *output, const int output_d1, + const int output_d2, const int output_d3, const int output_size, + cudaStream_t cuda_stream); + +template void CalRepeatElements5d(const int *input, const int input_d1, const int input_d2, const int input_d3, + const int input_d4, const int rep, const int axis, int *output, + const int output_d1, const int output_d2, const int output_d3, + const int output_d4, const int output_size, cudaStream_t cuda_stream); + +template void CalRepeatElements(const int *input, const int input_dim, const int* const input_shape, + const int* const input_shape_cumulative_product, const int rep, const int axis, + int *output, const int* const output_shape, const int output_size, + cudaStream_t cuda_stream); + +// float16 +template void CalRepeatElements1d( + const half *input, const int rep, const int axis, half *output, const int output_size, cudaStream_t cuda_stream); + +template void CalRepeatElements2d(const half *input, const int input_d1, const int rep, const int axis, + half *output, const int output_d1, const int output_size, + cudaStream_t cuda_stream); + +template void CalRepeatElements3d(const half *input, const int input_d1, const int input_d2, const int rep, + const int axis, half *output, const int output_d1, const int output_d2, + const int output_size, cudaStream_t cuda_stream); + +template void CalRepeatElements4d(const half *input, const int input_d1, const int input_d2, const int input_d3, + const int rep, const int axis, half *output, const int output_d1, + const int output_d2, const int output_d3, const int output_size, + cudaStream_t cuda_stream); + +template void CalRepeatElements5d(const half *input, const int input_d1, const int input_d2, const int input_d3, + const int input_d4, const int rep, const int axis, half *output, + const int output_d1, const int output_d2, const int output_d3, + const int output_d4, const int output_size, cudaStream_t cuda_stream); + +template void CalRepeatElements(const half *input, const int input_dim, const int* const input_shape, + const int* const input_shape_cumulative_product, const int rep, const int axis, + half *output, const int* const output_shape, const int output_size, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh new file mode 100644 index 0000000000..34221c3845 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_ + +#include + +#define REPEAT_ELEMENTS_MAX_INPUT_DIM 100 + +template +void CalRepeatElements1d( + const T *input, const int rep, const int axis, T *output, const int output_size, cudaStream_t cuda_stream); + +template +void CalRepeatElements2d(const T *input, const int input_d1, const int rep, const int axis, T *output, + const int output_d1, const int output_size, cudaStream_t cuda_stream); + +template +void CalRepeatElements3d(const T *input, const int input_d1, const int input_d2, const int rep, const int axis, + T *output, const int output_d1, const int output_d2, const int output_size, + cudaStream_t cuda_stream); + +template +void CalRepeatElements4d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int rep, + const int axis, T *output, const int output_d1, const int output_d2, const int output_d3, + const int output_size, cudaStream_t cuda_stream); + +template +void CalRepeatElements5d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int input_d4, + const int rep, const int axis, T *output, const int output_d1, const int output_d2, + const int output_d3, const int output_d4, const int output_size, cudaStream_t cuda_stream); + +template +void CalRepeatElements(const T *input, const int input_dim, const int* const input_shape, + const int* const input_shape_cumulative_product, const int rep, const int axis, T *output, + const int* const output_shape, const int output_size, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1e9a0b5713..3c8df78a68 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, - Unique, GatherD, Identity) + Unique, GatherD, Identity, RepeatElements) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, @@ -381,7 +381,8 @@ __all__ = [ "Push", "Pull", "ReLUV2", - 'SparseToDense', + "SparseToDense", + "RepeatElements", ] __all__.sort() diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 24c3784afa..95562dcedd 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4022,3 +4022,52 @@ class Identity(PrimitiveWithInfer): 'dtype': x['dtype'], 'value': None} return out + + +class RepeatElements(PrimitiveWithInfer): + """ + Repeat elements of a tensor along an axis, like np.repeat. + + Args: + rep (int): The number of times to repeat, must be positive, required. + axis (int): The axis along which to repeat, default 0. + + Inputs: + - **x** (Tensor) - The tensor to repeat values for. Must be of type int32 or float16. + + Outputs: + One tensor with values repeated along the specified axis. If x has shape + (s1, s2, ..., sn) and axis is i, the output will have shape (s1, s2, ..., si * rep, ..., sn) + + + Examples: + >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) + >>> repeat_elements = P.RepeatElements(rep = 2, axis = 0) + >>> output = repeat_elements(x) + [[0, 1, 2], + [0, 1, 2], + [3, 4, 5], + [3, 4, 5]], + """ + + @prim_attr_register + def __init__(self, rep, axis=0): + self.init_prim_io_names(inputs=["x"], outputs=["output"]) + + validator.check_value_type("rep", rep, [int], self.name) + self.rep = rep + + validator.check_value_type("axis", axis, [int], self.name) + self.axis = axis + + def infer_shape(self, x_shape): + validator.check("rep", self.rep, "", 0, Rel.GT, self.name) + validator.check("axis", self.axis, "dimension of x", len(x_shape), Rel.LT, self.name) + validator.check("axis", self.axis, "negative dimension of x", -len(x_shape), Rel.GE, self.name) + + x_shape[self.axis] *= self.rep + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) + return x_dtype diff --git a/tests/st/ops/gpu/test_repeat_elements_op.py b/tests/st/ops/gpu/test_repeat_elements_op.py new file mode 100644 index 0000000000..74941a604d --- /dev/null +++ b/tests/st/ops/gpu/test_repeat_elements_op.py @@ -0,0 +1,656 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context + +class RepeatElementsNet(nn.Cell): + def __init__(self, rep, axis): + super(RepeatElementsNet, self).__init__() + self.repeat_elements = P.RepeatElements(rep, axis) + + def construct(self, x): + return self.repeat_elements(x) + + +def repeat_elements(x, rep, axis): + repeat_elements_net = RepeatElementsNet(rep, axis) + return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_1d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_1d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1) + + ms_out = repeat_elements(a, 5, 0) + np_out = a.repeat(5, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 513, 0) + np_out = a.repeat(513, 0) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_1d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_1d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24) + + ms_out = repeat_elements(a, 231, 0) + np_out = a.repeat(231, 0) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_2d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_2d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1) + + ms_out = repeat_elements(a, 13, 0) + np_out = a.repeat(13, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 13, 1) + np_out = a.repeat(13, 1) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_2d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(12, 2) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_2d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(8, 3) + + ms_out = repeat_elements(a, 23, 0) + np_out = a.repeat(23, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 1) + np_out = a.repeat(23, 1) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_3d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_3d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1) + + ms_out = repeat_elements(a, 43, 0) + np_out = a.repeat(43, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 43, 1) + np_out = a.repeat(43, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 43, 2) + np_out = a.repeat(43, 2) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_3d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(60).reshape(6, 2, 5) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_3d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(60).reshape(3, 4, 5) + + ms_out = repeat_elements(a, 14, 0) + np_out = a.repeat(14, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 14, 1) + np_out = a.repeat(14, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 14, 2) + np_out = a.repeat(14, 2) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_4d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_4d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1) + + ms_out = repeat_elements(a, 17, 0) + np_out = a.repeat(17, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 17, 1) + np_out = a.repeat(17, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 17, 2) + np_out = a.repeat(17, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 17, 3) + np_out = a.repeat(17, 3) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_4d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(4, 3, 2, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_4d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(2, 2, 2, 3) + + ms_out = repeat_elements(a, 23, 0) + np_out = a.repeat(23, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 1) + np_out = a.repeat(23, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 2) + np_out = a.repeat(23, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 3) + np_out = a.repeat(23, 3) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_5d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_5d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 19, 0) + np_out = a.repeat(19, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 1) + np_out = a.repeat(19, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 2) + np_out = a.repeat(19, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 3) + np_out = a.repeat(19, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 4) + np_out = a.repeat(19, 4) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_5d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(224).reshape(8, 2, 1, 7, 2) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_5d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(224).reshape(1, 7, 4, 4, 2) + + ms_out = repeat_elements(a, 7, 0) + np_out = a.repeat(7, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 1) + np_out = a.repeat(7, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 2) + np_out = a.repeat(7, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 3) + np_out = a.repeat(7, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 4) + np_out = a.repeat(7, 4) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_large_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 5) + np_out = a.repeat(1, 5) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 6) + np_out = a.repeat(1, 6) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 7) + np_out = a.repeat(1, 7) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_large_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 42, 0) + np_out = a.repeat(42, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 1) + np_out = a.repeat(42, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 2) + np_out = a.repeat(42, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 3) + np_out = a.repeat(42, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 4) + np_out = a.repeat(42, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 5) + np_out = a.repeat(42, 5) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 6) + np_out = a.repeat(42, 6) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 7) + np_out = a.repeat(42, 7) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_large_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1152).reshape(2, 3, 4, 8, 1, 1, 2, 3) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 5) + np_out = a.repeat(1, 5) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 6) + np_out = a.repeat(1, 6) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 7) + np_out = a.repeat(1, 7) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_large_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1152).reshape(4, 3, 4, 2, 1, 1, 4, 3) + + ms_out = repeat_elements(a, 4, 0) + np_out = a.repeat(4, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 1) + np_out = a.repeat(4, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 2) + np_out = a.repeat(4, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 3) + np_out = a.repeat(4, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 4) + np_out = a.repeat(4, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 5) + np_out = a.repeat(4, 5) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 6) + np_out = a.repeat(4, 6) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 7) + np_out = a.repeat(4, 7) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_half(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 1, 1, 4, 3) + + ms_out = repeat_elements(a, 4, 0) + np_out = a.repeat(4, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 1) + np_out = a.repeat(4, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 2) + np_out = a.repeat(4, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 3) + np_out = a.repeat(4, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 4) + np_out = a.repeat(4, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 5) + np_out = a.repeat(4, 5) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 6) + np_out = a.repeat(4, 6) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 7) + np_out = a.repeat(4, 7) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_net_multi_use(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + rep = 3 + axis = 4 + repeat_elements_net = RepeatElementsNet(rep, axis) + + a = np.arange(64).reshape(2, 2, 2, 2, 2, 2) + ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() + np_out = a.repeat(rep, axis) + np.testing.assert_array_equal(np_out, ms_out) + + a = np.arange(128).reshape(2, 2, 4, 2, 2, 2) + ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() + np_out = a.repeat(rep, axis) + np.testing.assert_array_equal(np_out, ms_out) + + a = np.arange(18).reshape(1, 1, 3, 2, 3, 1) + ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() + np_out = a.repeat(rep, axis) + np.testing.assert_array_equal(np_out, ms_out) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_invalid_input(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(64).reshape(2, 2, 2, 2, 2, 2) + with pytest.raises(ValueError): + _ = repeat_elements(a, 0, 0) + + with pytest.raises(ValueError): + _ = repeat_elements(a, 1, 6) + + with pytest.raises(ValueError): + _ = repeat_elements(a, 1, -7)