merge empty lod tensor, test=develop (#19228)

* merge_empty_lod_tensor, test=develop

* fix multiclass_nms, test=develop

* refine API.spec, test=develop

* add unittest case for fetch, test=develop

* add lod tensor test, test=develop

* return index for multiclass_nms, test=develop

* add api for multiclass_nms2

* update API.spc, test=develop

* refine api doc, test=develop

* fix test_detection.py, test=develop

* polish code, test=develop

* add more unittest case, test=develop
new_fix
wangguanzhong 6 years ago committed by GitHub
parent c6756ed225
commit 25dcd74d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -396,7 +396,7 @@ paddle.fluid.layers.density_prior_box (ArgSpec(args=['input', 'image', 'densitie
paddle.fluid.layers.multi_box_head (ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False)), ('document', 'fd58078fdfffd899b91f992ba224628f'))
paddle.fluid.layers.bipartite_match (ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '080ce0d54d3f1950ad5a3a8e5ae529e9'))
paddle.fluid.layers.target_assign (ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'e9685f32d21bec8c013626c0254502c5'))
paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0)), ('document', 'efae414c1137c7944d6174dd08c5347a'))
paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta', 'return_index'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0, False)), ('document', '5485bcaceb0cde2695565a2ffd5bbd40'))
paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '8edacd4b9bd02dd68931b9fa6bfe0cbd'))
paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '651d98d51879dfa1bc1cd40391786a41'))
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
@ -412,7 +412,8 @@ paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varar
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gt_score', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '400403175718d5a632402cdae88b01b8'))
paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ed56ff21536ca5c8ad418d0cfaf6a7b9'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9ddee76cb808db83768bf68010e39b2b'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', '51a388c4d067ea93a6a60492db40c7af'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'f6e333d76922c6e564413b4d216c245c'))
paddle.fluid.layers.multiclass_nms2 (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'return_index', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, False, None)), ('document', 'be156186ee7a2ee56ab30b964acb15e5'))
paddle.fluid.layers.retinanet_detection_output (ArgSpec(args=['bboxes', 'scores', 'anchors', 'im_info', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0.05, 1000, 100, 0.3, 1.0)), ('document', '078d28607ce261a0cba2b965a79f6bb8'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6c023b9401214ae387a8b2d92638e5e4'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3619a7847709f5868f5e929065947b38'))

@ -61,12 +61,17 @@ void FetchOpHandle::RunImpl() {
var_handle->name());
auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) {
if (t.IsInitialized() && t.numel() > 0) {
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, &tensors_[i]);
TensorCopy(t, cpu, &tensors_[i]);
#endif
} else {
tensors_[i].ShareDataWith(t);
}
} else {
tensors_[i].ShareDataWith(t);
tensors_[i].clear();
tensors_[i].Resize({0});
}
tensors_[i].set_lod(t.lod());
}

