parent
							
								
									ca6756b5fe
								
							
						
					
					
						commit
						f679568d86
					
				| @ -0,0 +1,33 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | MS_REG_GPU_KERNEL_TWO( | ||||||
|  |   GatherNd, | ||||||
|  |   KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | ||||||
|  |   GatherNdGpuFwdKernel, float, int) | ||||||
|  | MS_REG_GPU_KERNEL_TWO( | ||||||
|  |   GatherNd, | ||||||
|  |   KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | ||||||
|  |   GatherNdGpuFwdKernel, half, int) | ||||||
|  | MS_REG_GPU_KERNEL_TWO( | ||||||
|  |   GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||||
|  |   GatherNdGpuFwdKernel, int, int) | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
| @ -0,0 +1,162 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_GATHERND_GPU_KERNEL_H | ||||||
|  | #define MINDSPORE_GATHERND_GPU_KERNEL_H | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | template <typename T, typename S> | ||||||
|  | class GatherNdGpuFwdKernel : public GpuKernel { | ||||||
|  |  public: | ||||||
|  |   GatherNdGpuFwdKernel() : dev_batch_strides_(nullptr), dev_batch_indices_(nullptr) {} | ||||||
|  |   ~GatherNdGpuFwdKernel() { | ||||||
|  |     if (dev_batch_strides_ != nullptr) { | ||||||
|  |       device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(dev_batch_strides_)); | ||||||
|  |     } | ||||||
|  |     if (dev_batch_indices_ != nullptr) { | ||||||
|  |       device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(dev_batch_indices_)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||||
|  | 
 | ||||||
|  |   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||||
|  |               const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||||
|  |     VARIABLE_NOT_USED(workspace); | ||||||
|  |     T *input_addr = GetDeviceAddress<T>(inputs, 0); | ||||||
|  |     S *indices_addr = GetDeviceAddress<S>(inputs, 1); | ||||||
|  |     T *output_addr = GetDeviceAddress<T>(outputs, 0); | ||||||
|  | 
 | ||||||
|  |     GatherNd(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], dev_batch_strides_, | ||||||
|  |              dev_batch_indices_, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |   bool Init(const CNodePtr &kernel_node) override { | ||||||
|  |     InitResource(); | ||||||
|  |     size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||||
|  |     if (input_num != 2) { | ||||||
|  |       MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherNdGpuFwdKernel needs 2."; | ||||||
|  |     } | ||||||
|  |     input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||||
|  |     indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | ||||||
|  |     output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||||
|  | 
 | ||||||
|  |     Reshape(); | ||||||
|  | 
 | ||||||
|  |     size_t dim_indices_last = dims_[dims_.size() - 1]; | ||||||
|  |     batch_strides_.resize(dim_indices_last, 0); | ||||||
|  |     batch_indices_.resize(dim_indices_last, 0); | ||||||
|  | 
 | ||||||
|  |     if (dim_indices_last > 0) { | ||||||
|  |       batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1]; | ||||||
|  |       batch_indices_[dim_indices_last - 1] = dims_[1]; | ||||||
|  |     } | ||||||
|  |     for (size_t i = dim_indices_last - 1; i > 0; --i) { | ||||||
|  |       batch_strides_[i - 1] = input_shapes_[i - 1]; | ||||||
|  |       batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     size_t strides_len = sizeof(S) * batch_strides_.size(); | ||||||
|  |     void *dev_batch_strides_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(strides_len); | ||||||
|  |     if (dev_batch_strides_work == nullptr) { | ||||||
|  |       MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_strides_work, size: " << strides_len; | ||||||
|  |     } | ||||||
|  |     dev_batch_strides_ = static_cast<S *>(dev_batch_strides_work); | ||||||
|  | 
 | ||||||
|  |     size_t indices_len = sizeof(S) * batch_indices_.size(); | ||||||
|  |     void *dev_batch_indices_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); | ||||||
|  |     if (dev_batch_indices_work == nullptr) { | ||||||
|  |       MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_indices_work, size: " << indices_len; | ||||||
|  |     } | ||||||
|  |     dev_batch_indices_ = static_cast<S *>(dev_batch_indices_work); | ||||||
|  | 
 | ||||||
|  |     CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_strides_, &batch_strides_[0], strides_len, cudaMemcpyHostToDevice), | ||||||
|  |                                "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); | ||||||
|  |     CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_indices_, &batch_indices_[0], indices_len, cudaMemcpyHostToDevice), | ||||||
|  |                                "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); | ||||||
|  | 
 | ||||||
