Merge pull request #15952 from jerrywgz/fpn_ops
	
		
	
				
					
				
			add distribute fpn proposals op, test=developalign_pyramid
						commit
						b0e3c02410
					
				@ -0,0 +1,93 @@
 | 
				
			||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput("FpnRois"),
 | 
				
			||||
                   "Input(FpnRois) shouldn't be null");
 | 
				
			||||
    PADDLE_ENFORCE_GE(
 | 
				
			||||
        ctx->Outputs("MultiFpnRois").size(), 1UL,
 | 
				
			||||
        "Outputs(MultiFpnRois) of DistributeOp should not be empty");
 | 
				
			||||
    size_t min_level = static_cast<size_t>(ctx->Attrs().Get<int>("min_level"));
 | 
				
			||||
    size_t max_level = static_cast<size_t>(ctx->Attrs().Get<int>("max_level"));
 | 
				
			||||
    PADDLE_ENFORCE_GE(max_level, min_level,
 | 
				
			||||
                      "max_level must not lower than min_level");
 | 
				
			||||
    // Set the output shape
 | 
				
			||||
    size_t num_out_rois = max_level - min_level + 1;
 | 
				
			||||
    std::vector<framework::DDim> outs_dims;
 | 
				
			||||
    outs_dims.reserve(num_out_rois);
 | 
				
			||||
    for (size_t i = 0; i < num_out_rois; ++i) {
 | 
				
			||||
      framework::DDim out_dim = {-1, 4};
 | 
				
			||||
      outs_dims.push_back(out_dim);
 | 
				
			||||
    }
 | 
				
			||||
    ctx->SetOutputsDim("MultiFpnRois", outs_dims);
 | 
				
			||||
    ctx->SetOutputDim("RestoreIndex", {1, -1});
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  framework::OpKernelType GetExpectedKernelType(
 | 
				
			||||
      const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois"));
 | 
				
			||||
    return framework::OpKernelType(data_type, platform::CPUPlace());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class DistributeFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
 | 
				
			||||
 public:
 | 
				
			||||
  void Make() override {
 | 
				
			||||
    AddInput("FpnRois", "(LoDTensor) The rois at all levels in shape (-1, 4)");
 | 
				
			||||
    AddOutput("MultiFpnRois", "(LoDTensor) Output with distribute operator")
 | 
				
			||||
        .AsDuplicable();
 | 
				
			||||
    AddOutput("RestoreIndex",
 | 
				
			||||
              "(Tensor) An array of positive number which is "
 | 
				
			||||
              "used to restore the order of FpnRois");
 | 
				
			||||
    AddAttr<int>("min_level",
 | 
				
			||||
                 "The lowest level of FPN layer where the"
 | 
				
			||||
                 " proposals come from");
 | 
				
			||||
    AddAttr<int>("max_level",
 | 
				
			||||
                 "The highest level of FPN layer where the"
 | 
				
			||||
                 " proposals come from");
 | 
				
			||||
    AddAttr<int>("refer_level",
 | 
				
			||||
                 "The referring level of FPN layer with"
 | 
				
			||||
                 " specified scale");
 | 
				
			||||
    AddAttr<int>("refer_scale",
 | 
				
			||||
                 "The referring scale of FPN layer with"
 | 
				
			||||
                 " specified level");
 | 
				
			||||
    AddComment(R"DOC(
 | 
				
			||||
This operator distribute all proposals into different fpn level,
 | 
				
			||||
 with respect to scale of the proposals, the referring scale and
 | 
				
			||||
 the referring level. Besides, to restore the order of proposals,
 | 
				
			||||
we return an array which indicate the original index of rois in
 | 
				
			||||
 current proposals.
 | 
				
			||||
)DOC");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
REGISTER_OPERATOR(distribute_fpn_proposals, ops::DistributeFpnProposalsOp,
 | 
				
			||||
                  ops::DistributeFpnProposalsOpMaker,
 | 
				
			||||
                  paddle::framework::EmptyGradOpMaker);
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
 | 
				
			||||
                       ops::DistributeFpnProposalsOpKernel<float>,
 | 
				
			||||
                       ops::DistributeFpnProposalsOpKernel<double>);
 | 
				
			||||
@ -0,0 +1,221 @@
 | 
				
			||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include <paddle/fluid/memory/allocation/allocator.h>
 | 
				
			||||
#include "cub/cub.cuh"
 | 
				
			||||
#include "paddle/fluid/memory/memcpy.h"
 | 
				
			||||
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
 | 
				
			||||
#include "paddle/fluid/operators/gather.cu.h"
 | 
				
			||||
#include "paddle/fluid/platform/cuda_primitives.h"
 | 
				
			||||
#include "paddle/fluid/platform/for_range.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
using LoDTensor = framework::LoDTensor;
 | 
				
			||||
 | 
				
			||||
static constexpr int kNumCUDAThreads = 512;
 | 
				
			||||
static constexpr int kNumMaxinumNumBlocks = 4096;
 | 
				
			||||
 | 
				
			||||
#define CUDA_1D_KERNEL_LOOP(i, n)                              \
 | 
				
			||||
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
 | 
				
			||||
       i += blockDim.x * gridDim.x)
 | 
				
			||||
 | 
				
			||||
