enhance nms for mask rcnn

fix/ir_debug
jerrywgz 6 years ago
parent 3f815e079f
commit 88ee56d0b2

@ -93,5 +93,25 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
}
}
template <class T>
void SliceOneClass(const platform::DeviceContext& ctx,
const framework::Tensor& items, const int class_id,
framework::Tensor* one_class_item) {
T* item_data = one_class_item->mutable_data<T>(ctx.GetPlace());
const T* items_data = items.data<T>();
const int64_t num_item = items.dims()[0];
const int class_num = items.dims()[1];
int item_size = 1;
if (items.dims().size() == 3) {
item_size = items.dims()[2];
}
for (int i = 0; i < num_item; ++i) {
for (int j = 0; j < item_size; ++j) {
item_data[i * item_size + j] =
items_data[i * class_num * item_size + class_id * item_size + j];
}
}
}
} // namespace operators
} // namespace paddle

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save