run radix sort of proposals layer on context stream (#31631)

2.0.1-rocm-post
zlsh80826 4 years ago committed by GitHub
parent e429deb0c4
commit 1c67cf0c98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,7 +66,8 @@ static void SortDescending(const platform::CUDADeviceContext &ctx,
// Determine temporary device storage requirements // Determine temporary device storage requirements
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num); nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num, 0,
sizeof(T) * 8, ctx.stream());
// Allocate temporary storage // Allocate temporary storage
auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
@ -74,7 +75,7 @@ static void SortDescending(const platform::CUDADeviceContext &ctx,
// Run sorting operation // Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in, d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in,
idx_out, num); idx_out, num, 0, sizeof(T) * 8, ctx.stream());
} }
template <typename T> template <typename T>

@ -144,7 +144,7 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in, nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
idx_out, total_roi_num); idx_out, total_roi_num, 0, sizeof(T) * 8, dev_ctx.stream());
// Allocate temporary storage // Allocate temporary storage
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
@ -152,7 +152,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
// sort score to get corresponding index // sort score to get corresponding index
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(), d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
keys_out, idx_in, idx_out, total_roi_num); keys_out, idx_in, idx_out, total_roi_num, 0, sizeof(T) * 8,
dev_ctx.stream());
index_out_t.Resize({real_post_num}); index_out_t.Resize({real_post_num});
Tensor sorted_rois; Tensor sorted_rois;
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
@ -176,7 +177,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
temp_storage_bytes = 0; temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs<int, int>( cub::DeviceRadixSort::SortPairs<int, int>(
nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data, nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
batch_idx_in, index_out_t.data<int>(), real_post_num); batch_idx_in, index_out_t.data<int>(), real_post_num, 0,
sizeof(int) * 8, dev_ctx.stream());
// Allocate temporary storage // Allocate temporary storage
d_temp_storage = memory::Alloc(place, temp_storage_bytes); d_temp_storage = memory::Alloc(place, temp_storage_bytes);
@ -184,7 +186,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
// sort batch_id to get corresponding index // sort batch_id to get corresponding index
cub::DeviceRadixSort::SortPairs<int, int>( cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(), d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num); out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num, 0,
sizeof(int) * 8, dev_ctx.stream());
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois); GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);

@ -149,9 +149,9 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
// Determine temporary device storage requirements // Determine temporary device storage requirements
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs<int, int>(nullptr, temp_storage_bytes, cub::DeviceRadixSort::SortPairs<int, int>(
target_lvls_data, keys_out, nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in,
idx_in, idx_out, roi_num); idx_out, roi_num, 0, sizeof(int) * 8, dev_ctx.stream());
// Allocate temporary storage // Allocate temporary storage
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
@ -159,14 +159,14 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
// sort target level to get corresponding index // sort target level to get corresponding index
cub::DeviceRadixSort::SortPairs<int, int>( cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out, d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out,
idx_in, idx_out, roi_num); idx_in, idx_out, roi_num, 0, sizeof(int) * 8, dev_ctx.stream());
int* restore_idx_data = int* restore_idx_data =
restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace()); restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace());
// sort current index to get restore index // sort current index to get restore index
cub::DeviceRadixSort::SortPairs<int, int>( cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in, d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in,
restore_idx_data, roi_num); restore_idx_data, roi_num, 0, sizeof(int) * 8, dev_ctx.stream());
int start = 0; int start = 0;
auto multi_rois_num = ctx.MultiOutput<Tensor>("MultiLevelRoIsNum"); auto multi_rois_num = ctx.MultiOutput<Tensor>("MultiLevelRoIsNum");

Loading…
Cancel
Save