|  |     InitSizeLists(); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  protected: | ||||||
|  |   void InitSizeLists() override { | ||||||
|  |     size_t size = GetSize(input_shapes_); | ||||||
|  |     input_size_list_.push_back(size); | ||||||
|  | 
 | ||||||
|  |     size = GetSize(indices_shapes_); | ||||||
|  |     input_size_list_.push_back(size); | ||||||
|  | 
 | ||||||
|  |     size = GetSize(output_shapes_); | ||||||
|  |     output_size_list_.push_back(size); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   void Reshape() { | ||||||
|  |     size_t dim_of_indices = 1; | ||||||
|  |     for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); i++) { | ||||||
|  |       dim_of_indices *= indices_shapes_[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     size_t dim_after_indices = 1; | ||||||
|  |     size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)]; | ||||||
|  |     for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) { | ||||||
|  |       dim_after_indices *= input_shapes_[i]; | ||||||
|  |     } | ||||||
|  |     dims_.emplace_back(dim_of_indices); | ||||||
|  |     dims_.emplace_back(dim_after_indices); | ||||||
|  |     dims_.emplace_back(dim_indices_last); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |   size_t GetSize(const std::vector<size_t> &shape) const { | ||||||
|  |     if (shape.size() == 0) { | ||||||
|  |       return 0; | ||||||
|  |     } | ||||||
|  |     size_t result = sizeof(T); | ||||||
|  |     for (size_t i = 0; i < shape.size(); i++) { | ||||||
|  |       result *= shape[i]; | ||||||
|  |     } | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   std::vector<size_t> input_shapes_; | ||||||
|  |   std::vector<size_t> indices_shapes_; | ||||||
|  |   std::vector<size_t> output_shapes_; | ||||||
|  | 
 | ||||||
|  |   std::vector<size_t> dims_; | ||||||
|  | 
 | ||||||
|  |   std::vector<size_t> input_size_list_; | ||||||
|  |   std::vector<size_t> output_size_list_; | ||||||
|  |   std::vector<size_t> workspace_size_list_; | ||||||
|  | 
 | ||||||
|  |   std::vector<S> batch_strides_; | ||||||
|  |   std::vector<S> batch_indices_; | ||||||
|  | 
 | ||||||
|  |   S *dev_batch_strides_; | ||||||
|  |   S *dev_batch_indices_; | ||||||
|  | }; | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_GATHERND_GPU_KERNEL_H
 | ||||||