int const BBoxSize = 4;
 | 
				
			||||
 | 
				
			||||
struct RangeInitFunctor {
 | 
				
			||||
  int start_;
 | 
				
			||||
  int delta_;
 | 
				
			||||
  int* out_;
 | 
				
			||||
  __device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
static inline int NumBlocks(const int N) {
 | 
				
			||||
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
 | 
				
			||||
                  kNumMaxinumNumBlocks);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
static inline void TransLoD(const int* length_lod, const int lod_size,
 | 
				
			||||
                            int* offset_lod) {
 | 
				
			||||
  int offset = 0;
 | 
				
			||||
  for (int i = 0; i < lod_size; ++i) {
 | 
				
			||||
    offset_lod[i] = offset;
 | 
				
			||||
    offset += length_lod[i];
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
static __device__ inline T RoIArea(const T* box, bool normalized) {
 | 
				
			||||
  if (box[2] < box[0] || box[3] < box[1]) {
 | 
				
			||||
    // If coordinate values are is invalid
 | 
				
			||||
    // (e.g. xmax < xmin or ymax < ymin), return 0.
 | 
				
			||||
    return static_cast<T>(0.);
 | 
				
			||||
  } else {
 | 
				
			||||
    const T w = box[2] - box[0];
 | 
				
			||||
    const T h = box[3] - box[1];
 | 
				
			||||
    if (normalized) {
 | 
				
			||||
      return w * h;
 | 
				
			||||
    } else {
 | 
				
			||||
      // If coordinate values are not within range [0, 1].
 | 
				
			||||
      return (w + 1) * (h + 1);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
template <class T>
 | 
				
			||||
static __global__ void GPUDistFpnProposalsHelper(
 | 
				
			||||
    const int nthreads, const T* rois, const int lod_size,
 | 
				
			||||
    const int refer_level, const int refer_scale, const int max_level,
 | 
				
			||||
    const int min_level, int* roi_batch_id_data, int* sub_lod_list,
 | 
				
			||||
    int* target_lvls) {
 | 
				
			||||
  CUDA_1D_KERNEL_LOOP(i, nthreads) {
 | 
				
			||||
    const T* offset_roi = rois + i * BBoxSize;
 | 
				
			||||
    int roi_batch_ind = roi_batch_id_data[i];
 | 
				
			||||
    // get the target level of current rois
 | 
				
			||||
    T roi_area = RoIArea(offset_roi, false);
 | 
				
			||||
    T roi_scale = sqrt(roi_area);
 | 
				
			||||
    int tgt_lvl = floor(log2(roi_scale / refer_scale) + refer_level);
 | 
				
			||||
    tgt_lvl = min(max_level, max(tgt_lvl, min_level));
 | 
				
			||||
    target_lvls[i] = tgt_lvl;
 | 
				
			||||
    // compute number of rois in the same batch and same target level
 | 
				
			||||
    platform::CudaAtomicAdd(sub_lod_list + tgt_lvl * lod_size + roi_batch_ind,
 | 
				
			||||
                            1);
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    auto* fpn_rois = ctx.Input<paddle::framework::LoDTensor>("FpnRois");
 | 
				
			||||
 | 
				
			||||
    auto multi_fpn_rois = ctx.MultiOutput<LoDTensor>("MultiFpnRois");
 | 
				
			||||
    auto* restore_index = ctx.Output<Tensor>("RestoreIndex");
 | 
				
			||||
 | 
				
			||||
    const int min_level = ctx.Attr<int>("min_level");
 | 
				
			||||
    const int max_level = ctx.Attr<int>("max_level");
 | 
				
			||||
    const int refer_level = ctx.Attr<int>("refer_level");
 | 
				
			||||
    const int refer_scale = ctx.Attr<int>("refer_scale");
 | 
				
			||||
    int num_level = max_level - min_level + 1;
 | 
				
			||||
 | 
				
			||||
    // check that the fpn_rois is not empty
 | 
				
			||||
    PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
 | 
				
			||||
                      "DistributeFpnProposalsOp need 1 level of LoD");
 | 
				
			||||
 | 
				
			||||
    auto fpn_rois_lod = fpn_rois->lod().back();
 | 
				
			||||
    int lod_size = fpn_rois_lod.size() - 1;
 | 
				
			||||
    int roi_num = fpn_rois_lod[lod_size];
 | 
				
			||||
 | 
				
			||||
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
				
			||||
 | 
				
			||||
    // get batch id by lod in CPU
 | 
				
			||||
    Tensor roi_batch_id_list;
 | 
				
			||||
    roi_batch_id_list.Resize({roi_num});
 | 
				
			||||
    int* roi_batch_id_data =
 | 
				
			||||
        roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
 | 
				
			||||
    for (int n = 0; n < lod_size; ++n) {
 | 
				
			||||
      for (size_t i = fpn_rois_lod[n]; i < fpn_rois_lod[n + 1]; ++i) {
 | 
				
			||||
        roi_batch_id_data[i] = n;
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
    // copy batch id list to GPU
 | 
				
			||||
    Tensor roi_batch_id_list_gpu;
 | 
				
			||||
    framework::TensorCopySync(roi_batch_id_list, dev_ctx.GetPlace(),
 | 
				
			||||
                              &roi_batch_id_list_gpu);
 | 
				
			||||
 | 
				
			||||
    Tensor sub_lod_list;
 | 
				
			||||
    sub_lod_list.Resize({num_level, lod_size});
 | 
				
			||||
    int* sub_lod_list_data = sub_lod_list.mutable_data<int>(dev_ctx.GetPlace());
 | 
				
			||||
    Tensor target_lvls;
 | 
				
			||||
    target_lvls.Resize({roi_num});
 | 
				
			||||
    int* target_lvls_data = target_lvls.mutable_data<int>(dev_ctx.GetPlace());
 | 
				
			||||
 | 
				
			||||
    int blocks = NumBlocks(roi_num);
 | 
				
			||||
    int threads = kNumCUDAThreads;
 | 
				
			||||
 | 
				
			||||
    // get target levels and sub_lod list
 | 
				
			||||
    GPUDistFpnProposalsHelper<T><<<blocks, threads>>>(
 | 
				
			||||
        roi_num, fpn_rois->data<T>(), lod_size, refer_level, refer_scale,
 | 
				
			||||
        max_level, min_level, roi_batch_id_list_gpu.data<int>(),
 | 
				
			||||
        sub_lod_list_data, target_lvls_data);
 | 
				
			||||
 | 
				
			||||
    Tensor index_in_t;
 | 
				
			||||
    int* idx_in = index_in_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
 | 
				
			||||
    platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, roi_num);
 | 
				
			||||
    for_range(RangeInitFunctor{0, 1, idx_in});
 | 
				
			||||
 | 
				
			||||
    Tensor keys_out_t;
 | 
				
			||||
    int* keys_out = keys_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
 | 
				
			||||
    Tensor index_out_t;
 | 
				
			||||
    int* idx_out = index_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
 | 
				
			||||
 | 
				
			||||
    // Determine temporary device storage requirements
 | 
				
			||||
    size_t temp_storage_bytes = 0;
 | 
				
			||||
    cub::DeviceRadixSort::SortPairsDescending<int, int>(
 | 
				
			||||
        nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in,
 | 
				
			||||
        idx_out, roi_num);
 | 
				
			||||
    // Allocate temporary storage
 | 
				
			||||
    auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
 | 
				
			||||
    auto d_temp_storage = memory::Alloc(place, temp_storage_bytes,
 | 
				
			||||
                                        memory::Allocator::kScratchpad);
 | 
				
			||||
 | 
				
			||||
    // Run sorting operation
 | 
				
			||||
    // sort target level to get corresponding index
 | 
				
			||||
    cub::DeviceRadixSort::SortPairsDescending<int, int>(
 | 
				
			||||
        d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out,
 | 
				
			||||
        idx_in, idx_out, roi_num);
 | 
				
			||||
 | 
				
			||||
    int* restore_idx_data =
 | 
				
			||||
        restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace());
 | 
				
			||||
    // sort current index to get restore index
 | 
				
			||||
    cub::DeviceRadixSort::SortPairsDescending<int, int>(
 | 
				
			||||
        d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in,
 | 
				
			||||
        restore_idx_data, roi_num);
 | 
				
			||||
 | 
				
			||||
    Tensor offset_lod;
 | 
				
			||||
    int* offset_lod_data =
 | 
				
			||||
        offset_lod.mutable_data<int>({lod_size + 1}, dev_ctx.GetPlace());
 | 
				
			||||
    for (int i = 0; i < num_level; ++i) {
 | 
				
			||||
      Tensor sub_lod = sub_lod_list.Slice(i, i + 1);
 | 
				
			||||
      int* sub_lod_data = sub_lod.data<int>();
 | 
				
			||||
      // transfer length-based lod to offset-based lod
 | 
				
			||||
      TransLoD(sub_lod_data, lod_size + 1, offset_lod_data);
 | 
				
			||||
      int sub_rois_num = offset_lod_data[lod_size];
 | 
				
			||||
      Tensor sub_idx = index_out_t.Slice(0, sub_rois_num);
 | 
				
			||||
 | 
				
			||||
      multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
 | 
				
			||||
                                         dev_ctx.GetPlace());
 | 
				
			||||
 | 
				
			||||
      GPUGather<T>(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]);
 | 
				
			||||
      framework::LoD lod;
 | 
				
			||||
      std::vector<size_t> offset;
 | 
				
			||||
      memory::Copy(platform::CPUPlace(), offset.data(), place, offset_lod_data,
 | 
				
			||||
                   sizeof(int) * (lod_size + 1), 0);
 | 
				
			||||
      lod.emplace_back(offset);
 | 
				
			||||
      multi_fpn_rois[i]->set_lod(lod);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(
 | 
				
			||||
    distribute_fpn_proposals,
 | 
				
			||||
    ops::GPUDistributeFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                           float>,
 | 
				
			||||
    ops::GPUDistributeFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                           double>);
 | 
				
			||||
@ -0,0 +1,147 @@
 | 
				
			||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <algorithm>
 | 
				
			||||
#include <cmath>
 | 
				
			||||
#include <cstring>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/operators/detail/safe_ref.h"
 | 
				
			||||
#include "paddle/fluid/operators/gather.h"
 | 
				
			||||
#include "paddle/fluid/operators/math/math_function.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
const int kBoxDim = 4;
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
static inline T BBoxArea(const T* box, bool normalized) {
 | 
				
			||||
  if (box[2] < box[0] || box[3] < box[1]) {
 | 
				
			||||
    // If coordinate values are is invalid
 | 
				
			||||
    // (e.g. xmax < xmin or ymax < ymin), return 0.
 | 
				
			||||
    return static_cast<T>(0.);
 | 
				
			||||
  } else {
 | 
				
			||||
    const T w = box[2] - box[0];
 | 
				
			||||
    const T h = box[3] - box[1];
 | 
				
			||||
    if (normalized) {
 | 
				
			||||
      return w * h;
 | 
				
			||||
    } else {
 | 
				
			||||
      // If coordinate values are not within range [0, 1].
 | 
				
			||||
      return (w + 1) * (h + 1);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& context) const override {
 | 
				
			||||
    auto* fpn_rois = context.Input<paddle::framework::LoDTensor>("FpnRois");
 | 
				
			||||
 | 
				
			||||
    auto multi_fpn_rois =
 | 
				
			||||
        context.MultiOutput<paddle::framework::LoDTensor>("MultiFpnRois");
 | 
				
			||||
 | 
				
			||||
    auto* restore_index =
 | 
				
			||||
        context.Output<paddle::framework::Tensor>("RestoreIndex");
 | 
				
			||||
 | 
				
			||||
    const int min_level = context.Attr<int>("min_level");
 | 
				
			||||
    const int max_level = context.Attr<int>("max_level");
 | 
				
			||||
    const int refer_level = context.Attr<int>("refer_level");
 | 
				
			||||
    const int refer_scale = context.Attr<int>("refer_scale");
 | 
				
			||||
    const int num_level = max_level - min_level + 1;
 | 
				
			||||
 | 
				
			||||
    // check that the fpn_rois is not empty
 | 
				
			||||
    PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
 | 
				
			||||
                      "DistributeFpnProposalsOp need 1 level of LoD");
 | 
				
			||||
 | 
				
			||||
    auto fpn_rois_lod = fpn_rois->lod().back();
 | 
				
			||||
    int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
 | 
				
			||||
    std::vector<int> target_level;
 | 
				
			||||
    // std::vector<int> target_level(fpn_rois_num, -1);
 | 
				
			||||
    // record the number of rois in each level
 | 
				
			||||
    std::vector<int> num_rois_level(num_level, 0);
 | 
				
			||||
    std::vector<int> num_rois_level_integral(num_level + 1, 0);
 | 
				
			||||
    for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
 | 
				
			||||
      Tensor fpn_rois_slice =
 | 
				
			||||
          fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
 | 
				
			||||
      const T* rois_data = fpn_rois_slice.data<T>();
 | 
				
			||||
      for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
 | 
				
			||||
        // get the target level of current rois
 | 
				
			||||
        T roi_scale = std::sqrt(BBoxArea(rois_data, false));
 | 
				
			||||
        int tgt_lvl =
 | 
				
			||||
            std::floor(std::log2(roi_scale / refer_scale) + refer_level);
 | 
				
			||||
        tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
 | 
				
			||||
        target_level.push_back(tgt_lvl);
 | 
				
			||||
        num_rois_level[tgt_lvl - min_level]++;
 | 
				
			||||
        rois_data += kBoxDim;
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
    // define the output rois
 | 
				
			||||
    // pointer which point to each level fpn rois
 | 
				
			||||
    std::vector<T*> multi_fpn_rois_data(num_level);
 | 
				
			||||
    // lod0 which will record the offset information of each level rois
 | 
				
			||||
    std::vector<std::vector<size_t>> multi_fpn_rois_lod0;
 | 
				
			||||
    for (int i = 0; i < num_level; ++i) {
 | 
				
			||||
      // allocate memory for each level rois
 | 
				
			||||
      multi_fpn_rois[i]->mutable_data<T>({num_rois_level[i], kBoxDim},
 | 
				
			||||
                                         context.GetPlace());
 | 
				
			||||
      multi_fpn_rois_data[i] = multi_fpn_rois[i]->data<T>();
 | 
				
			||||
      std::vector<size_t> lod0(1, 0);
 | 
				
			||||
      multi_fpn_rois_lod0.push_back(lod0);
 | 
				
			||||
      // statistic start point for each level rois
 | 
				
			||||
      num_rois_level_integral[i + 1] =
 | 
				
			||||
          num_rois_level_integral[i] + num_rois_level[i];
 | 
				
			||||
    }
 | 
				
			||||
    restore_index->mutable_data<int>({1, fpn_rois_num}, context.GetPlace());
 | 
				
			||||
    int* restore_index_data = restore_index->data<int>();
 | 
				
			||||
    std::vector<int> restore_index_inter(fpn_rois_num, -1);
 | 
				
			||||
    // distribute the rois into different fpn level by target level
 | 
				
			||||
    for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
 | 
				
			||||
      Tensor fpn_rois_slice =
 | 
				
			||||
          fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
 | 
				
			||||
      const T* rois_data = fpn_rois_slice.data<T>();
 | 
				
			||||
      size_t cur_offset = fpn_rois_lod[i];
 | 
				
			||||
      // std::vector<size_t > lod_offset[num_level];
 | 
				
			||||
      for (int j = 0; j < num_level; j++) {
 | 
				
			||||
        multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]);
 | 
				
			||||
      }
 | 
				
			||||
      for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
 | 
				
			||||
        int lvl = target_level[cur_offset + j];
 | 
				
			||||
        memcpy(multi_fpn_rois_data[lvl - min_level], rois_data,
 | 
				
			||||
               kBoxDim * sizeof(T));
 | 
				
			||||
        multi_fpn_rois_data[lvl - min_level] += kBoxDim;
 | 
				
			||||
        int index_in_shuffle = num_rois_level_integral[lvl - min_level] +
 | 
				
			||||
                               multi_fpn_rois_lod0[lvl - min_level][i + 1];
 | 
				
			||||
        restore_index_inter[index_in_shuffle] = cur_offset + j;
 | 
				
			||||
        multi_fpn_rois_lod0[lvl - min_level][i + 1]++;
 | 
				
			||||
        rois_data += kBoxDim;
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
    for (int i = 0; i < fpn_rois_num; ++i) {
 | 
				
			||||
      restore_index_data[restore_index_inter[i]] = i;
 | 
				
			||||
    }
 | 
				
			||||
    // merge lod information into LoDTensor
 | 
				
			||||
    for (int i = 0; i < num_level; ++i) {
 | 
				
			||||
      framework::LoD lod;
 | 
				
			||||
      lod.emplace_back(multi_fpn_rois_lod0[i]);
 | 
				
			||||
      multi_fpn_rois[i]->set_lod(lod);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,117 @@
 | 
				
			||||
#    Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy as np
 | 
				
			||||
import math
 | 
				
			||||
import sys
 | 
				
			||||
from op_test import OpTest
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestDistributeFPNProposalsOp(OpTest):
 | 
				
			||||
    def set_data(self):
 | 
				
			||||
        self.init_test_case()
 | 
				
			||||
        self.make_rois()
 | 
				
			||||
        self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
 | 
				
			||||
        self.inputs = {'FpnRois': (self.rois[:, 1:5], self.rois_lod)}
 | 
				
			||||
        self.attrs = {
 | 
				
			||||
            'max_level': self.roi_max_level,
 | 
				
			||||
            'min_level': self.roi_min_level,
 | 
				
			||||
            'refer_scale': self.canonical_scale,
 | 
				
			||||
            'refer_level': self.canonical_level
 | 
				
			||||
        }
 | 
				
			||||
        output = [('out%d' % i, self.rois_fpn[i])
 | 
				
			||||
                  for i in range(len(self.rois_fpn))]
 | 
				
			||||
        self.outputs = {
 | 
				
			||||
            'MultiFpnRois': output,
 | 
				
			||||
            'RestoreIndex': self.rois_idx_restore
 | 
				
			||||
        }
 | 
				
			||||
 | 
				
			||||
    def init_test_case(self):
 | 
				
			||||
        self.roi_max_level = 5
 | 
				
			||||
        self.roi_min_level = 2
 | 
				
			||||
        self.canonical_scale = 224
 | 
				
			||||
        self.canonical_level = 4
 | 
				
			||||
        self.images_shape = [512, 512]
 | 
				
			||||
 | 
				
			||||
    def boxes_area(self, boxes):
 | 
				
			||||
        w = (boxes[:, 2] - boxes[:, 0] + 1)
 | 
				
			||||
        h = (boxes[:, 3] - boxes[:, 1] + 1)
 | 
				
			||||
        areas = w * h
 | 
				
			||||
        assert np.all(areas >= 0), 'Negative areas founds'
 | 
				
			||||
        return areas
 | 
				
			||||
 | 
				
			||||
    def map_rois_to_fpn_levels(self, rois, lvl_min, lvl_max):
 | 
				
			||||
        s = np.sqrt(self.boxes_area(rois))
 | 
				
			||||
        s0 = self.canonical_scale
 | 
				
			||||
        lvl0 = self.canonical_level
 | 
				
			||||
        target_lvls = np.floor(lvl0 + np.log2(s / s0 + 1e-6))
 | 
				
			||||
        target_lvls = np.clip(target_lvls, lvl_min, lvl_max)
 | 
				
			||||
        return target_lvls
 | 
				
			||||
 | 
				
			||||
    def get_sub_lod(self, sub_lvl):
 | 
				
			||||
        sub_lod = []
 | 
				
			||||
        max_batch_id = sub_lvl[-1]
 | 
				
			||||
        for i in range(max_batch_id.astype(np.int32) + 1):
 | 
				
			||||
            sub_lod.append(np.where(sub_lvl == i)[0].size)
 | 
				
			||||
        return sub_lod
 | 
				
			||||
 | 
				
			||||
    def add_multilevel_roi(self, rois, target_lvls, lvl_min, lvl_max):
 | 
				
			||||
        rois_idx_order = np.empty((0, ))
 | 
				
			||||
        rois_fpn = []
 | 
				
			||||
        for lvl in range(lvl_min, lvl_max + 1):
 | 
				
			||||
            idx_lvl = np.where(target_lvls == lvl)[0]
 | 
				
			||||
            if len(idx_lvl) == 0:
 | 
				
			||||
                rois_fpn.append((np.empty(shape=(0, 4)), [[0, 0]]))
 | 
				
			||||
                continue
 | 
				
			||||
            sub_lod = self.get_sub_lod(rois[idx_lvl, 0])
 | 
				
			||||
            rois_fpn.append((rois[idx_lvl, 1:], [sub_lod]))
 | 
				
			||||
            rois_idx_order = np.concatenate((rois_idx_order, idx_lvl))
 | 
				
			||||
        rois_idx_restore = np.argsort(rois_idx_order).astype(
 | 
				
			||||
            np.int32, copy=False)
 | 
				
			||||
        return rois_fpn, rois_idx_restore
 | 
				
			||||
 | 
				
			||||
    def calc_rois_distribute(self):
 | 
				
			||||
        lvl_min = self.roi_min_level
 | 
				
			||||
        lvl_max = self.roi_max_level
 | 
				
			||||
        target_lvls = self.map_rois_to_fpn_levels(self.rois[:, 1:5], lvl_min,
 | 
				
			||||
                                                  lvl_max)
 | 
				
			||||
        rois_fpn, rois_idx_restore = self.add_multilevel_roi(
 | 
				
			||||
            self.rois, target_lvls, lvl_min, lvl_max)
 | 
				
			||||
        return rois_fpn, rois_idx_restore
 | 
				
			||||
 | 
				
			||||
    def make_rois(self):
 | 
				
			||||
        self.rois_lod = [[100, 200]]
 | 
				
			||||
        rois = []
 | 
				
			||||
        lod = self.rois_lod[0]
 | 
				
			||||
        bno = 0
 | 
				
			||||
        for roi_num in lod:
 | 
				
			||||
            for i in range(roi_num):
 | 
				
			||||
                xywh = np.random.rand(4)
 | 
				
			||||
                xy1 = xywh[0:2] * 20
 | 
				
			||||
                wh = xywh[2:4] * (self.images_shape - xy1)
 | 
				
			||||
                xy2 = xy1 + wh
 | 
				
			||||
                roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]]
 | 
				
			||||
                rois.append(roi)
 | 
				
			||||
            bno += 1
 | 
				
			||||
        self.rois = np.array(rois).astype("float32")
 | 
				
			||||
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.op_type = "distribute_fpn_proposals"
 | 
				
			||||
        self.set_data()
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue