From d54fe5a60a65c2f98da8051fae38155fb1885c36 Mon Sep 17 00:00:00 2001 From: linqingke Date: Tue, 28 Jul 2020 21:38:43 +0800 Subject: [PATCH] fix sgd bug. --- .../gpu/arrays/scatter_nd_gpu_kernel.h | 12 ++++++------ .../kernel_compiler/gpu/cuda_impl/iou_impl.cu | 6 +++--- .../kernel_compiler/gpu/cuda_impl/scatter_nd.cu | 6 ++++-- .../kernel_compiler/gpu/cuda_impl/sgd_impl.cu | 6 +++--- mindspore/ops/operations/other_ops.py | 4 ++++ 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h index 51f4323b1d..7cc0d1f858 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h @@ -69,7 +69,10 @@ class ScatterNdGpuFwdKernel : public GpuKernel { memcpy_flag_ = true; } - ScatterNd(indices, update, output, block_size_, input_size_, output_size_, indices_dim_0_, indices_dim_1_, + const size_t input_size = input_size_ / sizeof(T); + const size_t output_size = output_size_ / sizeof(T); + + ScatterNd(indices, update, output, block_size_, input_size, output_size, indices_dim_0_, indices_dim_1_, indices_stride_, work_shape_, reinterpret_cast(stream_ptr)); return true; } @@ -138,7 +141,7 @@ class ScatterNdGpuFwdKernel : public GpuKernel { // calculate indices dim 0/1 indices_dim_0_ = indices_shapes_[0]; - indices_dim_1_ = indices_shapes_[1]; + indices_dim_1_ = indices_shapes_[indices_shapes_.size() - 1]; // calculate block_size for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) { @@ -146,10 +149,7 @@ class ScatterNdGpuFwdKernel : public GpuKernel { } // calculate indices_stride - for (size_t i = 0; i < indices_dim_1_; i++) { - vec_indices_stride_.push_back(0); - } - + vec_indices_stride_.resize(indices_dim_1_, 0); vec_indices_stride_[indices_dim_1_ - 1] = block_size_; for (size_t i = indices_dim_1_ - 1; i > 0; --i) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu index f5e9f50dde..a3cdd7e131 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu @@ -50,12 +50,12 @@ __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *io T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] - location_coordinate[0][1] + 1); + T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] - + location_coordinate[1][1] + 1); if (mode == 0) { - T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] - - location_coordinate[1][1] + 1); iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon); } else { - iou_results[i] = overlaps / (area1 + epsilon); + iou_results[i] = overlaps / (area2 + epsilon); } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu index 5f9672c41f..80258e718d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu @@ -15,7 +15,9 @@ */ #include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" #include "runtime/device/gpu/cuda_common.h" + template __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, @@ -39,7 +41,7 @@ __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t b out_bound |= write_index >= output_size; if (!out_bound) { - output[write_index] = update[read_index]; + ms_atomic_add(&output[write_index], update[read_index]); } } } @@ -48,7 +50,7 @@ template void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, S *work_shape, cudaStream_t stream) { - ScatterNdKernel<<>>(indices, update, output, block_size, input_size, + ScatterNdKernel<<>>(indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); return; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu index 4c452b116d..66451c2390 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu @@ -22,12 +22,12 @@ __global__ void SGDKernel(const int size, const T dampening, const T weight_deca const T *momentum, const T *lr, T *param, T *accum, T *stat) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { T grad_new = grad[i]; - if (weight_decay != static_cast(0)) { + if (weight_decay > static_cast(0)) { grad_new += param[i] * weight_decay; } - if (momentum[0] != static_cast(0)) { - if (stat[i] == static_cast(0)) { + if (momentum[0] > static_cast(0)) { + if (stat[i] > static_cast(0)) { accum[i] = grad_new; stat[i] = 0; } else { diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 1af25ed1f2..d8775b7c28 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -101,6 +101,8 @@ class BoundingBoxEncode(PrimitiveWithInfer): def infer_shape(self, anchor_box, groundtruth_box): validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, self.name) + validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) + validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name) validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name) return anchor_box @@ -152,6 +154,8 @@ class BoundingBoxDecode(PrimitiveWithInfer): def infer_shape(self, anchor_box, deltas): validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) + validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) + validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name) validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name) return anchor_box