Merge pull request #15952 from jerrywgz/fpn_ops
add distribute fpn proposals op, test=developalign_pyramid
commit
b0e3c02410
@ -0,0 +1,93 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("FpnRois"),
|
||||
"Input(FpnRois) shouldn't be null");
|
||||
PADDLE_ENFORCE_GE(
|
||||
ctx->Outputs("MultiFpnRois").size(), 1UL,
|
||||
"Outputs(MultiFpnRois) of DistributeOp should not be empty");
|
||||
size_t min_level = static_cast<size_t>(ctx->Attrs().Get<int>("min_level"));
|
||||
size_t max_level = static_cast<size_t>(ctx->Attrs().Get<int>("max_level"));
|
||||
PADDLE_ENFORCE_GE(max_level, min_level,
|
||||
"max_level must not lower than min_level");
|
||||
// Set the output shape
|
||||
size_t num_out_rois = max_level - min_level + 1;
|
||||
std::vector<framework::DDim> outs_dims;
|
||||
outs_dims.reserve(num_out_rois);
|
||||
for (size_t i = 0; i < num_out_rois; ++i) {
|
||||
framework::DDim out_dim = {-1, 4};
|
||||
outs_dims.push_back(out_dim);
|
||||
}
|
||||
ctx->SetOutputsDim("MultiFpnRois", outs_dims);
|
||||
ctx->SetOutputDim("RestoreIndex", {1, -1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois"));
|
||||
return framework::OpKernelType(data_type, platform::CPUPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class DistributeFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("FpnRois", "(LoDTensor) The rois at all levels in shape (-1, 4)");
|
||||
AddOutput("MultiFpnRois", "(LoDTensor) Output with distribute operator")
|
||||
.AsDuplicable();
|
||||
AddOutput("RestoreIndex",
|
||||
"(Tensor) An array of positive number which is "
|
||||
"used to restore the order of FpnRois");
|
||||
AddAttr<int>("min_level",
|
||||
"The lowest level of FPN layer where the"
|
||||
" proposals come from");
|
||||
AddAttr<int>("max_level",
|
||||
"The highest level of FPN layer where the"
|
||||
" proposals come from");
|
||||
AddAttr<int>("refer_level",
|
||||
"The referring level of FPN layer with"
|
||||
" specified scale");
|
||||
AddAttr<int>("refer_scale",
|
||||
"The referring scale of FPN layer with"
|
||||
" specified level");
|
||||
AddComment(R"DOC(
|
||||
This operator distribute all proposals into different fpn level,
|
||||
with respect to scale of the proposals, the referring scale and
|
||||
the referring level. Besides, to restore the order of proposals,
|
||||
we return an array which indicate the original index of rois in
|
||||
current proposals.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(distribute_fpn_proposals, ops::DistributeFpnProposalsOp,
|
||||
ops::DistributeFpnProposalsOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
|
||||
ops::DistributeFpnProposalsOpKernel<float>,
|
||||
ops::DistributeFpnProposalsOpKernel<double>);
|
@ -0,0 +1,221 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <paddle/fluid/memory/allocation/allocator.h>
|
||||
#include "cub/cub.cuh"
|
||||
#include "paddle/fluid/memory/memcpy.h"
|
||||
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
|
||||
#include "paddle/fluid/operators/gather.cu.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
static constexpr int kNumCUDAThreads = 512;
|
||||
static constexpr int kNumMaxinumNumBlocks = 4096;
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
int const BBoxSize = 4;
|
||||
|
||||
struct RangeInitFunctor {
|
||||
int start_;
|
||||
int delta_;
|
||||
int* out_;
|
||||
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
|
||||
};
|
||||
|
||||
static inline int NumBlocks(const int N) {
|
||||
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
|
||||
kNumMaxinumNumBlocks);
|
||||
}
|
||||
|
||||
static inline void TransLoD(const int* length_lod, const int lod_size,
|
||||
int* offset_lod) {
|
||||
int offset = 0;
|
||||
for (int i = 0; i < lod_size; ++i) {
|
||||
offset_lod[i] = offset;
|
||||
offset += length_lod[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ inline T RoIArea(const T* box, bool normalized) {
|
||||
if (box[2] < box[0] || box[3] < box[1]) {
|
||||
// If coordinate values are is invalid
|
||||
// (e.g. xmax < xmin or ymax < ymin), return 0.
|
||||
return static_cast<T>(0.);
|
||||
} else {
|
||||
const T w = box[2] - box[0];
|
||||
const T h = box[3] - box[1];
|
||||
if (normalized) {
|
||||
return w * h;
|
||||
} else {
|
||||
// If coordinate values are not within range [0, 1].
|
||||
return (w + 1) * (h + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static __global__ void GPUDistFpnProposalsHelper(
|
||||
const int nthreads, const T* rois, const int lod_size,
|
||||
const int refer_level, const int refer_scale, const int max_level,
|
||||
const int min_level, int* roi_batch_id_data, int* sub_lod_list,
|
||||
int* target_lvls) {
|
||||
CUDA_1D_KERNEL_LOOP(i, nthreads) {
|
||||
const T* offset_roi = rois + i * BBoxSize;
|
||||
int roi_batch_ind = roi_batch_id_data[i];
|
||||
// get the target level of current rois
|
||||
T roi_area = RoIArea(offset_roi, false);
|
||||
T roi_scale = sqrt(roi_area);
|
||||
int tgt_lvl = floor(log2(roi_scale / refer_scale) + refer_level);
|
||||
tgt_lvl = min(max_level, max(tgt_lvl, min_level));
|
||||
target_lvls[i] = tgt_lvl;
|
||||
// compute number of rois in the same batch and same target level
|
||||
platform::CudaAtomicAdd(sub_lod_list + tgt_lvl * lod_size + roi_batch_ind,
|
||||
1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* fpn_rois = ctx.Input<paddle::framework::LoDTensor>("FpnRois");
|
||||
|
||||
auto multi_fpn_rois = ctx.MultiOutput<LoDTensor>("MultiFpnRois");
|
||||
auto* restore_index = ctx.Output<Tensor>("RestoreIndex");
|
||||
|
||||
const int min_level = ctx.Attr<int>("min_level");
|
||||
const int max_level = ctx.Attr<int>("max_level");
|
||||
const int refer_level = ctx.Attr<int>("refer_level");
|
||||
const int refer_scale = ctx.Attr<int>("refer_scale");
|
||||
int num_level = max_level - min_level + 1;
|
||||
|
||||
// check that the fpn_rois is not empty
|
||||
PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
|
||||
"DistributeFpnProposalsOp need 1 level of LoD");
|
||||
|
||||
auto fpn_rois_lod = fpn_rois->lod().back();
|
||||
int lod_size = fpn_rois_lod.size() - 1;
|
||||
int roi_num = fpn_rois_lod[lod_size];
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
// get batch id by lod in CPU
|
||||
Tensor roi_batch_id_list;
|
||||
roi_batch_id_list.Resize({roi_num});
|
||||
int* roi_batch_id_data =
|
||||
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
|
||||
for (int n = 0; n < lod_size; ++n) {
|
||||
for (size_t i = fpn_rois_lod[n]; i < fpn_rois_lod[n + 1]; ++i) {
|
||||
roi_batch_id_data[i] = n;
|
||||
}
|
||||
}
|
||||
// copy batch id list to GPU
|
||||
Tensor roi_batch_id_list_gpu;
|
||||
framework::TensorCopySync(roi_batch_id_list, dev_ctx.GetPlace(),
|
||||
&roi_batch_id_list_gpu);
|
||||
|
||||
Tensor sub_lod_list;
|
||||
sub_lod_list.Resize({num_level, lod_size});
|
||||
int* sub_lod_list_data = sub_lod_list.mutable_data<int>(dev_ctx.GetPlace());
|
||||
Tensor target_lvls;
|
||||
target_lvls.Resize({roi_num});
|
||||
int* target_lvls_data = target_lvls.mutable_data<int>(dev_ctx.GetPlace());
|
||||
|
||||
int blocks = NumBlocks(roi_num);
|
||||
int threads = kNumCUDAThreads;
|
||||
|
||||
// get target levels and sub_lod list
|
||||
GPUDistFpnProposalsHelper<T><<<blocks, threads>>>(
|
||||
roi_num, fpn_rois->data<T>(), lod_size, refer_level, refer_scale,
|
||||
max_level, min_level, roi_batch_id_list_gpu.data<int>(),
|
||||
sub_lod_list_data, target_lvls_data);
|
||||
|
||||
Tensor index_in_t;
|
||||
int* idx_in = index_in_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
|
||||
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, roi_num);
|
||||
for_range(RangeInitFunctor{0, 1, idx_in});
|
||||
|
||||
Tensor keys_out_t;
|
||||
int* keys_out = keys_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
|
||||
Tensor index_out_t;
|
||||
int* idx_out = index_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
|
||||
|
||||
// Determine temporary device storage requirements
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceRadixSort::SortPairsDescending<int, int>(
|
||||
nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in,
|
||||
idx_out, roi_num);
|
||||
// Allocate temporary storage
|
||||
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
|
||||
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes,
|
||||
memory::Allocator::kScratchpad);
|
||||
|
||||
// Run sorting operation
|
||||
// sort target level to get corresponding index
|
||||
cub::DeviceRadixSort::SortPairsDescending<int, int>(
|
||||
d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out,
|
||||
idx_in, idx_out, roi_num);
|
||||
|
||||
int* restore_idx_data =
|
||||
restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace());
|
||||
// sort current index to get restore index
|
||||
cub::DeviceRadixSort::SortPairsDescending<int, int>(
|
||||
d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in,
|
||||
restore_idx_data, roi_num);
|
||||
|
||||
Tensor offset_lod;
|
||||
int* offset_lod_data =
|
||||
offset_lod.mutable_data<int>({lod_size + 1}, dev_ctx.GetPlace());
|
||||
for (int i = 0; i < num_level; ++i) {
|
||||
Tensor sub_lod = sub_lod_list.Slice(i, i + 1);
|
||||
int* sub_lod_data = sub_lod.data<int>();
|
||||
// transfer length-based lod to offset-based lod
|
||||
TransLoD(sub_lod_data, lod_size + 1, offset_lod_data);
|
||||
int sub_rois_num = offset_lod_data[lod_size];
|
||||
Tensor sub_idx = index_out_t.Slice(0, sub_rois_num);
|
||||
|
||||
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
|
||||
dev_ctx.GetPlace());
|
||||
|
||||
GPUGather<T>(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]);
|
||||
framework::LoD lod;
|
||||
std::vector<size_t> offset;
|
||||
memory::Copy(platform::CPUPlace(), offset.data(), place, offset_lod_data,
|
||||
sizeof(int) * (lod_size + 1), 0);
|
||||
lod.emplace_back(offset);
|
||||
multi_fpn_rois[i]->set_lod(lod);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
distribute_fpn_proposals,
|
||||
ops::GPUDistributeFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
|
||||
float>,
|
||||
ops::GPUDistributeFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
|
||||
double>);
|
@ -0,0 +1,147 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
#include "paddle/fluid/operators/gather.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
const int kBoxDim = 4;
|
||||
|
||||
template <typename T>
|
||||
static inline T BBoxArea(const T* box, bool normalized) {
|
||||
if (box[2] < box[0] || box[3] < box[1]) {
|
||||
// If coordinate values are is invalid
|
||||
// (e.g. xmax < xmin or ymax < ymin), return 0.
|
||||
return static_cast<T>(0.);
|
||||
} else {
|
||||
const T w = box[2] - box[0];
|
||||
const T h = box[3] - box[1];
|
||||
if (normalized) {
|
||||
return w * h;
|
||||
} else {
|
||||
// If coordinate values are not within range [0, 1].
|
||||
return (w + 1) * (h + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* fpn_rois = context.Input<paddle::framework::LoDTensor>("FpnRois");
|
||||
|
||||
auto multi_fpn_rois =
|
||||
context.MultiOutput<paddle::framework::LoDTensor>("MultiFpnRois");
|
||||
|
||||
auto* restore_index =
|
||||
context.Output<paddle::framework::Tensor>("RestoreIndex");
|
||||
|
||||
const int min_level = context.Attr<int>("min_level");
|
||||
const int max_level = context.Attr<int>("max_level");
|
||||
const int refer_level = context.Attr<int>("refer_level");
|
||||
const int refer_scale = context.Attr<int>("refer_scale");
|
||||
const int num_level = max_level - min_level + 1;
|
||||
|
||||
// check that the fpn_rois is not empty
|
||||
PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
|
||||
"DistributeFpnProposalsOp need 1 level of LoD");
|
||||
|
||||
auto fpn_rois_lod = fpn_rois->lod().back();
|
||||
int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
|
||||
std::vector<int> target_level;
|
||||
// std::vector<int> target_level(fpn_rois_num, -1);
|
||||
// record the number of rois in each level
|
||||
std::vector<int> num_rois_level(num_level, 0);
|
||||
std::vector<int> num_rois_level_integral(num_level + 1, 0);
|
||||
for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
|
||||
Tensor fpn_rois_slice =
|
||||
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
|
||||
const T* rois_data = fpn_rois_slice.data<T>();
|
||||
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
|
||||
// get the target level of current rois
|
||||
T roi_scale = std::sqrt(BBoxArea(rois_data, false));
|
||||
int tgt_lvl =
|
||||
std::floor(std::log2(roi_scale / refer_scale) + refer_level);
|
||||
tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
|
||||
target_level.push_back(tgt_lvl);
|
||||
num_rois_level[tgt_lvl - min_level]++;
|
||||
rois_data += kBoxDim;
|
||||
}
|
||||
}
|
||||
// define the output rois
|
||||
// pointer which point to each level fpn rois
|
||||
std::vector<T*> multi_fpn_rois_data(num_level);
|
||||
// lod0 which will record the offset information of each level rois
|
||||
std::vector<std::vector<size_t>> multi_fpn_rois_lod0;
|
||||
for (int i = 0; i < num_level; ++i) {
|
||||
// allocate memory for each level rois
|
||||
multi_fpn_rois[i]->mutable_data<T>({num_rois_level[i], kBoxDim},
|
||||
context.GetPlace());
|
||||
multi_fpn_rois_data[i] = multi_fpn_rois[i]->data<T>();
|
||||
std::vector<size_t> lod0(1, 0);
|
||||
multi_fpn_rois_lod0.push_back(lod0);
|
||||
// statistic start point for each level rois
|
||||
num_rois_level_integral[i + 1] =
|
||||
num_rois_level_integral[i] + num_rois_level[i];
|
||||
}
|
||||
restore_index->mutable_data<int>({1, fpn_rois_num}, context.GetPlace());
|
||||
int* restore_index_data = restore_index->data<int>();
|
||||
std::vector<int> restore_index_inter(fpn_rois_num, -1);
|
||||
// distribute the rois into different fpn level by target level
|
||||
for (int i = 0; i < fpn_rois_lod.size() - 1; ++i) {
|
||||
Tensor fpn_rois_slice =
|
||||
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
|
||||
const T* rois_data = fpn_rois_slice.data<T>();
|
||||
size_t cur_offset = fpn_rois_lod[i];
|
||||
// std::vector<size_t > lod_offset[num_level];
|
||||
for (int j = 0; j < num_level; j++) {
|
||||
multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]);
|
||||
}
|
||||
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
|
||||
int lvl = target_level[cur_offset + j];
|
||||
memcpy(multi_fpn_rois_data[lvl - min_level], rois_data,
|
||||
kBoxDim * sizeof(T));
|
||||
multi_fpn_rois_data[lvl - min_level] += kBoxDim;
|
||||
int index_in_shuffle = num_rois_level_integral[lvl - min_level] +
|
||||
multi_fpn_rois_lod0[lvl - min_level][i + 1];
|
||||
restore_index_inter[index_in_shuffle] = cur_offset + j;
|
||||
multi_fpn_rois_lod0[lvl - min_level][i + 1]++;
|
||||
rois_data += kBoxDim;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < fpn_rois_num; ++i) {
|
||||
restore_index_data[restore_index_inter[i]] = i;
|
||||
}
|
||||
// merge lod information into LoDTensor
|
||||
for (int i = 0; i < num_level; ++i) {
|
||||
framework::LoD lod;
|
||||
lod.emplace_back(multi_fpn_rois_lod0[i]);
|
||||
multi_fpn_rois[i]->set_lod(lod);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,117 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import math
|
||||
import sys
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestDistributeFPNProposalsOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_case()
|
||||
self.make_rois()
|
||||
self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
|
||||
self.inputs = {'FpnRois': (self.rois[:, 1:5], self.rois_lod)}
|
||||
self.attrs = {
|
||||
'max_level': self.roi_max_level,
|
||||
'min_level': self.roi_min_level,
|
||||
'refer_scale': self.canonical_scale,
|
||||
'refer_level': self.canonical_level
|
||||
}
|
||||
output = [('out%d' % i, self.rois_fpn[i])
|
||||
for i in range(len(self.rois_fpn))]
|
||||
self.outputs = {
|
||||
'MultiFpnRois': output,
|
||||
'RestoreIndex': self.rois_idx_restore
|
||||
}
|
||||
|
||||
def init_test_case(self):
|
||||
self.roi_max_level = 5
|
||||
self.roi_min_level = 2
|
||||
self.canonical_scale = 224
|
||||
self.canonical_level = 4
|
||||
self.images_shape = [512, 512]
|
||||
|
||||
def boxes_area(self, boxes):
|
||||
w = (boxes[:, 2] - boxes[:, 0] + 1)
|
||||
h = (boxes[:, 3] - boxes[:, 1] + 1)
|
||||
areas = w * h
|
||||
assert np.all(areas >= 0), 'Negative areas founds'
|
||||
return areas
|
||||
|
||||
def map_rois_to_fpn_levels(self, rois, lvl_min, lvl_max):
|
||||
s = np.sqrt(self.boxes_area(rois))
|
||||
s0 = self.canonical_scale
|
||||
lvl0 = self.canonical_level
|
||||
target_lvls = np.floor(lvl0 + np.log2(s / s0 + 1e-6))
|
||||
target_lvls = np.clip(target_lvls, lvl_min, lvl_max)
|
||||
return target_lvls
|
||||
|
||||
def get_sub_lod(self, sub_lvl):
|
||||
sub_lod = []
|
||||
max_batch_id = sub_lvl[-1]
|
||||
for i in range(max_batch_id.astype(np.int32) + 1):
|
||||
sub_lod.append(np.where(sub_lvl == i)[0].size)
|
||||
return sub_lod
|
||||
|
||||
def add_multilevel_roi(self, rois, target_lvls, lvl_min, lvl_max):
|
||||
rois_idx_order = np.empty((0, ))
|
||||
rois_fpn = []
|
||||
for lvl in range(lvl_min, lvl_max + 1):
|
||||
idx_lvl = np.where(target_lvls == lvl)[0]
|
||||
if len(idx_lvl) == 0:
|
||||
rois_fpn.append((np.empty(shape=(0, 4)), [[0, 0]]))
|
||||
continue
|
||||
sub_lod = self.get_sub_lod(rois[idx_lvl, 0])
|
||||
rois_fpn.append((rois[idx_lvl, 1:], [sub_lod]))
|
||||
rois_idx_order = np.concatenate((rois_idx_order, idx_lvl))
|
||||
rois_idx_restore = np.argsort(rois_idx_order).astype(
|
||||
np.int32, copy=False)
|
||||
return rois_fpn, rois_idx_restore
|
||||
|
||||
def calc_rois_distribute(self):
|
||||
lvl_min = self.roi_min_level
|
||||
lvl_max = self.roi_max_level
|
||||
target_lvls = self.map_rois_to_fpn_levels(self.rois[:, 1:5], lvl_min,
|
||||
lvl_max)
|
||||
rois_fpn, rois_idx_restore = self.add_multilevel_roi(
|
||||
self.rois, target_lvls, lvl_min, lvl_max)
|
||||
return rois_fpn, rois_idx_restore
|
||||
|
||||
def make_rois(self):
|
||||
self.rois_lod = [[100, 200]]
|
||||
rois = []
|
||||
lod = self.rois_lod[0]
|
||||
bno = 0
|
||||
for roi_num in lod:
|
||||
for i in range(roi_num):
|
||||
xywh = np.random.rand(4)
|
||||
xy1 = xywh[0:2] * 20
|
||||
wh = xywh[2:4] * (self.images_shape - xy1)
|
||||
xy2 = xy1 + wh
|
||||
roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]]
|
||||
rois.append(roi)
|
||||
bno += 1
|
||||
self.rois = np.array(rois).astype("float32")
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "distribute_fpn_proposals"
|
||||
self.set_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
Loading…
Reference in new issue