diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cu new file mode 100644 index 0000000000..16a0f9a6af --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cu @@ -0,0 +1,106 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +__global__ void InitErrorCode(IndexAddErrorCode *error_code) { + *error_code = IndexAddErrorCode::kOk; +} + +__global__ void ValidateIndexValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, + IndexAddErrorCode *error_code) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_axis_size; pos += blockDim.x * gridDim.x) { + const int idx_value = index[pos]; + if (idx_value < 0 || idx_value >= dst_axis_size) { + *error_code = IndexAddErrorCode::kIndexOutOfRange; + return; + } + } + return; +} + +template +__global__ void IndexAddAtomic(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_size; pos += blockDim.x * gridDim.x) { + const size_t src_axis_idx = (pos / inner_size) % src_axis_size; + const size_t src_outer_idx = pos / (src_axis_size * inner_size); + const size_t dst_axis_idx = static_cast(index[src_axis_idx]); + const size_t dst_inner_idx = pos % inner_size; + const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; + MsAtomicAdd(&dst[dst_idx], src[pos]); + } + return; +} + +template +__global__ void IndexAdd(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_size; pos += blockDim.x * gridDim.x) { + const size_t src_axis_idx = (pos / inner_size) % src_axis_size; + const size_t src_outer_idx = pos / (src_axis_size * inner_size); + const size_t dst_axis_idx = static_cast(index[src_axis_idx]); + const size_t dst_inner_idx = pos % inner_size; + const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; + dst[dst_idx] += src[pos]; + } + return; +} + +void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, + IndexAddErrorCode *error_code, cudaStream_t cuda_stream) { + InitErrorCode<<<1, 1, 0, cuda_stream>>>(error_code); + ValidateIndexValues<<>>(index, src_axis_size, dst_axis_size, + error_code); +} + +template +void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size, + const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream) { + size_t src_size = outer_size * src_axis_size * inner_size; + if (use_lock) { + IndexAddAtomic<<>>(dst, index, src, src_size, outer_size, + src_axis_size, dst_axis_size, inner_size); + } else { + IndexAdd<<>>(dst, index, src, src_size, outer_size, + src_axis_size, dst_axis_size, inner_size); + } + return; +} + +template void CalIndexAdd(double *dst, const int *index, const double *src, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, + cudaStream_t cuda_stream); +template void CalIndexAdd(float *dst, const int *index, const float *src, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, + cudaStream_t cuda_stream); +template void CalIndexAdd(half *dst, const int *index, const half *src, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, + cudaStream_t cuda_stream); +template void CalIndexAdd(int *dst, const int *index, const int *src, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, + cudaStream_t cuda_stream); +template void CalIndexAdd(int16_t *dst, const int *index, const int16_t *src, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, + cudaStream_t cuda_stream); +template void CalIndexAdd(int8_t *dst, const int *index, const int8_t *src, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, + cudaStream_t cuda_stream); +template void CalIndexAdd(uint8_t *dst, const int *index, const uint8_t *src, const size_t outer_size, + const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh new file mode 100644 index 0000000000..a32adaeafe --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ +enum class IndexAddErrorCode { + kOk = 0, + kIndexOutOfRange +}; + +void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, + IndexAddErrorCode *error_code, cudaStream_t cuda_stream); + +template +void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size, + const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.cc new file mode 100644 index 0000000000..f5758205a0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(IndexAdd, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + IndexAddGpuKernel, double) +MS_REG_GPU_KERNEL_ONE(IndexAdd, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + IndexAddGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IndexAdd, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + IndexAddGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IndexAdd, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + IndexAddGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(IndexAdd, + KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + IndexAddGpuKernel, int16_t) +MS_REG_GPU_KERNEL_ONE(IndexAdd, + KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + IndexAddGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(IndexAdd, + KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + IndexAddGpuKernel, uint8_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h new file mode 100644 index 0000000000..5ceb1b1822 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h @@ -0,0 +1,155 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh" +namespace mindspore { +namespace kernel { +template +class IndexAddGpuKernel : public GpuKernel { + public: + IndexAddGpuKernel() + : dst_size_(0), + index_size_(0), + src_size_(0), + output_size_(0), + outer_size_(0), + src_axis_size_(0), + dst_axis_size_(0), + inner_size_(0), + use_lock_(true), + check_index_bound_(true) {} + ~IndexAddGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *dst = GetDeviceAddress(inputs, 0); + int *index = GetDeviceAddress(inputs, 1); + T *src = GetDeviceAddress(inputs, 2); + T *dst_out = GetDeviceAddress(outputs, 0); + + if (check_index_bound_) { + IndexAddErrorCode *error_code_addr = GetDeviceAddress(workspace, 0); + IndexAddErrorCode error_code = IndexAddErrorCode::kOk; + ValidateIndexAddInputValues(index, src_axis_size_, dst_axis_size_, error_code_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(&error_code, error_code_addr, sizeof(IndexAddErrorCode), + cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), + "Failed to copy error code to host."); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); + LogExceptionIfNotOk(error_code); + } + CalIndexAdd(dst, index, src, outer_size_, src_axis_size_, dst_axis_size_, inner_size_, use_lock_, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(&dst_out[0], &dst[0], dst_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but index add needs 3 inputs."; + return false; + } + std::vector dst_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector index_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + int64_t src_rank = src_shape.size(); + int64_t axis = GetAttr(kernel_node, "axis"); + if (axis < 0) { + axis += src_rank; + } + outer_size_ = 1; + for (int64_t i = axis - 1; i >= 0; i--) { + outer_size_ *= src_shape[i]; + } + inner_size_ = 1; + for (int64_t i = axis + 1; i < src_rank; i++) { + inner_size_ *= src_shape[i]; + } + src_axis_size_ = src_shape[axis]; + dst_axis_size_ = dst_shape[axis]; + dst_size_ = sizeof(T); + for (auto x : dst_shape) { + dst_size_ *= x; + } + index_size_ = sizeof(int); + for (auto x : index_shape) { + index_size_ *= x; + } + src_size_ = sizeof(T); + for (auto x : src_shape) { + src_size_ *= x; + } + output_size_ = dst_size_; + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(dst_size_); + input_size_list_.push_back(index_size_); + input_size_list_.push_back(src_size_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(sizeof(IndexAddErrorCode)); + } + + private: + void LogExceptionIfNotOk(IndexAddErrorCode error_code) { + switch (error_code) { + case IndexAddErrorCode::kOk: + return; + case IndexAddErrorCode::kIndexOutOfRange: + MS_LOG(EXCEPTION) << "gpu IndexAdd op error: values of index tensor is out of range"; + break; + default: + MS_LOG(EXCEPTION) << "gpu IndexAdd op unknown error"; + } + } + + size_t dst_size_; + size_t index_size_; + size_t src_size_; + size_t output_size_; + size_t outer_size_; + size_t src_axis_size_; + size_t dst_axis_size_; + size_t inner_size_; + bool use_lock_; + bool check_index_bound_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 20bd7cfd52..33c4de8c15 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, - MatrixInverse) + MatrixInverse, IndexAdd) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, @@ -418,6 +418,7 @@ __all__ = [ "SparseToDense", "MatrixInverse", "Range", + "IndexAdd", ] __all__.sort() diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index ea5887eae0..46c9b824eb 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -4199,3 +4199,68 @@ class MatrixInverse(PrimitiveWithInfer): validator.check_int(len(x_shape), 2, Rel.GE, self.name, None) validator.check_equal_int(x_shape[-1], x_shape[-2], self.name, None) return x_shape + + +class IndexAdd(PrimitiveWithInfer): + """ + Adds tenosr y to specified axis and indices of tensor x. + + Args: + axis (int): The dimension along wich to index. + + Inputs: + - **input_x** (Tensor) - The input tensor to add to, with data type float64, float32, float16, int32, int16, + int8, uint8. + - **indices** (Tensor) - The index of `input_x` on the `axis`th dimension to add to, with data type int32. + The `indices` must be 1D with the size same as the size of the `axis`th dimension of `input_y`. The values + of `indices` should be in the range of 0 to the size of the `axis`th dimension of `input_x`. + - **input_y** (Tensor) - The input tensor with the value to add. Must have same data type as `input_x`. + The shape must be the same as `input_x` except the `axis`th dimension. + + Outputs: + Tensor, has the same shape and dtype as input_x. + + Supported Platforms: + ``GPU`` + + Examples: + >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [6, 7, 8]]), mindspore.float32) + >>> input_y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32) + >>> indices = Tensor(np.array([0, 2]), mindspore.int32) + >>> index_add = ops.IndexAdd(axis=1) + >>> output = index_add(input_x, indices, input_y) + >>> print(output) + [[ 1.5 2. 4. ] + [ 5. 5. 7.5] + [ 8. 7. 10.5]] + """ + + @prim_attr_register + def __init__(self, axis, use_lock=True, check_index_bound=True): + """Initialize InplaceAdd""" + self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) + self.axis = axis + validator.check_value_type('axis', axis, [int], self.name) + + def infer_dtype(self, x_dtype, idx_type, y_dtype): + args = {'x': x_dtype, 'y': y_dtype} + valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.int32, mstype.int16, mstype.int8, + mstype.uint8] + validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) + valid_idx_type = [mstype.int32] + validator.check_tensor_dtype_valid("idx_type", idx_type, valid_idx_type, self.name) + return x_dtype + + def infer_shape(self, x_shape, idx_shape, y_shape): + validator.check("x rank", len(x_shape), "y rank", len(y_shape), Rel.EQ, self.name) + validator.check("size of indices", idx_shape[0], "dimension of y[axis]", y_shape[self.axis], + Rel.EQ, self.name) + x_rank = len(x_shape) + validator.check_int_range(self.axis, -x_rank - 1, x_rank, Rel.INC_BOTH, 'axis', self.name) + axis = self.axis if self.axis >= 0 else x_rank + self.axis + for dim in range(x_rank): + if dim == axis: + validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.GE, self.name) + else: + validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name) + return x_shape diff --git a/tests/st/ops/gpu/test_index_add_op.py b/tests/st/ops/gpu/test_index_add_op.py new file mode 100644 index 0000000000..fe2343edbf --- /dev/null +++ b/tests/st/ops/gpu/test_index_add_op.py @@ -0,0 +1,259 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetIndexAdd(nn.Cell): + def __init__(self, axis): + super(NetIndexAdd, self).__init__() + self.index_add = P.IndexAdd(axis) + + def construct(self, x, idx, y): + z = self.index_add(x, idx, y) + return z + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add(): + x = np.arange(2 * 3 * 4 * 4).reshape(2, 3, 4, 4).astype(np.float32) + y0 = np.ones((1, 3, 4, 4), dtype=np.float32) + idx0 = np.array([1]).astype(np.int32) + axis0 = 0 + expect = np.copy(x) + expect[idx0, :, :, :] = expect[idx0, :, :, :] + y0 + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetIndexAdd(axis0) + output = net(Tensor(x), Tensor(idx0), Tensor(y0)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis0) + output = net(Tensor(x), Tensor(idx0), Tensor(y0)) + assert (output.asnumpy() == expect).all() + + y1 = np.ndarray((2, 2, 4, 4)).astype(np.float32) + y1.fill(0.1) + idx1 = np.array([0, 2]).astype(np.int32) + axis1 = 1 + expect = np.copy(x) + expect[:, idx1, :, :] = expect[:, idx1, :, :] + y1 + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + net = NetIndexAdd(axis1) + output = net(Tensor(x), Tensor(idx1), Tensor(y1)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis1) + output = net(Tensor(x), Tensor(idx1), Tensor(y1)) + assert (output.asnumpy() == expect).all() + + y2 = np.ones((2, 3, 2, 4)).astype(np.float32) + y2.fill(5.5) + idx2 = np.array([1, 3]).astype(np.int32) + axis2 = 2 + expect = np.copy(x) + expect[:, :, idx2, :] = expect[:, :, idx2, :] + y2 + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + net = NetIndexAdd(axis2) + output = net(Tensor(x), Tensor(idx2), Tensor(y2)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis2) + output = net(Tensor(x), Tensor(idx2), Tensor(y2)) + assert (output.asnumpy() == expect).all() + + y3 = np.ones((2, 3, 4, 3)).astype(np.float32) + y3.fill(1000.00) + idx3 = np.array([0, 2, 3]).astype(np.int32) + axis3 = 3 + expect = np.copy(x) + expect[:, :, :, idx3] = expect[:, :, :, idx3] + y3 + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + net = NetIndexAdd(axis3) + output = net(Tensor(x), Tensor(idx3), Tensor(y3)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis3) + output = net(Tensor(x), Tensor(idx3), Tensor(y3)) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_float16(): + x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.float16) + y = np.ones((2, 2, 4), dtype=np.float16) + idx = np.array([0, 2]).astype(np.int32) + axis = 1 + expect = np.copy(x) + expect[:, idx, :] = expect[:, idx, :] + y + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_int32(): + x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.int32) + y = np.ones((2, 2, 4), dtype=np.int32) + idx = np.array([0, 2]).astype(np.int32) + axis = 1 + expect = np.copy(x) + expect[:, idx, :] = expect[:, idx, :] + y + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_int8(): + x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.int8) + y = np.ones((2, 2, 4), dtype=np.int8) + idx = np.array([0, 2]).astype(np.int32) + axis = 1 + expect = np.copy(x) + expect[:, idx, :] = expect[:, idx, :] + y + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_uint8(): + x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.uint8) + y = np.ones((2, 2, 4), dtype=np.uint8) + idx = np.array([0, 2]).astype(np.int32) + axis = 1 + expect = np.copy(x) + expect[:, idx, :] = expect[:, idx, :] + y + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_float64(): + x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.float64) + y = np.ones((2, 2, 4), dtype=np.float64) + idx = np.array([0, 2]).astype(np.int32) + axis = 1 + expect = np.copy(x) + expect[:, idx, :] = expect[:, idx, :] + y + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_int16(): + x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.int16) + y = np.ones((2, 2, 4), dtype=np.int16) + idx = np.array([0, 2]).astype(np.int32) + axis = 1 + expect = np.copy(x) + expect[:, idx, :] = expect[:, idx, :] + y + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetIndexAdd(axis) + output = net(Tensor(x), Tensor(idx), Tensor(y)) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_invalid_inputs(): + x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.uint8) + y = np.ones((2, 2, 4), dtype=np.uint8) + with pytest.raises(TypeError): + #axis not int + net = NetIndexAdd(1.0) + + #x and y don't have the same type + y = np.ones((2, 2, 4), dtype=np.float32) + idx = np.array([0, 1]).astype(np.int32) + net = NetIndexAdd(1) + _ = net(Tensor(x), Tensor(idx), Tensor(y)) + + with pytest.raises(ValueError): + #index size not the same as len(y[axis]) + idx = np.array([0]).astype(np.int32) + net = NetIndexAdd(1) + _ = net(Tensor(x), Tensor(idx), Tensor(y)) + + #x and y don't have same rank + y = np.ones((2, 2), dtype=np.uint8) + idx = np.array([0, 1]).astype(np.int32) + net = NetIndexAdd(1) + _ = net(Tensor(x), Tensor(idx), Tensor(y)) + + #x and y don't have same shape on dimensions other than axis-th dimension + y = np.ones((2, 2, 5), dtype=np.uint8) + idx = np.array([0, 1]).astype(np.int32) + net = NetIndexAdd(1) + _ = net(Tensor(x), Tensor(idx), Tensor(y)) + + with pytest.raises(RuntimeError) as info: + #index value not in the range of 0 to len(x[axis]) + idx = np.array([5, 6]).astype(np.int32) + net = NetIndexAdd(1) + _ = net(Tensor(x), Tensor(idx), Tensor(y)) + assert "out of range" in str(info.value)