|
|
@ -293,19 +293,24 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PADDLE_ENFORCE_EQ(rois->lod().empty(), false,
|
|
|
|
PADDLE_ENFORCE_EQ(rois->lod().empty(), false,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"the lod of Input ROIs should not be empty when "
|
|
|
|
"The lod of Input ROIs should not be empty when "
|
|
|
|
"BatchRoINums is None!"));
|
|
|
|
"BatchRoINums is None!"));
|
|
|
|
auto rois_lod = rois->lod().back();
|
|
|
|
auto rois_lod = rois->lod().back();
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(rois_batch_size, batch_size,
|
|
|
|
rois_batch_size, batch_size,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument("the rois_batch_size and input(X) "
|
|
|
|
"The rois_batch_size and input(X)'s "
|
|
|
|
"batch_size should be the same."));
|
|
|
|
"batch_size should be the same but received"
|
|
|
|
|
|
|
|
"rois_batch_size: %d and batch_size: %d",
|
|
|
|
|
|
|
|
rois_batch_size, batch_size));
|
|
|
|
int rois_num_with_lod = rois_lod[rois_batch_size];
|
|
|
|
int rois_num_with_lod = rois_lod[rois_batch_size];
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
rois_num_with_lod, rois_num,
|
|
|
|
rois_num_with_lod, rois_num,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument("The rois_num from input should be "
|
|
|
|
"the rois_num from input and lod must be the same"));
|
|
|
|
"equal to the rois_num from lod, "
|
|
|
|
|
|
|
|
"but received rois_num from input: "
|
|
|
|
|
|
|
|
"%d and the rois_num from lod: %d.",
|
|
|
|
|
|
|
|
rois_num_with_lod, rois_num));
|
|
|
|
|
|
|
|
|
|
|
|
// calculate batch id index for each roi according to LoD
|
|
|
|
// calculate batch id index for each roi according to LoD
|
|
|
|
for (int n = 0; n < rois_batch_size; ++n) {
|
|
|
|
for (int n = 0; n < rois_batch_size; ++n) {
|
|
|
|