!6226 RoiAlign gpu kennel not matched with D-chip

Merge pull request !6226 from JonathanY/rcnn
pull/6226/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 475ed7a321

@ -20,11 +20,14 @@
inline __device__ int roi_cast_int(float x) { return __float2int_rd(x); }
inline __device__ int roi_cast_int(half x) { return __half2int_rd(x); }
inline __device__ int roi_round_int(float x) { return __float2int_rn(x + 0.00007); }
inline __device__ int roi_round_int(half x) { return __half2int_rn(x + static_cast<half>(0.00007)); }
template <typename T>
__device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high,
int *y_high, T *w1, T *w2, T *w3, T *w4) {
// return 0 if out of map boundary
constexpr float eps = 0.00007;
if (y < static_cast<T>(-1.0) || y > static_cast<T>(height) || x < static_cast<T>(-1.0) || x > static_cast<T>(width)) {
*w1 = *w2 = *w3 = *w4 = 0;
*x_low = *x_high = *y_low = *y_high = -1;
@ -36,8 +39,8 @@ __device__ void bilinear_interpolate(const int height, const int width, T y, T x
x = x <= static_cast<T>(.0) ? static_cast<T>(.0) : x;
// top left point
*y_low = roi_cast_int(y);
*x_low = roi_cast_int(x);
*y_low = (y <= static_cast<T>(eps) ? 0 : roi_cast_int(y));
*x_low = (x <= static_cast<T>(eps) ? 0 : roi_cast_int(x));
// bottom right point
if (*y_low >= height - 1) {
@ -83,7 +86,7 @@ __device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const
const T *roi_box = roi_boxes + (*n) * roi_cols;
int roi_batch_ind = 0;
if (roi_cols == 5) {
roi_batch_ind = roi_box[0];
roi_batch_ind = roi_round_int(roi_box[0]);
roi_box++;
}
@ -124,13 +127,17 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes,
roi_box[2] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
continue;
}
int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
int offset = -1;
int c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w;
bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,
pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h,
&bin_size_w, &roi_start_h, &roi_start_w);
if (offset < 0 || offset >= size) continue;
// (n, c, ph, pw) is the base param of pooled map
const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w;
@ -147,7 +154,8 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes,
int x_low = 0, y_low = 0, x_high = 0, y_high = 0;
T w1, w2, w3, w4;
bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4);
if (x_low != -1 || x_high != -1 || y_low != -1 || y_high != -1) {
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0 && y_low < height && y_high < height &&
x_low < width && x_high < width) {
T v1 = input[offset + y_low * width + x_low];
T v2 = input[offset + y_low * width + x_high];
T v3 = input[offset + y_high * width + x_low];
@ -185,6 +193,14 @@ template void ROIAlign<half>(const half *x, const half *roi_boxes, int roi_rows,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);
template <typename T>
__global__ void ROIAlignGradInitKernel(size_t size_init, T *dx) {
for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size_init;
thread_idx += blockDim.x * gridDim.x) {
dx[thread_idx] = static_cast<T>(.0);
}
}
template <typename T>
__global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, int roi_cols, T *dx,
const T spatial_scale, const int sample_num, int roi_end_mode, const int channels,
@ -200,13 +216,16 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
continue;
}
int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
int offset = -1;
int c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w;
bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,
pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h,
&bin_size_w, &roi_start_h, &roi_start_w);
if (offset < 0 || offset >= size) continue;
// (n, c, ph, pw) is the base param of pooled map
const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w;
@ -226,7 +245,8 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
int x_low = 0, y_low = 0, x_high = 0, y_high = 0;
T w1, w2, w3, w4;
bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4);
if (x_low != -1 || x_high != -1 || y_low != -1 || y_high != -1) {
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0 && y_low < height && y_high < height &&
x_low < width && x_high < width) {
T g1 = top_diff_this_bin * w1 / count_points_in_grid_cell;
T g2 = top_diff_this_bin * w2 / count_points_in_grid_cell;
T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell;
@ -236,12 +256,11 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
T *dx_2 = dx + offset + y_low * width + x_high;
T *dx_3 = dx + offset + y_high * width + x_low;
T *dx_4 = dx + offset + y_high * width + x_high;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
MsAtomicAdd(dx_1, g1);
MsAtomicAdd(dx_2, g2);
MsAtomicAdd(dx_3, g3);
MsAtomicAdd(dx_4, g4);
}
MsAtomicAdd(dx_1, g1);
MsAtomicAdd(dx_2, g2);
MsAtomicAdd(dx_3, g3);
MsAtomicAdd(dx_4, g4);
}
}
}
@ -249,9 +268,12 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
}
template <typename T>
void ROIAlignGrad(const T *dy, const T *roi_boxes, int roi_rows, int roi_cols, T *dx, const T spatial_scale,
const int sample_num, int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) {
void ROIAlignGrad(const T *dy, const T *roi_boxes, int batch_size, int roi_rows, int roi_cols, T *dx,
const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) {
size_t size_init = batch_size * channels * height * width;
ROIAlignGradInitKernel<<<GET_BLOCKS(size_init), GET_THREADS, 0, cuda_stream>>>(size_init, dx);
size_t size = roi_rows * channels * pooled_height * pooled_width;
ROIAlignGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, dy, roi_boxes, roi_cols, dx, spatial_scale, sample_num, roi_end_mode, channels, height, width, pooled_height,
@ -259,12 +281,12 @@ void ROIAlignGrad(const T *dy, const T *roi_boxes, int roi_rows, int roi_cols, T
return;
}
template void ROIAlignGrad<float>(const float *dy, const float *roi_boxes, int roi_rows, int roi_cols, float *dx,
const float spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);
template void ROIAlignGrad<float>(const float *dy, const float *roi_boxes, int batch_size, int roi_rows, int roi_cols,
float *dx, const float spatial_scale, const int sample_num, int roi_end_mode,
const int channels, const int height, const int width, const int pooled_height,
const int pooled_width, cudaStream_t cuda_stream);
template void ROIAlignGrad<half>(const half *dy, const half *roi_boxes, int roi_rows, int roi_cols, half *dx,
const half spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);
template void ROIAlignGrad<half>(const half *dy, const half *roi_boxes, int batch_size, int roi_rows, int roi_cols,
half *dx, const half spatial_scale, const int sample_num, int roi_end_mode,
const int channels, const int height, const int width, const int pooled_height,
const int pooled_width, cudaStream_t cuda_stream);

@ -22,8 +22,8 @@ void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream);
template <typename T>
void ROIAlignGrad(const T *dy, const T *roi_boxes, int roi_rows, int roi_cols, T *dx, const T spatial_scale,
const int sample_num, int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream);
void ROIAlignGrad(const T *dy, const T *roi_boxes, int batch_size, int roi_rows, int roi_cols, T *dx,
const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_

@ -40,8 +40,8 @@ class ROIAlignGradGpuFwdKernel : public GpuKernel {
T *dx = GetDeviceAddress<T>(outputs, 0);
ROIAlignGrad(dy, rois, roi_rows_, roi_cols_, dx, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_,
width_, pooled_height_, pooled_width_, reinterpret_cast<cudaStream_t>(stream_ptr));
ROIAlignGrad(dy, rois, batch_size_, roi_rows_, roi_cols_, dx, spatial_scale_, sample_num_, roi_end_mode_, channels_,
height_, width_, pooled_height_, pooled_width_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

Loading…
Cancel
Save