fix ResizeNearestNeighbor & add util.cuh to support atomicadd for half

pull/3760/head
TFbunny 5 years ago
parent 760cd6829b
commit 075acf80b1

@ -55,15 +55,15 @@ class ResizeNearestNeighborGpuKernel : public GpuKernel {
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output.";
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
shape_size_ = input_shape.size();
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
<< RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs.";
MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
<< RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs.";
return false;
}
input_size_ = 1;

@ -38,10 +38,10 @@ class ResizeNearestNeighborGradGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
int size = SizeToInt(output_size_ / sizeof(T));
float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_);
float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_);
CalResizeNearestNeighborGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3],
int input_size = SizeToInt(input_size_ / sizeof(T));
float h_scale = Scaling(output_shape_[2], input_shape_[2], align_corners_);
float w_scale = Scaling(output_shape_[3], input_shape_[3], align_corners_);
CalResizeNearestNeighborGrad(input_size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3],
output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3],
align_corners_, h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
@ -55,15 +55,15 @@ class ResizeNearestNeighborGradGpuKernel : public GpuKernel {
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output.";
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
shape_size_ = input_shape.size();
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
<< RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs.";
MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
<< RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs.";
return false;
}
input_size_ = 1;

@ -18,64 +18,73 @@
#include <stdint.h>
#include <math.h>
#include <algorithm>
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh"
template <typename T>
__global__ void ResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
const int s4, T *output, const int d1, const int d2, const int d3,
const int d4, bool align_corners, float h_scale, float w_scale) {
__global__ void InitZero(T *output, const int output_size) {
for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (output_size); pos += gridDim.x * blockDim.x) {
output[pos] = static_cast<T>(0);
}
}
template <typename T>
__global__ void ResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2,
const int s3, const int s4, T *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale) {
// initialization
// HalfPixelCenters false
int input_pos;
int output_pos;
int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION];
int in_height = s3;
int in_width = s4;
int out_height = d3;
int out_width = d4;
// for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] +
// pos_array[1] * output_shape[2] * output_shape[3] +
// pos_array[2] * output_shape[3] +
// pos_array[3]
T h_scale_ = static_cast<T>(h_scale);
T w_scale_ = static_cast<T>(w_scale);
T out_h_;
T out_w_;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
pos_array[0] = pos / (d2 * d3 * d4) % d1;
pos_array[1] = pos / (d3 * d4) % d2;
pos_array[2] = pos / (d4) % d3;
pos_array[3] = pos % d4;
out_h_ = static_cast<T>(pos_array[2]);
out_w_ = static_cast<T>(pos_array[3]);
const int in_y =
min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)),
in_height - 1);
const int in_x =
min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)),
in_width - 1);
// pos_array[0] N, pos_array[1] C, in_y H, in_x W
input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x;
output[pos] = input[input_pos];
int in_h;
int in_w;
for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (input_size); pos += gridDim.x * blockDim.x) {
pos_array[0] = pos / (s2 * s3 * s4) % s1;
pos_array[1] = pos / (s3 * s4) % s2;
pos_array[2] = pos / (s4) % s3;
pos_array[3] = pos % s4;
in_h = pos_array[2];
in_w = pos_array[3];
const int out_y =
min((align_corners) ? static_cast<int>(roundf(in_h * h_scale)) : static_cast<int>(floorf(in_h * h_scale)),
out_height - 1);
const int out_x =
min((align_corners) ? static_cast<int>(roundf(in_w * w_scale)) : static_cast<int>(floorf(in_w * w_scale)),
out_width - 1);
// pos_array[0] N, pos_array[1] C, out_y H, out_x W
output_pos = pos_array[0] * d2 * d3 * d4 + pos_array[1] * d3 * d4 + out_y * d4 + out_x;
ms_atomic_add(&output[output_pos], input[pos]);
}
return;
}
template <typename T>
void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3,
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) {
ResizeNearestNeighborGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale);
int output_size = d1 * d2 * d3 * d4;
InitZero<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(output, output_size);
ResizeNearestNeighborGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(
input_size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale);
return;
}
template void CalResizeNearestNeighborGrad<float>(const int size, const float *input, const int s1, const int s2,
template void CalResizeNearestNeighborGrad<float>(const int input_size, const float *input, const int s1, const int s2,
const int s3, const int s4, float *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);
template void CalResizeNearestNeighborGrad<half>(const int size, const half *input, const int s1, const int s2,
template void CalResizeNearestNeighborGrad<half>(const int input_size, const half *input, const int s1, const int s2,
const int s3, const int s4, half *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);
template void CalResizeNearestNeighborGrad<int>(const int size, const int *input, const int s1, const int s2,
template void CalResizeNearestNeighborGrad<int>(const int input_size, const int *input, const int s1, const int s2,
const int s3, const int s4, int *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);

@ -21,7 +21,7 @@
#define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4
template <typename T>
void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3,
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream);

@ -34,22 +34,20 @@ __global__ void ResizeNearestNeighbor(const int size, const T *input, const int
// pos_array[1] * output_shape[2] * output_shape[3] +
// pos_array[2] * output_shape[3] +
// pos_array[3]
T h_scale_ = static_cast<T>(h_scale);
T w_scale_ = static_cast<T>(w_scale);
T out_h_;
T out_w_;
int out_h;
int out_w;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
pos_array[0] = pos / (d2 * d3 * d4) % d1;
pos_array[1] = pos / (d3 * d4) % d2;
pos_array[2] = pos / (d4) % d3;
pos_array[3] = pos % d4;
out_h_ = static_cast<T>(pos_array[2]);
out_w_ = static_cast<T>(pos_array[3]);
out_h = pos_array[2];
out_w = pos_array[3];
const int in_y =
min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)),
min((align_corners) ? static_cast<int>(roundf(out_h * h_scale)) : static_cast<int>(floorf(out_h * h_scale)),
in_height - 1);
const int in_x =
min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)),
min((align_corners) ? static_cast<int>(roundf(out_w * w_scale)) : static_cast<int>(floorf(out_w * w_scale)),
in_width - 1);
// pos_array[0] N, pos_array[1] C, in_y H, in_x W
input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x;

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); }
inline __device__ int ms_atomic_add(int *address, int val) { return atomicAdd(address, val); }
inline __device__ half ms_atomic_add(half *address, half val) {
unsigned int *aligned =
reinterpret_cast<unsigned int *>(reinterpret_cast<size_t>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *aligned;
unsigned int assumed;
unsigned short old_as_us; //NOLINT
do {
assumed = old;
old_as_us = static_cast<unsigned short>(reinterpret_cast<size_t>(address) & 2 ? old >> 16 : old & 0xffff); //NOLINT
half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast<float>(val));
unsigned short sum_as_us = __half_as_ushort(sum); //NOLINT
unsigned int sum_as_ui =
reinterpret_cast<size_t>(address) & 2 ? (sum_as_us << 16) | (old & 0xffff) : (old & 0xffff0000) | sum_as_us;
old = atomicCAS(aligned, assumed, sum_as_ui);
} while (assumed != old);
__half_raw raw = {old_as_us};
return half(raw);
}

@ -43,15 +43,21 @@ class ResizeNearestNeighborGradAlignCornerF(nn.Cell):
@pytest.mark.env_onecard
def test_ResizeNearestNeighborGradAlignCornerT():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32)
size = (4, 4)
expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float32)
rnn = ResizeNearestNeighborGradAlignCornerT()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16)
size = (4, 4)
expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float16)
rnn = ResizeNearestNeighborGradAlignCornerT()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.int32)
size = (4, 4)
expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.int32)
rnn = ResizeNearestNeighborGradAlignCornerT()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
@ -63,13 +69,19 @@ def test_ResizeNearestNeighborGradAlignCornerF():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)
expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float32)
rnn = ResizeNearestNeighborGradAlignCornerF()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)
expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float16)
rnn = ResizeNearestNeighborGradAlignCornerF()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32)
size = (2, 2)
expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.int32)
rnn = ResizeNearestNeighborGradAlignCornerF()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)

@ -53,6 +53,11 @@ def test_ResizeNearestNeighborAlignCornerT():
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32))
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32)
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -69,3 +74,8 @@ def test_ResizeNearestNeighborAlignCornerF():
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32))
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32)
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)

Loading…
Cancel
Save