| @ -0,0 +1,33 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | MS_REG_GPU_KERNEL_TWO( | ||||||
|  |   ScatterNd, | ||||||
|  |   KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||||
|  |   ScatterNdGpuFwdKernel, float, int) | ||||||
|  | MS_REG_GPU_KERNEL_TWO( | ||||||
|  |   ScatterNd, | ||||||
|  |   KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||||
|  |   ScatterNdGpuFwdKernel, half, int) | ||||||
|  | MS_REG_GPU_KERNEL_TWO( | ||||||
|  |   ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||||
|  |   ScatterNdGpuFwdKernel, int, int) | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
| @ -0,0 +1,175 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H | ||||||
|  | #define MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | template <typename T, typename S> | ||||||
|  | class ScatterNdGpuFwdKernel : public GpuKernel { | ||||||
|  |  public: | ||||||
|  |   ScatterNdGpuFwdKernel() | ||||||
|  |       : input_size_(1), | ||||||
|  |         indices_size_(1), | ||||||
|  |         output_size_(1), | ||||||
|  |         block_size_(1), | ||||||
|  |         indices_stride_(nullptr), | ||||||
|  |         work_shape_(nullptr), | ||||||
|  |         indices_dim_0_(0), | ||||||
|  |         indices_dim_1_(0) {} | ||||||
|  |   ~ScatterNdGpuFwdKernel() { | ||||||
|  |     if (indices_stride_ != nullptr) { | ||||||
|  |       device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(indices_stride_)); | ||||||
|  |     } | ||||||
|  |     if (work_shape_ != nullptr) { | ||||||
|  |       device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(work_shape_)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||||
|  | 
 | ||||||
|  |   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||||
|  |               const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||||
|  |     VARIABLE_NOT_USED(workspace); | ||||||
|  |     S *indices = GetDeviceAddress<S>(inputs, 0); | ||||||
|  |     T *update = GetDeviceAddress<T>(inputs, 1); | ||||||
|  |     T *output = GetDeviceAddress<T>(outputs, 0); | ||||||
|  | 
 | ||||||
|  |     ScatterNd(indices, update, output, block_size_, input_size_, output_size_, indices_dim_0_, indices_dim_1_, | ||||||
|  |               indices_stride_, work_shape_, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   bool Init(const CNodePtr &kernel_node) override { | ||||||
|  |     size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||||
|  |     if (input_num != 2) { | ||||||
|  |       MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 2 input."; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  |     size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | ||||||
|  |     if (output_num != 1) { | ||||||
|  |       MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | ||||||
|  |     indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||||
|  |     output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||||
|  | 
 | ||||||
|  |     vec_work_shape_ = GetAttr<std::vector<S>>(kernel_node, "shape"); | ||||||
|  | 
 | ||||||
|  |     GetSize(); | ||||||
|  | 
 | ||||||
|  |     size_t indices_len = sizeof(S) * vec_indices_stride_.size(); | ||||||
|  |     void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); | ||||||
|  |     if (indices_stride_work == nullptr) { | ||||||
|  |       MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len; | ||||||
|  |     } | ||||||
|  |     indices_stride_ = static_cast<S *>(indices_stride_work); | ||||||
|  | 
 | ||||||
|  |     size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); | ||||||
|  |     void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len); | ||||||
|  |     if (work_shape_work == nullptr) { | ||||||
|  |       MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len; | ||||||
|  |     } | ||||||
|  |     work_shape_ = static_cast<S *>(work_shape_work); | ||||||
|  | 
 | ||||||
|  |     CHECK_CUDA_RET_WITH_EXCEPT( | ||||||
|  |       cudaMemcpy(indices_stride_, &vec_indices_stride_[0], indices_len, cudaMemcpyHostToDevice), | ||||||
|  |       "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); | ||||||
|  |     CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice), | ||||||
|  |                                "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); | ||||||
|  |     InitSizeLists(); | ||||||
|  | 
 | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  protected: | ||||||
|  |   void InitSizeLists() override { | ||||||
|  |     input_size_list_.push_back(indices_size_); | ||||||
|  |     input_size_list_.push_back(input_size_); | ||||||
|  |     output_size_list_.push_back(output_size_); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   void GetSize() { | ||||||
|  |     indices_size_ = sizeof(S); | ||||||
|  |     for (size_t i = 0; i < indices_shapes_.size(); i++) { | ||||||
|  |       indices_size_ *= indices_shapes_[i]; | ||||||
|  |     } | ||||||
|  |     input_size_ = sizeof(T); | ||||||
|  |     for (size_t i = 0; i < input_shapes_.size(); i++) { | ||||||
|  |       input_size_ *= input_shapes_[i]; | ||||||
|  |     } | ||||||
|  |     output_size_ = sizeof(T); | ||||||
|  |     for (size_t i = 0; i < output_shapes_.size(); i++) { | ||||||
|  |       output_size_ *= output_shapes_[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // calculate indices dim 0/1
 | ||||||
|  |     indices_dim_0_ = indices_shapes_[0]; | ||||||
|  |     indices_dim_1_ = indices_shapes_[1]; | ||||||
|  | 
 | ||||||
|  |     // calculate block_size
 | ||||||
|  |     for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) { | ||||||
|  |       block_size_ *= output_shapes_[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // calculate indices_stride
 | ||||||
|  |     for (size_t i = 0; i < indices_dim_1_; i++) { | ||||||
|  |       vec_indices_stride_.push_back(0); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     vec_indices_stride_[indices_dim_1_ - 1] = block_size_; | ||||||
|  | 
 | ||||||
|  |     for (size_t i = indices_dim_1_ - 1; i > 0; --i) { | ||||||
|  |       vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i]; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::vector<size_t> input_shapes_; | ||||||
|  |   std::vector<size_t> indices_shapes_; | ||||||
|  |   std::vector<size_t> output_shapes_; | ||||||
|  |   std::vector<S> vec_indices_stride_; | ||||||
|  |   std::vector<S> vec_work_shape_; | ||||||
|  | 
 | ||||||
|  |   std::vector<size_t> input_size_list_; | ||||||
|  |   std::vector<size_t> output_size_list_; | ||||||
|  |   std::vector<size_t> workspace_size_list_; | ||||||
|  | 
 | ||||||
|  |   size_t input_size_; | ||||||
|  |   size_t indices_size_; | ||||||
|  |   size_t output_size_; | ||||||
|  |   size_t block_size_; | ||||||
|  | 
 | ||||||
|  |   S *indices_stride_; | ||||||
|  |   S *work_shape_; | ||||||
|  |   size_t indices_dim_0_; | ||||||
|  |   size_t indices_dim_1_; | ||||||
|  | }; | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H
 | ||||||
| @ -0,0 +1,81 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | __global__ void BoundingBoxDecodeKernel(const size_t size, const T *rois, const T *deltas, T *bboxes, const float m1, | ||||||
|  |                                         const float m2, const float m3, const float m4, const float s1, const float s2, | ||||||
|  |                                         const float s3, const float s4, const int max_height, const int max_width, | ||||||
|  |                                         const float ratio_clip) { | ||||||
|  |   for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | ||||||
|  |     const size_t left_x = i * 4; | ||||||
|  |     const size_t left_y = i * 4 + 1; | ||||||
|  |     const size_t right_x = i * 4 + 2; | ||||||
|  |     const size_t right_y = i * 4 + 3; | ||||||
|  | 
 | ||||||
|  |     T dx = deltas[left_x] * s1 + m1; | ||||||
|  |     T dy = deltas[left_y] * s2 + m2; | ||||||
|  |     T dw = deltas[right_x] * s3 + m3; | ||||||
|  |     T dh = deltas[right_y] * s4 + m4; | ||||||
|  | 
 | ||||||
|  |     T max_ratio = abs(log(ratio_clip)); | ||||||
|  | 
 | ||||||
|  |     dw = dw > max_ratio ? max_ratio : (dw < (-max_ratio) ? (-max_ratio) : dw); | ||||||
|  |     dh = dh > max_ratio ? max_ratio : (dh < (-max_ratio) ? (-max_ratio) : dh); | ||||||
|  | 
 | ||||||
|  |     T px = (rois[left_x] + rois[right_x]) * 0.5f; | ||||||
|  |     T py = (rois[left_y] + rois[right_y]) * 0.5f; | ||||||
|  |     T pw = rois[right_x] - rois[left_x] + 1.0f; | ||||||
|  |     T ph = rois[right_y] - rois[left_y] + 1.0f; | ||||||
|  | 
 | ||||||
|  |     T gx = px + pw * dx; | ||||||
|  |     T gy = py + ph * dy; | ||||||
|  |     T gw = pw * exp(dw); | ||||||
|  |     T gh = ph * exp(dh); | ||||||
|  | 
 | ||||||
|  |     T x1 = gx - gw * 0.5f + 0.5f; | ||||||
|  |     T y1 = gy - gh * 0.5f + 0.5f; | ||||||
|  |     T x2 = gx + gw * 0.5f - 0.5f; | ||||||
|  |     T y2 = gy + gh * 0.5f - 0.5f; | ||||||
|  | 
 | ||||||
|  |     x1 = x1 > max_width ? max_width : (x1 < 0 ? 0 : x1); | ||||||
|  |     y1 = y1 > max_height ? max_height : (y1 < 0 ? 0 : y1); | ||||||
|  |     x2 = x2 > max_width ? max_width : (x2 < 0 ? 0 : x2); | ||||||
|  |     y2 = y2 > max_height ? max_height : (y2 < 0 ? 0 : y2); | ||||||
|  | 
 | ||||||
|  |     bboxes[left_x] = x1; | ||||||
|  |     bboxes[left_y] = y1; | ||||||
|  |     bboxes[right_x] = x2; | ||||||
|  |     bboxes[right_y] = y2; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, | ||||||
|  |                        const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, | ||||||
|  |                        const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, | ||||||
|  |                        cudaStream_t cuda_stream) { | ||||||
|  |   BoundingBoxDecodeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, rois, deltas, bboxes, m1, m2, m3, m4, | ||||||
|  |                                                                              s1, s2, s3, s4, max_height, max_width, | ||||||
|  |                                                                              ratio_clip); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template void BoundingBoxDecode<float>(const size_t size, const float *rois, const float *deltas, float *bboxes, | ||||||
|  |                                        const float &m1, const float &m2, const float &m3, const float &m4, | ||||||
|  |                                        const float &s1, const float &s2, const float &s3, const float &s4, | ||||||
|  |                                        const int &max_height, const int &max_width, const float &ratio_clip, | ||||||
|  |                                        cudaStream_t cuda_stream); | ||||||
| @ -0,0 +1,27 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ | ||||||
|  | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ | ||||||
|  | 
 | ||||||
|  | #include "runtime/device/gpu/cuda_common.h" | ||||||
|  | template <typename T> | ||||||
|  | void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, | ||||||
|  |                        const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, | ||||||
|  |                        const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, | ||||||
|  |                        cudaStream_t cuda_stream); | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ | ||||||
| @ -0,0 +1,62 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | __global__ void BoundingBoxEncodeKernel(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, | ||||||
|  |                                         const float m1, const float m2, const float m3, const float m4, const float s1, | ||||||
|  |                                         const float s2, const float s3, const float s4) { | ||||||
|  |   for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | ||||||
|  |     const size_t left_x = i * 4; | ||||||
|  |     const size_t left_y = i * 4 + 1; | ||||||
|  |     const size_t right_x = i * 4 + 2; | ||||||
|  |     const size_t right_y = i * 4 + 3; | ||||||
|  | 
 | ||||||
|  |     T px = (anchor_box[left_x] + anchor_box[right_x]) * 0.5f; | ||||||
|  |     T py = (anchor_box[left_y] + anchor_box[right_y]) * 0.5f; | ||||||
|  |     T pw = anchor_box[right_x] - anchor_box[left_x] + 1.0f; | ||||||
|  |     T ph = anchor_box[right_y] - anchor_box[left_y] + 1.0f; | ||||||
|  | 
 | ||||||
|  |     T gx = (groundtruth_box[left_x] + groundtruth_box[right_x]) * 0.5f; | ||||||
|  |     T gy = (groundtruth_box[left_y] + groundtruth_box[right_y]) * 0.5f; | ||||||
|  |     T gw = groundtruth_box[right_x] - groundtruth_box[left_x] + 1.0f; | ||||||
|  |     T gh = groundtruth_box[right_y] - groundtruth_box[left_y] + 1.0f; | ||||||
|  | 
 | ||||||
|  |     T dx = (gx - px) / pw; | ||||||
|  |     T dy = (gy - py) / ph; | ||||||
|  |     T dw = log(gw / pw); | ||||||
|  |     T dh = log(gh / ph); | ||||||
|  | 
 | ||||||
|  |     deltas[left_x] = (dx - m1) / s1; | ||||||
|  |     deltas[left_y] = (dy - m2) / s2; | ||||||
|  |     deltas[right_x] = (dw - m3) / s3; | ||||||
|  |     deltas[right_y] = (dh - m4) / s4; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, | ||||||
|  |                        const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, | ||||||
|  |                        const float &s3, const float &s4, cudaStream_t cuda_stream) { | ||||||
|  |   BoundingBoxEncodeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, anchor_box, groundtruth_box, deltas, | ||||||
|  |                                                                              m1, m2, m3, m4, s1, s2, s3, s4); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template void BoundingBoxEncode<float>(const size_t size, const float *anchor_box, const float *groundtruth_box, | ||||||
|  |                                        float *deltas, const float &m1, const float &m2, const float &m3, | ||||||
|  |                                        const float &m4, const float &s1, const float &s2, const float &s3, | ||||||
|  |                                        const float &s4, cudaStream_t cuda_stream); | ||||||
| @ -0,0 +1,26 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ | ||||||
|  | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ | ||||||
|  | 
 | ||||||
|  | #include "runtime/device/gpu/cuda_common.h" | ||||||
|  | template <typename T> | ||||||
|  | void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, | ||||||
|  |                        const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, | ||||||
|  |                        const float &s3, const float &s4, cudaStream_t cuda_stream); | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ | ||||||
| @ -0,0 +1,65 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" | ||||||
|  | #include "runtime/device/gpu/cuda_common.h" | ||||||
|  | template <typename T, typename S> | ||||||
|  | __global__ void GatherNdKernel(T *input, S *indices, T *output, const size_t output_dim0, const size_t output_dim1, | ||||||
|  |                                const size_t indices_dim1, S *batch_indices, S *batch_strides) { | ||||||
|  |   int num = output_dim0 * output_dim1; | ||||||
|  |   int i, j; | ||||||
|  |   for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; | ||||||
|  |        write_index += blockDim.x * gridDim.x) { | ||||||
|  |     i = write_index / output_dim1 % output_dim0; | ||||||
|  |     j = write_index % output_dim1; | ||||||
|  | 
 | ||||||
|  |     bool out_of_bound = false; | ||||||
|  |     int read_index = 0; | ||||||
|  |     int indices_i = 0; | ||||||
|  |     for (size_t k = 0; k < indices_dim1; k++) { | ||||||
|  |       size_t ind = indices_dim1 * i + k; | ||||||
|  |       indices_i = indices[ind]; | ||||||
|  |       out_of_bound |= !(indices_i < batch_indices[k]); | ||||||
|  |       read_index += indices_i * batch_strides[k]; | ||||||
|  |     } | ||||||
|  |     read_index += j; | ||||||
|  | 
 | ||||||
|  |     if (!out_of_bound) { | ||||||
|  |       output[write_index] = input[read_index]; | ||||||
|  |     } else { | ||||||
|  |       output[write_index] = 0; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   return; | ||||||
|  | } | ||||||
|  | template <typename T, typename S> | ||||||
|  | void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, | ||||||
|  |               const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream) { | ||||||
|  |   int size = output_dim0 * output_dim1; | ||||||
|  |   GatherNdKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1, | ||||||
|  |                                                                indices_dim1, batch_indices, batch_strides); | ||||||
|  |   return; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template void GatherNd<float, int>(float *input, int *indices, float *output, const size_t &output_dim0, | ||||||
|  |                                    const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | ||||||
|  |                                    int *batch_strides, cudaStream_t stream); | ||||||
|  | template void GatherNd<half, int>(half *input, int *indices, half *output, const size_t &output_dim0, | ||||||
|  |                                   const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | ||||||
|  |                                   int *batch_strides, cudaStream_t stream); | ||||||
|  | template void GatherNd<int, int>(int *input, int *indices, int *output, const size_t &output_dim0, | ||||||
|  |                                  const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | ||||||
|  |                                  int *batch_strides, cudaStream_t stream); | ||||||
| @ -0,0 +1,26 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_GATHERND_GPU_CU_H | ||||||
|  | #define MINDSPORE_GATHERND_GPU_CU_H | ||||||
|  | 
 | ||||||
|  | #include "runtime/device/gpu/cuda_common.h" | ||||||
|  | 
 | ||||||
|  | template <typename T, typename S> | ||||||
|  | void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, | ||||||
|  |               const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream); | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_GATHERND_GPU_CU_H | ||||||
| @ -0,0 +1,68 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" | ||||||
|  | #include "runtime/device/gpu/cuda_common.h" | ||||||
|  | template <typename T, typename S> | ||||||
|  | __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, | ||||||
|  |                                 const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, | ||||||
|  |                                 S *indices_stride, S *work_shape) { | ||||||
|  |   int i, j; | ||||||
|  |   for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; | ||||||
|  |        read_index += blockDim.x * gridDim.x) { | ||||||
|  |     int write_index = 0; | ||||||
|  |     bool out_bound = false; | ||||||
|  | 
 | ||||||
|  |     i = read_index / block_size; | ||||||
|  |     j = read_index % block_size; | ||||||
|  | 
 | ||||||
|  |     for (size_t k = 0; k < indices_dim_1; k++) { | ||||||
|  |       S indices_i = indices[i * indices_dim_1 + k]; | ||||||
|  |       out_bound |= indices_i >= work_shape[k]; | ||||||
|  |       write_index += indices_i * indices_stride[k]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     write_index += j; | ||||||
|  |     out_bound |= write_index >= output_size; | ||||||
|  | 
 | ||||||
|  |     if (!out_bound) { | ||||||
|  |       output[write_index] = update[read_index]; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename T, typename S> | ||||||
|  | void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | ||||||
|  |                const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | ||||||
|  |                S *work_shape, cudaStream_t stream) { | ||||||
|  |   ScatterNdKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size, | ||||||
|  |                                                                       output_size, indices_dim_0, indices_dim_1, | ||||||
|  |                                                                       indices_stride, work_shape); | ||||||
|  |   return; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size, | ||||||
|  |                                     const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | ||||||
|  |                                     const size_t &indices_dim_1, int *indices_stride, int *work_shape, | ||||||
|  |                                     cudaStream_t stream); | ||||||
|  | template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size, | ||||||
|  |                                    const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | ||||||
|  |                                    const size_t &indices_dim_1, int *indices_stride, int *work_shape, | ||||||
|  |                                    cudaStream_t stream); | ||||||
|  | template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size, | ||||||
|  |                                   const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | ||||||
|  |                                   const size_t &indices_dim_1, int *indices_stride, int *work_shape, | ||||||
|  |                                   cudaStream_t stream); | ||||||
| @ -0,0 +1,26 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_SCATTER_ND_GPU_CU_H | ||||||
|  | #define MINDSPORE_SCATTER_ND_GPU_CU_H | ||||||
|  | 
 | ||||||
|  | #include "runtime/device/gpu/cuda_common.h" | ||||||
|  | 
 | ||||||
|  | template <typename T, typename S> | ||||||
|  | void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | ||||||
|  |                const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | ||||||
|  |                S *work_shape, cudaStream_t stream); | ||||||
|  | #endif  // MINDSPORE_SCATTER_ND_GPU_CU_H | ||||||
| @ -0,0 +1,57 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include <iostream> | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | __global__ void SGDKernel(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *grad, | ||||||
|  |                           const T *momentum, const T *lr, T *param, T *accum, T *stat) { | ||||||
|  |   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | ||||||
|  |     T grad_new = grad[i]; | ||||||
|  |     if (weight_decay != static_cast<T>(0)) { | ||||||
|  |       grad_new += param[i] * weight_decay; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (momentum[0] != static_cast<T>(0)) { | ||||||
|  |       if (stat[i] == static_cast<T>(0)) { | ||||||
|  |         accum[i] = grad_new; | ||||||
|  |         stat[i] = 0; | ||||||
|  |       } else { | ||||||
|  |         accum[i] = accum[i] * momentum[0] + (1.0 - dampening) * grad_new; | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       if (nesterov) { | ||||||
|  |         grad_new += accum[i] * momentum[0]; | ||||||
|  |       } else { | ||||||
|  |         grad_new = accum[i]; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     param[i] -= lr[0] * grad_new; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, | ||||||
|  |          const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream) { | ||||||
|  |   SGDKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dampening, weight_decay, nesterov, grad, momentum, | ||||||
|  |                                                                lr, param, accum, stat); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template void SGD(const int size, const float dampening, const float weight_decay, const bool nesterov, const float *lr, | ||||||
|  |                   const float *momentum, const float *grad, float *param, float *accum, float *stat, | ||||||
|  |                   cudaStream_t cuda_stream); | ||||||
| @ -0,0 +1,25 @@ | |||||||
|  | /** | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ | ||||||
|  | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ | ||||||
|  | 
 | ||||||
|  | #include "runtime/device/gpu/cuda_common.h" | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, | ||||||
|  |          const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream); | ||||||
|  | #endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ | ||||||
| @ -0,0 +1,32 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | MS_REG_GPU_KERNEL_ONE(SGD, | ||||||
|  |                       KernelAttr() | ||||||
|  |                         .AddInputAttr(kNumberTypeFloat32) | ||||||
|  |                         .AddInputAttr(kNumberTypeFloat32) | ||||||
|  |                         .AddInputAttr(kNumberTypeFloat32) | ||||||
|  |                         .AddInputAttr(kNumberTypeFloat32) | ||||||
|  |                         .AddInputAttr(kNumberTypeFloat32) | ||||||
|  |                         .AddInputAttr(kNumberTypeFloat32) | ||||||
|  |                         .AddOutputAttr(kNumberTypeFloat32), | ||||||
|  |                       SGDGpuKernel, float) | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
| @ -0,0 +1,88 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ | ||||||
|  | #define MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | template <typename T> | ||||||
|  | class SGDGpuKernel : public GpuKernel { | ||||||
|  |  public: | ||||||
|  |   SGDGpuKernel() : size_(1), dampening_(0.0), weight_decay_(0.0), nesterov_(false) {} | ||||||
|  |   ~SGDGpuKernel() override = default; | ||||||
|  | 
 | ||||||
|  |   const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||||
|  | 
 | ||||||
|  |   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | ||||||
|  |               const std::vector<AddressPtr> &outputs, void *stream) override { | ||||||
|  |     T *param = GetDeviceAddress<T>(inputs, 0); | ||||||
|  |     T *grad = GetDeviceAddress<T>(inputs, 1); | ||||||
|  |     T *lr = GetDeviceAddress<T>(inputs, 2); | ||||||
|  |     T *accum = GetDeviceAddress<T>(inputs, 3); | ||||||
|  |     T *momentum = GetDeviceAddress<T>(inputs, 4); | ||||||
|  |     T *stat = GetDeviceAddress<T>(inputs, 5); | ||||||
|  | 
 | ||||||
|  |     SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat, | ||||||
|  |         reinterpret_cast<cudaStream_t>(stream)); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |   bool Init(const CNodePtr &kernel_node) override { | ||||||
|  |     dampening_ = GetAttr<float>(kernel_node, "dampening"); | ||||||
|  |     weight_decay_ = GetAttr<float>(kernel_node, "weight_decay"); | ||||||
|  |     nesterov_ = GetAttr<bool>(kernel_node, "nesterov"); | ||||||
|  | 
 | ||||||
|  |     auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||||
|  |     for (auto &dim : input_shape) { | ||||||
|  |       size_ *= dim; | ||||||
|  |     } | ||||||
|  |     InitSizeLists(); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  protected: | ||||||
|  |   void InitSizeLists() override { | ||||||
|  |     size_t input_size = size_ * sizeof(T); | ||||||
|  |     input_size_list_.push_back(input_size);  // parameter
 | ||||||
|  |     input_size_list_.push_back(input_size);  // gradient
 | ||||||
|  |     input_size_list_.push_back(sizeof(T));   // lr
 | ||||||
|  |     input_size_list_.push_back(input_size);  // accum
 | ||||||
|  |     input_size_list_.push_back(sizeof(T));   // momentum
 | ||||||
|  |     input_size_list_.push_back(input_size);  // stat
 | ||||||
|  |     output_size_list_.push_back(input_size); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   size_t size_; | ||||||
|  |   float dampening_; | ||||||
|  |   float weight_decay_; | ||||||
|  |   bool nesterov_; | ||||||
|  | 
 | ||||||
|  |   std::vector<size_t> input_size_list_; | ||||||
|  |   std::vector<size_t> output_size_list_; | ||||||
|  |   std::vector<size_t> workspace_size_list_; | ||||||
|  | }; | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_
 | ||||||
| @ -0,0 +1,26 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | MS_REG_GPU_KERNEL_ONE( | ||||||
|  |   BoundingBoxDecode, | ||||||
|  |   KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||||
|  |   BoundingBoxDecodeGpuKernel, float) | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
| @ -0,0 +1,152 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H | ||||||
|  | #define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | template <typename T> | ||||||
|  | class BoundingBoxDecodeGpuKernel : public GpuKernel { | ||||||
|  |  public: | ||||||
|  |   BoundingBoxDecodeGpuKernel() : rois_size_(0), deltas_size_(0), bboxes_size_(0), wh_ratio_clip_(0.016) {} | ||||||
|  | 
 | ||||||
|  |   ~BoundingBoxDecodeGpuKernel() override = default; | ||||||
|  | 
 | ||||||
|  |   const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||||
|  | 
 | ||||||
|  |   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||||
|  |               const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||||
|  |     T *rois_addr = GetDeviceAddress<T>(inputs, 0); | ||||||
|  |     T *deltas_addr = GetDeviceAddress<T>(inputs, 1); | ||||||
|  |     T *bboxes_addr = GetDeviceAddress<T>(outputs, 0); | ||||||
|  | 
 | ||||||
|  |     if (inputs[0]->size != inputs[1]->size) { | ||||||
|  |       MS_LOG(ERROR) << "Rois box size must equal with deltas box size -" << inputs[1]->size << ", but got" | ||||||
|  |                     << inputs[0]->size; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     const size_t coordinate = 4; | ||||||
|  |     const size_t block_size = inputs[0]->size / sizeof(T); | ||||||
|  |     if ((block_size % coordinate) != 0) { | ||||||
|  |       MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     BoundingBoxDecode(block_size / coordinate, rois_addr, deltas_addr, bboxes_addr, means_[0], means_[1], means_[2], | ||||||
|  |                       means_[3], stds_[0], stds_[1], stds_[2], stds_[3], max_shape_[0], max_shape_[1], wh_ratio_clip_, | ||||||
|  |                       reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   bool Init(const CNodePtr &kernel_node) override { | ||||||
|  |     size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||||
|  |     if (input_num != 2) { | ||||||
|  |       MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs."; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  |     rois_size_ = sizeof(T); | ||||||
|  |     deltas_size_ = sizeof(T); | ||||||
|  |     bboxes_size_ = sizeof(T); | ||||||
|  | 
 | ||||||
|  |     auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||||
|  |     for (size_t i = 0; i < logits_shape.size(); i++) { | ||||||
|  |       rois_size_ *= logits_shape[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | ||||||
|  |     for (size_t i = 0; i < labels_shape.size(); i++) { | ||||||
|  |       deltas_size_ *= labels_shape[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||||
|  |     for (size_t i = 0; i < output_shape.size(); i++) { | ||||||
|  |       bboxes_size_ *= output_shape[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     InitSizeLists(); | ||||||
|  | 
 | ||||||
|  |     const size_t coordinate_size = 4; | ||||||
|  |     if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() || | ||||||
|  |         AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) { | ||||||
|  |       means_ = GetAttr<std::vector<float>>(kernel_node, "means"); | ||||||
|  |     } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) { | ||||||
|  |       float mean = GetAttr<int>(kernel_node, "means"); | ||||||
|  |       for (size_t i = 0; i < coordinate_size; i++) { | ||||||
|  |         means_.emplace_back(mean); | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       MS_LOG(EXCEPTION) << "Attribute means type is invalid."; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() || | ||||||
|  |         AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) { | ||||||
|  |       stds_ = GetAttr<std::vector<float>>(kernel_node, "stds"); | ||||||
|  |     } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) { | ||||||
|  |       float std = GetAttr<int>(kernel_node, "stds"); | ||||||
|  |       for (size_t i = 0; i < coordinate_size; i++) { | ||||||
|  |         stds_.emplace_back(std); | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     max_shape_ = GetAttr<std::vector<int>>(kernel_node, "max_shape"); | ||||||
|  |     wh_ratio_clip_ = GetAttr<float>(kernel_node, "wh_ratio_clip"); | ||||||
|  | 
 | ||||||
|  |     if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { | ||||||
|  |       MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (max_shape_.size() < 2) { | ||||||
|  |       MS_LOG(EXCEPTION) << "The size of max_shape is less than 2."; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  protected: | ||||||
|  |   void InitSizeLists() override { | ||||||
|  |     input_size_list_.push_back(rois_size_); | ||||||
|  |     input_size_list_.push_back(deltas_size_); | ||||||
|  |     output_size_list_.push_back(bboxes_size_); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   size_t rois_size_; | ||||||
|  |   size_t deltas_size_; | ||||||
|  |   size_t bboxes_size_; | ||||||
|  |   std::vector<float> means_; | ||||||
|  |   std::vector<float> stds_; | ||||||
|  |   std::vector<int> max_shape_; | ||||||
|  |   float wh_ratio_clip_; | ||||||
|  | 
 | ||||||
|  |   std::vector<size_t> input_size_list_; | ||||||
|  |   std::vector<size_t> output_size_list_; | ||||||
|  |   std::vector<size_t> workspace_size_list_; | ||||||
|  | }; | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H
 | ||||||
| @ -0,0 +1,26 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include "backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | MS_REG_GPU_KERNEL_ONE( | ||||||
|  |   BoundingBoxEncode, | ||||||
|  |   KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||||
|  |   BoundingBoxEncodeGpuKernel, float) | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
| @ -0,0 +1,143 @@ | |||||||
|  | /**
 | ||||||
|  |  * Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H | ||||||
|  | #define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||||
|  | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||||
|  | 
 | ||||||
|  | namespace mindspore { | ||||||
|  | namespace kernel { | ||||||
|  | template <typename T> | ||||||
|  | class BoundingBoxEncodeGpuKernel : public GpuKernel { | ||||||
|  |  public: | ||||||
|  |   BoundingBoxEncodeGpuKernel() : anchor_size_(0), groundtruth_size_(0), deltas_size_(0) {} | ||||||
|  | 
 | ||||||
|  |   ~BoundingBoxEncodeGpuKernel() override = default; | ||||||
|  | 
 | ||||||
|  |   const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||||
|  |   const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||||
|  | 
 | ||||||
|  |   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||||
|  |               const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||||
|  |     T *anchor_addr = GetDeviceAddress<T>(inputs, 0); | ||||||
|  |     T *groundtruth_addr = GetDeviceAddress<T>(inputs, 1); | ||||||
|  |     T *deltas_addr = GetDeviceAddress<T>(outputs, 0); | ||||||
|  | 
 | ||||||
|  |     if (inputs[0]->size != inputs[1]->size) { | ||||||
|  |       MS_LOG(ERROR) << "Anchor box size must equal with groundtruth box size -" << inputs[1]->size << ", but got" | ||||||
|  |                     << inputs[0]->size; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     const size_t coordinate = 4; | ||||||
|  |     const size_t block_size = inputs[0]->size / sizeof(T); | ||||||
|  |     if ((block_size % coordinate) != 0) { | ||||||
|  |       MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     BoundingBoxEncode(block_size / coordinate, anchor_addr, groundtruth_addr, deltas_addr, means_[0], means_[1], | ||||||
|  |                       means_[2], means_[3], stds_[0], stds_[1], stds_[2], stds_[3], | ||||||
|  |                       reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   bool Init(const CNodePtr &kernel_node) override { | ||||||
|  |     size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||||
|  |     if (input_num != 2) { | ||||||
|  |       MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs."; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  |     anchor_size_ = sizeof(T); | ||||||
|  |     groundtruth_size_ = sizeof(T); | ||||||
|  |     deltas_size_ = sizeof(T); | ||||||
|  | 
 | ||||||
|  |     auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||||
|  |     for (size_t i = 0; i < logits_shape.size(); i++) { | ||||||
|  |       anchor_size_ *= logits_shape[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | ||||||
|  |     for (size_t i = 0; i < labels_shape.size(); i++) { | ||||||
|  |       groundtruth_size_ *= labels_shape[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||||
|  |     for (size_t i = 0; i < output_shape.size(); i++) { | ||||||
|  |       deltas_size_ *= output_shape[i]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     InitSizeLists(); | ||||||
|  | 
 | ||||||
|  |     const size_t coordinate_size = 4; | ||||||
|  |     if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() || | ||||||
|  |         AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) { | ||||||
|  |       means_ = GetAttr<std::vector<float>>(kernel_node, "means"); | ||||||
|  |     } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) { | ||||||
|  |       float mean = GetAttr<int>(kernel_node, "means"); | ||||||
|  |       for (size_t i = 0; i < coordinate_size; i++) { | ||||||
|  |         means_.emplace_back(mean); | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       MS_LOG(EXCEPTION) << "Attribute means type is invalid."; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() || | ||||||
|  |         AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) { | ||||||
|  |       stds_ = GetAttr<std::vector<float>>(kernel_node, "stds"); | ||||||
|  |     } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) { | ||||||
|  |       float std = GetAttr<int>(kernel_node, "stds"); | ||||||
|  |       for (size_t i = 0; i < coordinate_size; i++) { | ||||||
|  |         stds_.emplace_back(std); | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { | ||||||
|  |       MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  protected: | ||||||
|  |   void InitSizeLists() override { | ||||||
|  |     input_size_list_.push_back(anchor_size_); | ||||||
|  |     input_size_list_.push_back(groundtruth_size_); | ||||||
|  |     output_size_list_.push_back(deltas_size_); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   size_t anchor_size_; | ||||||
|  |   size_t groundtruth_size_; | ||||||
|  |   size_t deltas_size_; | ||||||
|  |   std::vector<float> means_; | ||||||
|  |   std::vector<float> stds_; | ||||||
|  | 
 | ||||||
|  |   std::vector<size_t> input_size_list_; | ||||||
|  |   std::vector<size_t> output_size_list_; | ||||||
|  |   std::vector<size_t> workspace_size_list_; | ||||||
|  | }; | ||||||
|  | }  // namespace kernel
 | ||||||
|  | }  // namespace mindspore
 | ||||||
|  | 
 | ||||||
|  | #endif  // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H
 | ||||||
| @ -0,0 +1,60 @@ | |||||||
|  | # Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | import mindspore | ||||||
|  | import mindspore.context as context | ||||||
|  | import mindspore.nn as nn | ||||||
|  | from mindspore import Tensor | ||||||
|  | from mindspore.ops import operations as P | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class NetBoundingBoxDecode(nn.Cell): | ||||||
|  |     def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): | ||||||
|  |         super(NetBoundingBoxDecode, self).__init__() | ||||||
|  |         self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=means, stds=stds, | ||||||
|  |                                           wh_ratio_clip=0.016) | ||||||
|  | 
 | ||||||
|  |     def construct(self, anchor, groundtruth): | ||||||
|  |         return self.decode(anchor, groundtruth) | ||||||
|  | 
 | ||||||
|  | @pytest.mark.level0 | ||||||
|  | @pytest.mark.platform_x86_gpu_training | ||||||
|  | @pytest.mark.env_onecard | ||||||
|  | def test_boundingbox_decode(): | ||||||
|  |     anchor = np.array([[4, 1, 2, 1], [2, 2, 2, 3]], np.float32) | ||||||
|  |     deltas = np.array([[3, 1, 2, 2], [1, 2, 1, 4]], np.float32) | ||||||
|  |     means = (0.1, 0.1, 0.2, 0.2) | ||||||
|  |     stds = (2.0, 2.0, 3.0, 3.0) | ||||||
|  |     anchor_box = Tensor(anchor, mindspore.float32) | ||||||
|  |     deltas_box = Tensor(deltas, mindspore.float32) | ||||||
|  |     expect_deltas = np.array([[28.6500, 0.0000, 0.0000, 33.8500], | ||||||
|  |                               [0.0000, 0.0000, 15.8663, 72.7000]], np.float32) | ||||||
|  | 
 | ||||||
|  |     error = np.ones(shape=[2, 4]) * 1.0e-4 | ||||||
|  | 
 | ||||||
|  |     context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||||
|  |     boundingbox_decode = NetBoundingBoxDecode(means, stds) | ||||||
|  |     output = boundingbox_decode(anchor_box, deltas_box) | ||||||
|  |     diff = output.asnumpy() - expect_deltas | ||||||
|  |     assert np.all(abs(diff) < error) | ||||||
|  | 
 | ||||||
|  |     context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | ||||||
|  |     boundingbox_decode = NetBoundingBoxDecode(means, stds) | ||||||
|  |     output = boundingbox_decode(anchor_box, deltas_box) | ||||||
|  |     diff = output.asnumpy() - expect_deltas | ||||||
|  |     assert np.all(abs(diff) < error) | ||||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue