From dd72f44b276b7db958dd1a1fee37b08a4be5879e Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Sun, 7 Mar 2021 23:51:48 -0500 Subject: [PATCH] tensor_scatter_update new op quick initial commit fix ci fix ci fix ci fix ci --- .../tensor_scatter_update_gpu_kernel.cc | 61 +++++ .../arrays/tensor_scatter_update_gpu_kernel.h | 212 ++++++++++++++++++ .../gpu/cuda_impl/tensor_scatter_update.cu | 84 +++++++ .../gpu/cuda_impl/tensor_scatter_update.cuh | 26 +++ .../st/ops/gpu/test_tensor_scatter_update.py | 79 +++++++ 5 files changed, 462 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh create mode 100644 tests/st/ops/gpu/test_tensor_scatter_update.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc new file mode 100644 index 0000000000..70ac26aa97 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + TensorScatterUpdateGpuFwdKernel, half, int) + +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + TensorScatterUpdateGpuFwdKernel, float, int) + +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + TensorScatterUpdateGpuFwdKernel, char, int) + +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + TensorScatterUpdateGpuFwdKernel, int, int) + +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + TensorScatterUpdateGpuFwdKernel, uchar, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.h new file mode 100644 index 0000000000..f3daf6e539 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.h @@ -0,0 +1,212 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_UPDATE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_UPDATE_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class TensorScatterUpdateGpuFwdKernel : public GpuKernel { + public: + TensorScatterUpdateGpuFwdKernel() + : input_size_(1), + update_size_(1), + indices_size_(1), + output_size_(1), + block_size_(1), + indices_stride_(nullptr), + work_shape_(nullptr), + indices_dim_0_(0), + indices_dim_1_(0), + memcpy_flag_(false) {} + ~TensorScatterUpdateGpuFwdKernel() { + if (indices_stride_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(indices_stride_)); + } + if (work_shape_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(work_shape_)); + } + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input = GetDeviceAddress(inputs, 0); + S *indices = GetDeviceAddress(inputs, 1); + T *update = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + + if (!memcpy_flag_) { + const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(indices_stride_, &vec_indices_stride_[0], indices_len, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpy failed in TensorScatterUpdateGpuFwdKernel::Launch."); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpy failed in TensorScatterUpdateGpuFwdKernel::Launch."); + memcpy_flag_ = true; + } + + CHECK_CUDA_RET_WITH_EXCEPT( + kernel_node_, + cudaMemsetAsync(output, static_cast(0.0), output_size_, reinterpret_cast(stream_ptr)), + "cudaMemSet failed in TensorScatterUpdateGpuFwdKernel::Launch."); + + const size_t update_size = update_size_ / sizeof(T); + const size_t output_size = output_size_ / sizeof(T); + + TensorScatterUpdate(input, indices, update, output, block_size_, update_size, output_size, indices_dim_0_, + indices_dim_1_, indices_stride_, work_shape_, reinterpret_cast(stream_ptr)); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(&output[0], &input[0], input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + memcpy_flag_ = false; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but TensorScatterUpdate needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but TensorScatterUpdate has 1 output."; + return false; + } + + update_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + std::vector shape_me = input_shapes_; + (void)std::transform(shape_me.begin(), shape_me.end(), std::back_inserter(vec_work_shape_), + [](const size_t &value) { return static_cast(value); }); + + GetSize(); + + const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); + if (indices_stride_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len; + } + indices_stride_ = static_cast(indices_stride_work); + + const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); + void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len); + if (work_shape_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len; + } + work_shape_ = static_cast(work_shape_work); + + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(indices_size_); + input_size_list_.push_back(update_size_); + output_size_list_.push_back(output_size_); + return; + } + + void GetSize() { + input_size_ = sizeof(T); + for (size_t i = 0; i < input_shapes_.size(); i++) { + input_size_ *= input_shapes_[i]; + } + + indices_size_ = sizeof(S); + for (size_t i = 0; i < indices_shapes_.size(); i++) { + indices_size_ *= indices_shapes_[i]; + } + update_size_ = sizeof(T); + for (size_t i = 0; i < update_shapes_.size(); i++) { + update_size_ *= update_shapes_[i]; + } + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shapes_.size(); i++) { + output_size_ *= output_shapes_[i]; + } + + // calculate indices dim 0/1 + indices_dim_0_ = indices_shapes_[0]; + indices_dim_1_ = indices_shapes_[indices_shapes_.size() - 1]; + + // calculate block_size + for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) { + block_size_ *= output_shapes_[i]; + } + + // calculate indices_stride + vec_indices_stride_.resize(indices_dim_1_, 0); + vec_indices_stride_[indices_dim_1_ - 1] = block_size_; + + for (size_t i = indices_dim_1_ - 1; i > 0; --i) { + vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i]; + } + } + + private: + std::vector update_shapes_; + std::vector indices_shapes_; + std::vector input_shapes_; + std::vector output_shapes_; + std::vector vec_indices_stride_; + std::vector vec_work_shape_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t update_size_; + size_t indices_size_; + size_t output_size_; + size_t block_size_; + + S *indices_stride_; + S *work_shape_; + size_t indices_dim_0_; + size_t indices_dim_1_; + bool memcpy_flag_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_UPDATE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu new file mode 100644 index 0000000000..8470cf29c9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *output, const size_t block_size, + const size_t input_size, const size_t output_size, const size_t indices_dim_0, + const size_t indices_dim_1, S *indices_stride, S *work_shape) { + int i, j; + for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; + read_index += blockDim.x * gridDim.x) { + size_t write_index = 0; + bool out_bound = false; + + i = read_index / block_size; + j = read_index % block_size; + + for (size_t k = 0; k < indices_dim_1; k++) { + S indices_i = indices[i * indices_dim_1 + k]; + out_bound |= indices_i >= work_shape[k]; + write_index += indices_i * indices_stride[k]; + } + + write_index += j; + out_bound |= write_index >= output_size; + + if (!out_bound) { + input[write_index] = update[read_index]; + } + } +} + +template +void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, + S *work_shape, cudaStream_t stream) { + TensorScatterUpdateKernel<<>>(input, indices, update, output, + block_size, input_size, output_size, + indices_dim_0, indices_dim_1, + indices_stride, work_shape); + return; +} + +template void TensorScatterUpdate(half *input, int *indices, half *update, half *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void TensorScatterUpdate(float *input, int *indices, float *update, float *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void TensorScatterUpdate(char *input, int *indices, char *update, char *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void TensorScatterUpdate(int *input, int *indices, int *update, int *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void TensorScatterUpdate(unsigned char *input, int *indices, unsigned char *update, + unsigned char *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, + const size_t &indices_dim_0, const size_t &indices_dim_1, + int *indices_stride, int *work_shape, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh new file mode 100644 index 0000000000..c1dd9976d2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH + +#include "runtime/device/gpu/cuda_common.h" + +template +void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, + S *work_shape, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH diff --git a/tests/st/ops/gpu/test_tensor_scatter_update.py b/tests/st/ops/gpu/test_tensor_scatter_update.py new file mode 100644 index 0000000000..333a20f104 --- /dev/null +++ b/tests/st/ops/gpu/test_tensor_scatter_update.py @@ -0,0 +1,79 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.scatter = P.TensorScatterUpdate() + + def construct(self, x, indices, update): + return self.scatter(x, indices, update) + + +def scatter_net(x, indices, update): + scatter = Net() + return scatter(Tensor(x), Tensor(indices), Tensor(update)).asnumpy() + +def test_scatter(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + arr_input = np.arange(21).reshape(3, 7).astype(np.float32) + arr_indices = np.array([[0, 1], [1, 1], [0, 5], [0, 2], [2, 1]]).astype(np.int32) + arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(np.float32) + out = scatter_net(arr_input, arr_indices, arr_update) + expected = np.array([[0, 3.2, -2.2, 3, 4, 5.3, 6], + [7, 1.1, 9, 10, 11, 12, 13], + [14, -1, 16, 17, 18, 19, 20]]).astype(np.float32) + np.testing.assert_allclose(out, expected, rtol=1e-6) + + arr_input = np.arange(24).reshape(4, 2, 3).astype(np.float32) + arr_indices = np.array([[0, 0, 0], [1, 1, 1], [0, 1, 1], [3, 0, 1]]).astype(np.int32) + arr_update = np.array([-1, -2, -3, -4]).astype(np.float32) + out = scatter_net(arr_input, arr_indices, arr_update) + expected = np.array([[[-1, 1, 2], + [3, -3, 5]], + [[6, 7, 8], + [9, -2, 11]], + [[12, 13, 14], + [15, 16, 17]], + [[18, -4, 20], + [21, 22, 23]]]).astype(np.float32) + np.testing.assert_allclose(out, expected, rtol=1e-6) + + arr_input = np.arange(25).reshape(5, 5).astype(np.float32) + arr_indices = np.array([[[0, 0], + [1, 1], + [2, 2], + [3, 3], + [4, 4]], + [[0, 4], + [1, 3], + [2, 2], + [3, 1], + [4, 0]]]).astype(np.int32) + arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float32) + out = scatter_net(arr_input, arr_indices, arr_update) + expected = np.array([[11, 1, 2, 3, 66], + [5, 22, 7, 77, 9], + [10, 11, 33, 13, 14], + [15, 99, 17, 44, 19], + [100, 21, 22, 23, 55]]).astype(np.float32) + np.testing.assert_allclose(out, expected, rtol=1e-6)