@ -326,17 +326,28 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims();
auto new_type = lod_tensors[0]->type();
proto::VarType::Type new_type = proto::VarType::FP32;
framework::DataLayout new_layout = lod_tensors[0]->layout();
for (auto *t : lod_tensors) {
if (t->numel() && t->IsInitialized()) {
new_dim = t->dims();
new_type = t->type();
new_layout = t->layout();
break;
}
}
LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) {
auto *t = lod_tensors[i];
PADDLE_ENFORCE_EQ(new_type, t->type());
PADDLE_ENFORCE_EQ(new_layout, t->layout());
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],
framework::product(t->dims()) / t->dims()[0]);
new_dim[0] += t->dims()[0];
if (t->numel() && t->IsInitialized()) {
PADDLE_ENFORCE_EQ(new_type, t->type());
PADDLE_ENFORCE_EQ(new_layout, t->layout());
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],
framework::product(t->dims()) / t->dims()[0]);
new_dim[0] += t->dims()[0];
}
auto &lod = t->lod();
PADDLE_ENFORCE_EQ(new_lod.size(), lod.size());
@ -356,6 +367,9 @@ void LoDTensor::MergeLoDTensor(
int begin = 0;
for (auto *src : lod_tensors) {
int end = begin + src->dims()[0];
if (end == begin) {
continue;
}
auto dst = Slice(begin, end);
framework::TensorCopy(*src, dst_place, &dst);
begin = end;

@ -185,7 +185,15 @@ TEST(LoD, MergeLoDTensor) {
dst_ptr[i] = i;
}
std::vector<const LoDTensor*> lods{&lod_tensor0, &lod_tensor1};
LoDTensor lod_tensor2;
LoD lod2;
lod2.push_back(std::vector<size_t>({0}));
lod2.push_back(std::vector<size_t>({0}));
lod_tensor2.set_lod(lod2);
lod_tensor2.Resize({0});
dst_ptr = lod_tensor2.mutable_data<float>(place);
std::vector<const LoDTensor*> lods{&lod_tensor0, &lod_tensor1, &lod_tensor2};
LoDTensor lod_tensor;
lod_tensor.MergeLoDTensor(lods, place);

@ -328,7 +328,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
void MultiClassOutput(const platform::DeviceContext& ctx,
const Tensor& scores, const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
const int scores_size, Tensor* outs) const {
const int scores_size, Tensor* outs,
int* oindices = nullptr, const int offset = 0) const {
int64_t class_num = scores.dims()[1];
int64_t predict_dim = scores.dims()[1];
int64_t box_size = bboxes.dims()[1];
@ -358,9 +359,15 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
if (scores_size == 3) {
bdata = bboxes_data + idx * box_size;
odata[count * out_dim + 1] = sdata[idx]; // score
if (oindices != nullptr) {
oindices[count] = offset + idx;
}
} else {
bdata = bbox.data<T>() + idx * box_size;
odata[count * out_dim + 1] = *(scores_data + idx * class_num + label);
if (oindices != nullptr) {
oindices[count] = offset + idx * class_num + label;
}
}
// xmin, ymin, xmax, ymax or multi-points coordinates
std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
@ -373,7 +380,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
auto* boxes = ctx.Input<LoDTensor>("BBoxes");
auto* scores = ctx.Input<LoDTensor>("Scores");
auto* outs = ctx.Output<LoDTensor>("Out");
bool return_index = ctx.HasOutput("Index") ? true : false;
auto index = ctx.Output<LoDTensor>("Index");
auto score_dims = scores->dims();
auto score_size = score_dims.size();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
@ -406,35 +414,55 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
int num_kept = batch_starts.back();
if (num_kept == 0) {
T* od = outs->mutable_data<T>({1, 1}, ctx.GetPlace());
od[0] = -1;
batch_starts = {0, 1};
if (return_index) {
outs->mutable_data<T>({0, out_dim}, ctx.GetPlace());
index->mutable_data<int>({0, 1}, ctx.GetPlace());
} else {
T* od = outs->mutable_data<T>({1, 1}, ctx.GetPlace());
od[0] = -1;
batch_starts = {0, 1};
}
} else {
outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
int offset = 0;
int* oindices = nullptr;
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores_slice = scores->Slice(i, i + 1);
boxes_slice = boxes->Slice(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice.Resize({score_dims[2], box_dim});
if (return_index) {
offset = i * score_dims[2];
}
} else {
auto boxes_lod = boxes->lod().back();
scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]);
if (return_index) {
offset = boxes_lod[i] * score_dims[1];
}
}
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
Tensor out = outs->Slice(s, e);
if (return_index) {
int* output_idx =
index->mutable_data<int>({num_kept, 1}, ctx.GetPlace());
oindices = output_idx + s;
}
MultiClassOutput(dev_ctx, scores_slice, boxes_slice, all_indices[i],
score_dims.size(), &out);
score_dims.size(), &out, oindices, offset);
}
}
}
framework::LoD lod;
lod.emplace_back(batch_starts);
if (return_index) {
index->set_lod(lod);
}
outs->set_lod(lod);
}
};
@ -519,13 +547,45 @@ This operator support multi-class and batched inputs. It applying NMS
independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
means there is no detected bbox for this image. If there is no detected boxes
for all images, all the elements in LoD are set to {1}, and the Out only
contains one value which is -1.
means there is no detected bbox for this image.
)DOC");
}
};
class MultiClassNMS2Op : public MultiClassNMSOp {
public:
MultiClassNMS2Op(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: MultiClassNMSOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMSOp::InferShape(ctx);
auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores");
auto score_size = score_dims.size();
if (score_size == 3) {
ctx->SetOutputDim("Index", {box_dims[1], 1});
} else {
ctx->SetOutputDim("Index", {-1, 1});
}
}
};
class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker {
public:
void Make() override {
MultiClassNMSOpMaker::Make();
AddOutput("Index",
"(LoDTensor) A 2-D LoDTensor with shape [No, 1] represents the "
"index of selected bbox. The index is the absolute index cross "
"batches.")
.AsIntermediate();
}
};
} // namespace operators
} // namespace paddle
@ -535,3 +595,8 @@ REGISTER_OPERATOR(multiclass_nms, ops::MultiClassNMSOp,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);
REGISTER_OPERATOR(multiclass_nms2, ops::MultiClassNMS2Op,
ops::MultiClassNMS2OpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(multiclass_nms2, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);

File diff suppressed because it is too large Load Diff

@ -47,7 +47,15 @@ class TestDetection(unittest.TestCase):
dtype='float32')
out = layers.detection_output(
scores=scores, loc=loc, prior_box=pb, prior_box_var=pbv)
out2, index = layers.detection_output(
scores=scores,
loc=loc,
prior_box=pb,
prior_box_var=pbv,
return_index=True)
self.assertIsNotNone(out)
self.assertIsNotNone(out2)
self.assertIsNotNone(index)
self.assertEqual(out.shape[-1], 6)
print(str(program))
@ -523,6 +531,21 @@ class TestMulticlassNMS(unittest.TestCase):
self.assertIsNotNone(output)
class TestMulticlassNMS2(unittest.TestCase):
def test_multiclass_nms2(self):
program = Program()
with program_guard(program):
bboxes = layers.data(
name='bboxes', shape=[-1, 10, 4], dtype='float32')
scores = layers.data(name='scores', shape=[-1, 10], dtype='float32')
output = layers.multiclass_nms2(bboxes, scores, 0.3, 400, 200, 0.7)
output2, index = layers.multiclass_nms2(
bboxes, scores, 0.3, 400, 200, 0.7, return_index=True)
self.assertIsNotNone(output)
self.assertIsNotNone(output2)
self.assertIsNotNone(index)
class TestCollectFpnPropsals(unittest.TestCase):
def test_collect_fpn_proposals(self):
program = Program()

