enhance input check for roi_align, test=develop (#20238)

revert-20712-fix_depthwise_conv
wangguanzhong 6 years ago committed by GitHub
parent c20b11ba11
commit 6fbf441001
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -258,7 +258,11 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.Resize({rois_num});
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto rois_lod = rois->lod().back();
auto lod = rois->lod();
PADDLE_ENFORCE_EQ(
lod.empty(), false,
"Input(ROIs) Tensor of ROIAlignOp does not contain LoD information.");
auto rois_lod = lod.back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,

@ -166,7 +166,11 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
auto lod = rois->lod();
PADDLE_ENFORCE_EQ(
lod.empty(), false,
"Input(ROIs) Tensor of ROIAlignOp does not contain LoD information.");
auto rois_lod = lod.back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,

Loading…
Cancel
Save