|
|
|
@ -171,9 +171,9 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_batch_size - 1, batch_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rois_batch_size and imgs "
|
|
|
|
|
"batch_size must be the same. But received rois_batch_size = %d, "
|
|
|
|
|
"batch_size = %d",
|
|
|
|
|
"The batch size of rois and the batch size of images "
|
|
|
|
|
" must be the same. But received the batch size of rois is %d, "
|
|
|
|
|
"and the batch size of images is %d",
|
|
|
|
|
rois_batch_size, batch_size));
|
|
|
|
|
auto* rois_lod = rois_lod_t->data<int64_t>();
|
|
|
|
|
for (int n = 0; n < rois_batch_size - 1; ++n) {
|
|
|
|
@ -183,9 +183,10 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto lod = rois->lod();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
lod.empty(), false,
|
|
|
|
|
"Input(ROIs) Tensor of ROIAlignOp does not contain LoD information.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod.empty(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(ROIs) Tensor of ROIAlignOp "
|
|
|
|
|
"does not contain LoD information."));
|
|
|
|
|
auto rois_lod = lod.back();
|
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
@ -196,8 +197,14 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
"batch_size = %d",
|
|
|
|
|
rois_batch_size, batch_size));
|
|
|
|
|
int rois_num_with_lod = rois_lod[rois_batch_size];
|
|
|
|
|
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
|
|
|
|
|
"The rois_num from input and lod must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_num, rois_num_with_lod,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The actual number of rois and the number of rois "
|
|
|
|
|
"provided from Input(RoIsLoD) in RoIAlign must be the same."
|
|
|
|
|
" But received actual number of rois is %d, and the number "
|
|
|
|
|
"of rois from RoIsLoD is %d",
|
|
|
|
|
rois_num, rois_num_with_lod));
|
|
|
|
|
for (int n = 0; n < rois_batch_size; ++n) {
|
|
|
|
|
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
|
|
|
|
|
roi_batch_id_data[i] = n;
|
|
|
|
|