|
|
@ -144,7 +144,7 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
|
|
|
|
size_t temp_storage_bytes = 0;
|
|
|
|
size_t temp_storage_bytes = 0;
|
|
|
|
cub::DeviceRadixSort::SortPairsDescending<T, int>(
|
|
|
|
cub::DeviceRadixSort::SortPairsDescending<T, int>(
|
|
|
|
nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
|
|
|
|
nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
|
|
|
|
idx_out, total_roi_num);
|
|
|
|
idx_out, total_roi_num, 0, sizeof(T) * 8, dev_ctx.stream());
|
|
|
|
// Allocate temporary storage
|
|
|
|
// Allocate temporary storage
|
|
|
|
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
|
|
|
|
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
|
|
|
|
|
|
|
|
|
|
|
@ -152,7 +152,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
|
|
|
|
// sort score to get corresponding index
|
|
|
|
// sort score to get corresponding index
|
|
|
|
cub::DeviceRadixSort::SortPairsDescending<T, int>(
|
|
|
|
cub::DeviceRadixSort::SortPairsDescending<T, int>(
|
|
|
|
d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
|
|
|
|
d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
|
|
|
|
keys_out, idx_in, idx_out, total_roi_num);
|
|
|
|
keys_out, idx_in, idx_out, total_roi_num, 0, sizeof(T) * 8,
|
|
|
|
|
|
|
|
dev_ctx.stream());
|
|
|
|
index_out_t.Resize({real_post_num});
|
|
|
|
index_out_t.Resize({real_post_num});
|
|
|
|
Tensor sorted_rois;
|
|
|
|
Tensor sorted_rois;
|
|
|
|
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
|
|
|
|
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
|
|
|
@ -176,7 +177,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
|
|
|
|
temp_storage_bytes = 0;
|
|
|
|
temp_storage_bytes = 0;
|
|
|
|
cub::DeviceRadixSort::SortPairs<int, int>(
|
|
|
|
cub::DeviceRadixSort::SortPairs<int, int>(
|
|
|
|
nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
|
|
|
|
nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
|
|
|
|
batch_idx_in, index_out_t.data<int>(), real_post_num);
|
|
|
|
batch_idx_in, index_out_t.data<int>(), real_post_num, 0,
|
|
|
|
|
|
|
|
sizeof(int) * 8, dev_ctx.stream());
|
|
|
|
// Allocate temporary storage
|
|
|
|
// Allocate temporary storage
|
|
|
|
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
|
|
|
|
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
|
|
|
|
|
|
|
|
|
|
|
@ -184,7 +186,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
|
|
|
|
// sort batch_id to get corresponding index
|
|
|
|
// sort batch_id to get corresponding index
|
|
|
|
cub::DeviceRadixSort::SortPairs<int, int>(
|
|
|
|
cub::DeviceRadixSort::SortPairs<int, int>(
|
|
|
|
d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
|
|
|
|
d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
|
|
|
|
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num);
|
|
|
|
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num, 0,
|
|
|
|
|
|
|
|
sizeof(int) * 8, dev_ctx.stream());
|
|
|
|
|
|
|
|
|
|
|
|
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);
|
|
|
|
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);
|
|
|
|
|
|
|
|
|
|
|
|