!6005 Optimize ROI Align kernel

Merge pull request !6005 from JonathanY/rcnn
pull/6005/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit cd2646c98e

@ -116,7 +116,15 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes,
const int height, const int width, const int pooled_height, const int pooled_width) {
for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size;
thread_idx += blockDim.x * gridDim.x) {
int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
int n = thread_idx / pooled_width / pooled_height / channels;
const T *roi_box = roi_boxes + n * roi_cols;
if (roi_box[0] < static_cast<T>(0.001) && roi_box[1] < static_cast<T>(0.001) &&
roi_box[2] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[0] > static_cast<T>(-0.001) && roi_box[1] > static_cast<T>(-0.001) &&
roi_box[2] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
continue;
}
int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w;
bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,
@ -183,7 +191,16 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
const int height, const int width, const int pooled_height, const int pooled_width) {
for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size;
thread_idx += blockDim.x * gridDim.x) {
int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
int n = thread_idx / pooled_width / pooled_height / channels;
const T *roi_box = roi_boxes + n * roi_cols;
if (roi_box[0] < static_cast<T>(0.001) && roi_box[1] < static_cast<T>(0.001) &&
roi_box[2] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[0] > static_cast<T>(-0.001) && roi_box[1] > static_cast<T>(-0.001) &&
roi_box[2] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
continue;
}
int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w;
bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,

Loading…
Cancel
Save