|
|
@ -258,7 +258,11 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
|
|
|
|
roi_batch_id_list.Resize({rois_num});
|
|
|
|
roi_batch_id_list.Resize({rois_num});
|
|
|
|
auto cplace = platform::CPUPlace();
|
|
|
|
auto cplace = platform::CPUPlace();
|
|
|
|
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
|
|
|
|
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
|
|
|
|
auto rois_lod = rois->lod().back();
|
|
|
|
auto lod = rois->lod();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
lod.empty(), false,
|
|
|
|
|
|
|
|
"Input(ROIs) Tensor of ROIAlignOp does not contain LoD information.");
|
|
|
|
|
|
|
|
auto rois_lod = lod.back();
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
rois_batch_size, batch_size,
|
|
|
|
rois_batch_size, batch_size,
|
|
|
|