From 075acf80b1a915b0132ed873da443b32eb3cd175 Mon Sep 17 00:00:00 2001 From: TFbunny Date: Wed, 29 Jul 2020 10:21:55 -0400 Subject: [PATCH] fix ResizeNearestNeighbor & add util.cuh to support atomicadd for half --- .../resize_nearest_neighbor_gpu_kernel.h | 6 +- .../resize_nearest_neighbor_grad_gpu_kernel.h | 14 ++-- .../resize_nearest_neighbor_grad_impl.cu | 75 +++++++++++-------- .../resize_nearest_neighbor_grad_impl.cuh | 2 +- .../cuda_impl/resize_nearest_neighbor_impl.cu | 14 ++-- .../kernel_compiler/gpu/cuda_impl/util.cuh | 40 ++++++++++ .../test_resize_nearest_neighbor_grad_op.py | 28 +++++-- .../gpu/test_resize_nearest_neighbor_op.py | 10 +++ 8 files changed, 129 insertions(+), 60 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h index 4650b033e5..ac0f6b402e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h @@ -55,15 +55,15 @@ class ResizeNearestNeighborGpuKernel : public GpuKernel { } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; + MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output."; return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); shape_size_ = input_shape.size(); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " - << RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs."; + MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " + << RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs."; return false; } input_size_ = 1; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h index e32ee44894..6d32d8da73 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h @@ -38,10 +38,10 @@ class ResizeNearestNeighborGradGpuKernel : public GpuKernel { const std::vector &outputs, void *stream_ptr) override { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); - int size = SizeToInt(output_size_ / sizeof(T)); - float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_); - float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_); - CalResizeNearestNeighborGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], + int input_size = SizeToInt(input_size_ / sizeof(T)); + float h_scale = Scaling(output_shape_[2], input_shape_[2], align_corners_); + float w_scale = Scaling(output_shape_[3], input_shape_[3], align_corners_); + CalResizeNearestNeighborGrad(input_size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], align_corners_, h_scale, w_scale, reinterpret_cast(stream_ptr)); return true; @@ -55,15 +55,15 @@ class ResizeNearestNeighborGradGpuKernel : public GpuKernel { } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; + MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output."; return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); shape_size_ = input_shape.size(); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " - << RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs."; + MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " + << RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs."; return false; } input_size_ = 1; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu index 546960b139..edb509a38d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu @@ -18,64 +18,73 @@ #include #include #include +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" template -__global__ void ResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, - const int s4, T *output, const int d1, const int d2, const int d3, - const int d4, bool align_corners, float h_scale, float w_scale) { +__global__ void InitZero(T *output, const int output_size) { + for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (output_size); pos += gridDim.x * blockDim.x) { + output[pos] = static_cast(0); + } +} + +template +__global__ void ResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, + const int s3, const int s4, T *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale) { // initialization // HalfPixelCenters false - int input_pos; + int output_pos; int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION]; - int in_height = s3; - int in_width = s4; + int out_height = d3; + int out_width = d4; // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + // pos_array[1] * output_shape[2] * output_shape[3] + // pos_array[2] * output_shape[3] + // pos_array[3] - T h_scale_ = static_cast(h_scale); - T w_scale_ = static_cast(w_scale); - T out_h_; - T out_w_; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - pos_array[0] = pos / (d2 * d3 * d4) % d1; - pos_array[1] = pos / (d3 * d4) % d2; - pos_array[2] = pos / (d4) % d3; - pos_array[3] = pos % d4; - out_h_ = static_cast(pos_array[2]); - out_w_ = static_cast(pos_array[3]); - const int in_y = - min((align_corners) ? static_cast(roundf(out_h_ * h_scale_)) : static_cast(floorf(out_h_ * h_scale_)), - in_height - 1); - const int in_x = - min((align_corners) ? static_cast(roundf(out_w_ * w_scale_)) : static_cast(floorf(out_w_ * w_scale_)), - in_width - 1); - // pos_array[0] N, pos_array[1] C, in_y H, in_x W - input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; - output[pos] = input[input_pos]; + int in_h; + int in_w; + + for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (input_size); pos += gridDim.x * blockDim.x) { + pos_array[0] = pos / (s2 * s3 * s4) % s1; + pos_array[1] = pos / (s3 * s4) % s2; + pos_array[2] = pos / (s4) % s3; + pos_array[3] = pos % s4; + in_h = pos_array[2]; + in_w = pos_array[3]; + const int out_y = + min((align_corners) ? static_cast(roundf(in_h * h_scale)) : static_cast(floorf(in_h * h_scale)), + out_height - 1); + const int out_x = + min((align_corners) ? static_cast(roundf(in_w * w_scale)) : static_cast(floorf(in_w * w_scale)), + out_width - 1); + // pos_array[0] N, pos_array[1] C, out_y H, out_x W + output_pos = pos_array[0] * d2 * d3 * d4 + pos_array[1] * d3 * d4 + out_y * d4 + out_x; + ms_atomic_add(&output[output_pos], input[pos]); } - return; } template -void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, +void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3, const int s4, T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) { - ResizeNearestNeighborGrad<<>>( - size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale); + int output_size = d1 * d2 * d3 * d4; + InitZero<<>>(output, output_size); + ResizeNearestNeighborGrad<<>>( + input_size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale); return; } -template void CalResizeNearestNeighborGrad(const int size, const float *input, const int s1, const int s2, +template void CalResizeNearestNeighborGrad(const int input_size, const float *input, const int s1, const int s2, const int s3, const int s4, float *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); -template void CalResizeNearestNeighborGrad(const int size, const half *input, const int s1, const int s2, +template void CalResizeNearestNeighborGrad(const int input_size, const half *input, const int s1, const int s2, const int s3, const int s4, half *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); -template void CalResizeNearestNeighborGrad(const int size, const int *input, const int s1, const int s2, +template void CalResizeNearestNeighborGrad(const int input_size, const int *input, const int s1, const int s2, const int s3, const int s4, int *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh index d1acdaab51..c7f85e694a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh @@ -21,7 +21,7 @@ #define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4 template -void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, +void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3, const int s4, T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu index 4280e33fd3..2cca9bd7a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu @@ -34,22 +34,20 @@ __global__ void ResizeNearestNeighbor(const int size, const T *input, const int // pos_array[1] * output_shape[2] * output_shape[3] + // pos_array[2] * output_shape[3] + // pos_array[3] - T h_scale_ = static_cast(h_scale); - T w_scale_ = static_cast(w_scale); - T out_h_; - T out_w_; + int out_h; + int out_w; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { pos_array[0] = pos / (d2 * d3 * d4) % d1; pos_array[1] = pos / (d3 * d4) % d2; pos_array[2] = pos / (d4) % d3; pos_array[3] = pos % d4; - out_h_ = static_cast(pos_array[2]); - out_w_ = static_cast(pos_array[3]); + out_h = pos_array[2]; + out_w = pos_array[3]; const int in_y = - min((align_corners) ? static_cast(roundf(out_h_ * h_scale_)) : static_cast(floorf(out_h_ * h_scale_)), + min((align_corners) ? static_cast(roundf(out_h * h_scale)) : static_cast(floorf(out_h * h_scale)), in_height - 1); const int in_x = - min((align_corners) ? static_cast(roundf(out_w_ * w_scale_)) : static_cast(floorf(out_w_ * w_scale_)), + min((align_corners) ? static_cast(roundf(out_w * w_scale)) : static_cast(floorf(out_w * w_scale)), in_width - 1); // pos_array[0] N, pos_array[1] C, in_y H, in_x W input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh new file mode 100644 index 0000000000..9da273a661 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); } + +inline __device__ int ms_atomic_add(int *address, int val) { return atomicAdd(address, val); } + +inline __device__ half ms_atomic_add(half *address, half val) { + unsigned int *aligned = + reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); + unsigned int old = *aligned; + unsigned int assumed; + unsigned short old_as_us; //NOLINT + do { + assumed = old; + old_as_us = static_cast(reinterpret_cast(address) & 2 ? old >> 16 : old & 0xffff); //NOLINT + half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast(val)); + unsigned short sum_as_us = __half_as_ushort(sum); //NOLINT + unsigned int sum_as_ui = + reinterpret_cast(address) & 2 ? (sum_as_us << 16) | (old & 0xffff) : (old & 0xffff0000) | sum_as_us; + old = atomicCAS(aligned, assumed, sum_as_ui); + } while (assumed != old); + __half_raw raw = {old_as_us}; + return half(raw); +} diff --git a/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py b/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py index 70de771e7d..0203c87204 100644 --- a/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py +++ b/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py @@ -43,15 +43,21 @@ class ResizeNearestNeighborGradAlignCornerF(nn.Cell): @pytest.mark.env_onecard def test_ResizeNearestNeighborGradAlignCornerT(): context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) - size = (2, 2) - expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) + size = (4, 4) + expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float32) rnn = ResizeNearestNeighborGradAlignCornerT() output = rnn(Tensor(dy), size) assert np.all(output.asnumpy() == expect) - dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) - size = (2, 2) - expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16) + size = (4, 4) + expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerT() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.int32) + size = (4, 4) + expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.int32) rnn = ResizeNearestNeighborGradAlignCornerT() output = rnn(Tensor(dy), size) assert np.all(output.asnumpy() == expect) @@ -63,13 +69,19 @@ def test_ResizeNearestNeighborGradAlignCornerF(): context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) size = (2, 2) - expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float32) rnn = ResizeNearestNeighborGradAlignCornerF() output = rnn(Tensor(dy), size) assert np.all(output.asnumpy() == expect) dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) size = (2, 2) - expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerF() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) + size = (2, 2) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.int32) rnn = ResizeNearestNeighborGradAlignCornerF() output = rnn(Tensor(dy), size) assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py b/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py index 79101af06e..1438d69637 100644 --- a/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py +++ b/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py @@ -53,6 +53,11 @@ def test_ResizeNearestNeighborAlignCornerT(): rnn = ResizeNearestNeighborAlignCornerT((4, 4)) output = rnn(input_tensor) assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) + rnn = ResizeNearestNeighborAlignCornerT((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -69,3 +74,8 @@ def test_ResizeNearestNeighborAlignCornerF(): rnn = ResizeNearestNeighborAlignCornerF((4, 4)) output = rnn(input_tensor) assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) + rnn = ResizeNearestNeighborAlignCornerF((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect)