|
|
|
@ -305,10 +305,10 @@ class GenerateMaskLabelsKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(gt_segms->lod()[0].size() - 1, n);
|
|
|
|
|
|
|
|
|
|
int mask_dim = num_classes * resolution * resolution;
|
|
|
|
|
|
|
|
|
|
mask_rois->mutable_data<T>({rois->numel(), kBoxDim}, ctx.GetPlace());
|
|
|
|
|
roi_has_mask_int32->mutable_data<int>({rois->numel(), 1}, ctx.GetPlace());
|
|
|
|
|
mask_int32->mutable_data<int>({rois->numel(), mask_dim}, ctx.GetPlace());
|
|
|
|
|
int roi_num = rois->lod().back()[n];
|
|
|
|
|
mask_rois->mutable_data<T>({roi_num, kBoxDim}, ctx.GetPlace());
|
|
|
|
|
roi_has_mask_int32->mutable_data<int>({roi_num, 1}, ctx.GetPlace());
|
|
|
|
|
mask_int32->mutable_data<int>({roi_num, mask_dim}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
framework::LoD lod;
|
|
|
|
|
std::vector<size_t> lod0(1, 0);
|
|
|
|
|