|
|
|
|
@ -39,14 +39,40 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
int width = in_dims[3];
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
const T* input_data = in->data<T>();
|
|
|
|
|
auto rois_lod = rois->lod().back();
|
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_batch_size, batch_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rois_batch_size and imgs batch_size of roi_align_xpu OP must "
|
|
|
|
|
"be the same. But received rois_batch_size %d , batch_size %d",
|
|
|
|
|
rois_batch_size, batch_size));
|
|
|
|
|
|
|
|
|
|
framework::Tensor _roi_batch_list;
|
|
|
|
|
_roi_batch_list.Resize({rois_num});
|
|
|
|
|
int* rois_lod = _roi_batch_list.mutable_data<int>(ctx.GetPlace());
|
|
|
|
|
int rois_batch_size = 1;
|
|
|
|
|
if (ctx.HasInput("RoisNum")) {
|
|
|
|
|
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
|
|
|
|
|
rois_batch_size = rois_num_t->numel();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_batch_size, batch_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The batch size of rois and the batch size of images "
|
|
|
|
|
" must be the same. But received the batch size of rois is %d, "
|
|
|
|
|
"and the batch size of images is %d",
|
|
|
|
|
rois_batch_size, batch_size));
|
|
|
|
|
auto* rois_num_data = rois_num_t->data<int>();
|
|
|
|
|
rois_lod[0] = 0;
|
|
|
|
|
for (int n = 0; n < rois_batch_size; ++n) {
|
|
|
|
|
rois_lod[n + 1] = rois_lod[n] + rois_num_data[n];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto _rois_lod = rois->lod().back();
|
|
|
|
|
rois_batch_size = _rois_lod.size() - 1;
|
|
|
|
|
for (int n = 0; n < _rois_lod.size(); ++n) {
|
|
|
|
|
rois_lod[n] = _rois_lod[n];
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_batch_size, batch_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rois_batch_size and imgs batch_size of roi_align_xpu OP "
|
|
|
|
|
"must "
|
|
|
|
|
"be the same. But received rois_batch_size %d , batch_size %d",
|
|
|
|
|
rois_batch_size, batch_size));
|
|
|
|
|
}
|
|
|
|
|
int rois_num_with_lod = rois_lod[rois_batch_size];
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_num, rois_num_with_lod,
|
|
|
|
|
|