@ -22,17 +22,25 @@ import unittest
class TestFetchVar(op_test.OpTest):
def set_input(self):
self.val = numpy.array([1, 3, 5]).astype(numpy.int32)
def test_fetch_var(self):
val = numpy.array([1, 3, 5]).astype(numpy.int32)
self.set_input()
x = layers.create_tensor(dtype="int32", persistable=True, name="x")
layers.assign(input=val, output=x)
layers.assign(input=self.val, output=x)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_main_program(), feed={}, fetch_list=[])
fetched_x = fluid.executor._fetch_var("x")
self.assertTrue(
numpy.array_equal(fetched_x, val),
"fetch_x=%s val=%s" % (fetched_x, val))
self.assertEqual(fetched_x.dtype, val.dtype)
numpy.array_equal(fetched_x, self.val),
"fetch_x=%s val=%s" % (fetched_x, self.val))
self.assertEqual(fetched_x.dtype, self.val.dtype)
class TestFetchNullVar(TestFetchVar):
def set_input(self):
self.val = numpy.array([]).astype(numpy.int32)
if __name__ == '__main__':

@ -156,12 +156,14 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
def lod_multiclass_nms(boxes, scores, background, score_threshold,
nms_threshold, nms_top_k, keep_top_k, box_lod,
normalized):
num_class = boxes.shape[1]
det_outs = []
lod = []
head = 0
for n in range(len(box_lod[0])):
box = boxes[head:head + box_lod[0][n]]
score = scores[head:head + box_lod[0][n]]
offset = head
head = head + box_lod[0][n]
nmsed_outs, nmsed_num = multiclass_nms(
box,
@ -173,19 +175,21 @@ def lod_multiclass_nms(boxes, scores, background, score_threshold,
keep_top_k,
normalized,
shared=False)
lod.append(nmsed_num)
if nmsed_num == 0:
continue
lod.append(nmsed_num)
tmp_det_out = []
for c, indices in nmsed_outs.items():
for idx in indices:
xmin, ymin, xmax, ymax = box[idx, c, :]
tmp_det_out.append([c, score[idx][c], xmin, ymin, xmax, ymax])
tmp_det_out.append([
c, score[idx][c], xmin, ymin, xmax, ymax,
offset * num_class + idx * num_class + c
])
sorted_det_out = sorted(
tmp_det_out, key=lambda tup: tup[0], reverse=False)
det_outs.extend(sorted_det_out)
if len(lod) == 0:
lod.append(1)
return det_outs, lod
@ -199,8 +203,9 @@ def batched_multiclass_nms(boxes,
keep_top_k,
normalized=True):
batch_size = scores.shape[0]
num_boxes = scores.shape[2]
det_outs = []
index_outs = []
lod = []
for n in range(batch_size):
nmsed_outs, nmsed_num = multiclass_nms(
@ -213,21 +218,21 @@ def batched_multiclass_nms(boxes,
keep_top_k,
normalized,
shared=True)
lod.append(nmsed_num)
if nmsed_num == 0:
continue
lod.append(nmsed_num)
tmp_det_out = []
for c, indices in nmsed_outs.items():
for idx in indices:
xmin, ymin, xmax, ymax = boxes[n][idx][:]
tmp_det_out.append(
[c, scores[n][c][idx], xmin, ymin, xmax, ymax])
tmp_det_out.append([
c, scores[n][c][idx], xmin, ymin, xmax, ymax,
idx + n * num_boxes
])
sorted_det_out = sorted(
tmp_det_out, key=lambda tup: tup[0], reverse=False)
det_outs.extend(sorted_det_out)
if len(lod) == 0:
lod += [1]
return det_outs, lod
@ -262,11 +267,13 @@ class TestMulticlassNMSOp(OpTest):
boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
det_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
lod = [1] if not det_outs else lod
det_outs = [[-1, 0]] if not det_outs else det_outs
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32')
self.op_type = 'multiclass_nms'
self.inputs = {'BBoxes': boxes, 'Scores': scores}
@ -324,11 +331,12 @@ class TestMulticlassNMSLoDInput(OpTest):
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
nmsed_outs, lod = lod_multiclass_nms(
det_outs, lod = lod_multiclass_nms(
boxes, scores, background, score_threshold, nms_threshold,
nms_top_k, keep_top_k, box_lod, normalized)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
det_outs = np.array(det_outs).astype('float32')
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms'
self.inputs = {
'BBoxes': (boxes, box_lod),
@ -359,5 +367,137 @@ class TestIOU(unittest.TestCase):
self.assertTrue(np.allclose(calc_output, expt_output))
class TestMulticlassNMS2Op(TestMulticlassNMSOp):
def setUp(self):
self.set_argument()
N = 7
M = 1200
C = 21
BOX_SIZE = 4
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
scores = np.random.random((N * M, C)).astype('float32')
def softmax(x):
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
scores = np.apply_along_axis(softmax, 1, scores)
scores = np.reshape(scores, (N, M, C))
scores = np.transpose(scores, (0, 2, 1))
boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
det_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
index_outs = det_outs[:, -1:].astype('int') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms2'
self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'Index': (index_outs, [lod])
}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': True,
}
def test_check_output(self):
self.check_output()
class TestMulticlassNMS2OpNoOutput(TestMulticlassNMS2Op):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
def setUp(self):
self.set_argument()
M = 1200
C = 21
BOX_SIZE = 4
box_lod = [[1200]]
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
normalized = False
scores = np.random.random((M, C)).astype('float32')
def softmax(x):
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
scores = np.apply_along_axis(softmax, 1, scores)
boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
boxes[:, :, 0] = boxes[:, :, 0] * 10
boxes[:, :, 1] = boxes[:, :, 1] * 10
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
det_outs, lod = lod_multiclass_nms(
boxes, scores, background, score_threshold, nms_threshold,
nms_top_k, keep_top_k, box_lod, normalized)
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
index_outs = det_outs[:, -1:].astype('int') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms2'
self.inputs = {
'BBoxes': (boxes, box_lod),
'Scores': (scores, box_lod),
}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'Index': (index_outs, [lod])
}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': normalized,
}
def test_check_output(self):
self.check_output()
class TestMulticlassNMS2LoDNoOutput(TestMulticlassNMS2LoDInput):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save