Merge pull request #8193 from qingqing01/ssd_target_assign
Add target assigner operator for SSD detection.emailweixu-patch-1
commit
ae0740ce66
@ -0,0 +1,202 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/target_assign_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class TargetAssignOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
// checkout inputs
|
||||
PADDLE_ENFORCE(ctx->HasInput("EncodedGTBBox"),
|
||||
"Input(EncodedGTBBox) of TargetAssignOp should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput("GTScoreLabel"),
|
||||
"Input(GTScoreLabel) of TargetAssignOp should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput("MatchIndices"),
|
||||
"Input(MatchIndices) of TargetAssignOp should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput("NegIndices"),
|
||||
"Input(NegIndices) of TargetAssignOp should not be null");
|
||||
|
||||
// checkout outputs
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("PredBBoxLabel"),
|
||||
"Output(PredBBoxLabel) of TargetAssignOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("PredBBoxWeight"),
|
||||
"Output(PredBBoxWeight) of TargetAssignOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("PredScoreLabel"),
|
||||
"Output(PredScoreLabel) of TargetAssignOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("PredScoreWeight"),
|
||||
"Output(PredScoreWeight) of TargetAssignOp should not be null.");
|
||||
|
||||
auto blabel_dims = ctx->GetInputDim("EncodedGTBBox");
|
||||
auto slabel_dims = ctx->GetInputDim("GTScoreLabel");
|
||||
auto mi_dims = ctx->GetInputDim("MatchIndices");
|
||||
auto neg_dims = ctx->GetInputDim("NegIndices");
|
||||
|
||||
PADDLE_ENFORCE_EQ(blabel_dims.size(), 3UL,
|
||||
"The rank of Input(EncodedGTBBox) must be 3.");
|
||||
PADDLE_ENFORCE_EQ(slabel_dims.size(), 2UL,
|
||||
"The rank of Input(GTScoreLabel) must be 2.");
|
||||
PADDLE_ENFORCE_EQ(mi_dims.size(), 2UL,
|
||||
"The rank of Input(MatchIndices) must be 2.");
|
||||
PADDLE_ENFORCE_EQ(neg_dims.size(), 2UL,
|
||||
"The rank of Input(NegIndices) must be 2.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0],
|
||||
"The 1st dimension (means the total number of "
|
||||
"ground-truth bounding boxes) of Input(EncodedGTBBox) "
|
||||
"and Input(GTScoreLabel) must be the same.");
|
||||
PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1],
|
||||
"The 2nd dimension (means the number of priod boxes) "
|
||||
"of Input(EncodedGTBBox) and "
|
||||
"Input(MatchIndices) must be the same.");
|
||||
PADDLE_ENFORCE_EQ(blabel_dims[2], 4,
|
||||
"The 3rd dimension of Input(EncodedGTBBox) must be 4.");
|
||||
|
||||
auto n = mi_dims[0];
|
||||
auto np = mi_dims[1];
|
||||
ctx->SetOutputDim("PredBBoxLabel", {n, np, 4});
|
||||
ctx->SetOutputDim("PredBBoxWeight", {n, np, 1});
|
||||
ctx->SetOutputDim("PredScoreLabel", {n, np, 1});
|
||||
ctx->SetOutputDim("PredScoreWeight", {n, np, 1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(
|
||||
ctx.Input<framework::LoDTensor>("EncodedGTBBox")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("EncodedGTBBox",
|
||||
"(LoDTensor), The encoded ground-truth bounding boxes with shape "
|
||||
"[Ng, Np, 4], where Ng is the total number of ground-truth boxes "
|
||||
"in this mini-batch, Np the number of predictions, 4 is the "
|
||||
"number of coordinate in [xmin, ymin, xmax, ymax] layout.");
|
||||
AddInput("GTScoreLabel",
|
||||
"(LoDTensor, default LoDTensor<int>), The input ground-truth "
|
||||
"labels with shape [Ng, 1], where the Ng is the same as it in "
|
||||
"the input of EncodedGTBBox.");
|
||||
AddInput("MatchIndices",
|
||||
"(Tensor, default Tensor<int>), The input matched indices "
|
||||
"with shape [N, Np], where N is the batch size, Np is the same "
|
||||
"as it in the input of EncodedGTBBox. If MatchIndices[i][j] "
|
||||
"is -1, the j-th prior box is not matched to any ground-truh "
|
||||
"box in i-th instance.");
|
||||
AddInput("NegIndices",
|
||||
"(LoDTensor, default LoDTensor<int>), The input negative example "
|
||||
"indices with shape [Neg, 1], where is the total number of "
|
||||
"negative example indices.");
|
||||
AddAttr<int>("background_label",
|
||||
"(int, default 0), Label index of background class.")
|
||||
.SetDefault(0);
|
||||
AddOutput("PredBBoxLabel",
|
||||
"(Tensor), The output encoded ground-truth labels "
|
||||
"with shape [N, Np, 4], N is the batch size and Np, 4 is the "
|
||||
"same as they in input of EncodedGTBBox. If MatchIndices[i][j] "
|
||||
"is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth "
|
||||
"box for background_label in i-th instance.");
|
||||
AddOutput("PredBBoxWeight",
|
||||
"(Tensor), The weight for PredBBoxLabel with the shape "
|
||||
"of [N, Np, 1]");
|
||||
AddOutput("PredScoreLabel",
|
||||
"(Tensor, default Tensor<int>), The output score labels for "
|
||||
"each predictions with shape [N, Np, 1]. If MatchIndices[i][j] "
|
||||
"is -1, PredScoreLabel[i][j] = background_label.");
|
||||
AddOutput("PredScoreWeight",
|
||||
"(Tensor), The weight for PredScoreLabel with the shape "
|
||||
"of [N, Np, 1]");
|
||||
AddComment(R"DOC(
|
||||
This operator is, for given the encoded boxes between prior boxes and
|
||||
ground-truth boxes and ground-truth class labels, to assign classification
|
||||
and regression targets to each prior box as well as weights to each
|
||||
prior box. The weights is used to specify which prior box would not contribute
|
||||
to training loss.
|
||||
|
||||
For each instance, the output `PredBBoxLabel`, `PredBBoxWeight`,
|
||||
`PredScoreLabel` and `PredScoreWeight` are assigned based on `MatchIndices`.
|
||||
Assumed that the row offset for each instance in `EncodedGTBBox` is called lod,
|
||||
this operato assigns classification/regression targets by performing the
|
||||
following steps:
|
||||
|
||||
1. Assigning all outpts based on `MatchIndices`:
|
||||
|
||||
If id = MatchIndices[i][j] > 0,
|
||||
|
||||
PredBBoxLabel[i][j] = EncodedGTBBox[lod[i] + id][j]
|
||||
PredBBoxWeight[i][j] = 1.
|
||||
PredScoreLabel[i][j] = GTScoreLabel[lod[i] + id]
|
||||
PredScoreWeight[i][j] = 1.
|
||||
|
||||
Otherwise,
|
||||
|
||||
PredBBoxLabel[j][j] = [0., 0., 0., 0.]
|
||||
PredBBoxWeight[i][j] = 0.
|
||||
PredScoreLabel[i][j] = background_label
|
||||
PredScoreWeight[i][j] = 0.
|
||||
|
||||
2. Assigning PredScoreWeight based on `NegIndices`:
|
||||
|
||||
Assumed that the row offset for each instance in `NegIndices` is caleed neg_lod,
|
||||
for i-th instance and all ids of NegIndices in this instance:
|
||||
|
||||
PredScoreLabel[i][id] = background_label
|
||||
PredScoreWeight[i][id] = 1.0
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NegTargetAssignFunctor<platform::CPUDeviceContext, T> {
|
||||
void operator()(const platform::CPUDeviceContext& ctx, const int* neg_indices,
|
||||
const size_t* lod, const int num, const int num_prior_box,
|
||||
const int background_label, int* out_label, T* out_label_wt) {
|
||||
for (int i = 0; i < num; ++i) {
|
||||
for (size_t j = lod[i]; j < lod[i + 1]; ++j) {
|
||||
int id = neg_indices[j];
|
||||
out_label[i * num_prior_box + id] = background_label;
|
||||
out_label_wt[i * num_prior_box + id] = static_cast<T>(1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template struct NegTargetAssignFunctor<platform::CPUDeviceContext, float>;
|
||||
template struct NegTargetAssignFunctor<platform::CPUDeviceContext, double>;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(target_assign, ops::TargetAssignOp,
|
||||
ops::TargetAssignOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
target_assign,
|
||||
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,61 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/target_assign_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void NegTargetAssignKernel(const int* neg_indices, const size_t* lod,
|
||||
const int num, const int num_prior_box,
|
||||
const int background_label,
|
||||
int* out_label, T* out_label_wt) {
|
||||
int bidx = blockIdx.x;
|
||||
int st = lod[bidx];
|
||||
int ed = lod[bidx + 1];
|
||||
|
||||
int row_start = bidx * num_prior_box;
|
||||
for (int i = st + threadIdx.x; i < ed; i += blockDim.x) {
|
||||
int id = row_start + neg_indices[i];
|
||||
out_label[id] = background_label;
|
||||
out_label_wt[id] = 1.;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct NegTargetAssignFunctor<platform::CUDADeviceContext, T> {
|
||||
void operator()(const platform::CUDADeviceContext& ctx,
|
||||
const int* neg_indices, const size_t* lod, const int num,
|
||||
const int num_prior_box, const int background_label,
|
||||
int* out_label, T* out_label_wt) {
|
||||
const int block_size = 256;
|
||||
const int grid_size = num;
|
||||
NegTargetAssignKernel<T><<<grid_size, block_size, 0, ctx.stream()>>>(
|
||||
neg_indices, lod, num, num_prior_box, background_label, out_label,
|
||||
out_label_wt);
|
||||
}
|
||||
};
|
||||
|
||||
template struct NegTargetAssignFunctor<platform::CUDADeviceContext, float>;
|
||||
template struct NegTargetAssignFunctor<platform::CUDADeviceContext, double>;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
target_assign,
|
||||
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,160 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/platform/assert.h"
|
||||
#include "paddle/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct TargetAssignFunctor {
|
||||
const T* gt_box_;
|
||||
const int* gt_label_;
|
||||
const int* match_indices_;
|
||||
const size_t* lod_;
|
||||
const int background_label_;
|
||||
const int64_t num_;
|
||||
const int64_t num_prior_box_;
|
||||
|
||||
T* out_box_;
|
||||
T* out_box_wt_;
|
||||
int* out_label_;
|
||||
T* out_label_wt_;
|
||||
|
||||
TargetAssignFunctor(const T* gt_box, const int* gt_label,
|
||||
const int* match_indices, const size_t* lod,
|
||||
const int background_label, const int64_t num,
|
||||
const int64_t np, T* out_box, T* out_box_wt,
|
||||
int* out_label, T* out_label_wt)
|
||||
: gt_box_(gt_box),
|
||||
gt_label_(gt_label),
|
||||
match_indices_(match_indices),
|
||||
lod_(lod),
|
||||
background_label_(background_label),
|
||||
num_(num),
|
||||
num_prior_box_(np),
|
||||
out_box_(out_box),
|
||||
out_box_wt_(out_box_wt),
|
||||
out_label_(out_label),
|
||||
out_label_wt_(out_label_wt) {}
|
||||
|
||||
HOSTDEVICE void operator()(size_t i) const {
|
||||
int row = i / num_prior_box_;
|
||||
int col = i - row * num_prior_box_;
|
||||
|
||||
size_t row_off = lod_[row];
|
||||
int offset = row * num_prior_box_ + col;
|
||||
|
||||
int id = match_indices_[offset];
|
||||
T* obox = out_box_ + offset * 4;
|
||||
int* olabel = out_label_ + offset;
|
||||
T* obox_wt = out_box_wt_ + offset;
|
||||
T* olabel_wt = out_label_wt_ + offset;
|
||||
|
||||
if (id > -1) {
|
||||
const T* gtbox = gt_box_ + ((row_off + id) * num_prior_box_ + col) * 4;
|
||||
|
||||
obox[0] = gtbox[0];
|
||||
obox[1] = gtbox[1];
|
||||
obox[2] = gtbox[2];
|
||||
obox[3] = gtbox[3];
|
||||
|
||||
olabel[0] = gt_label_[row_off + id];
|
||||
obox_wt[0] = static_cast<T>(1.);
|
||||
olabel_wt[0] = static_cast<T>(1.);
|
||||
} else {
|
||||
obox[0] = static_cast<T>(0.);
|
||||
obox[1] = static_cast<T>(0.);
|
||||
obox[2] = static_cast<T>(0.);
|
||||
obox[3] = static_cast<T>(0.);
|
||||
|
||||
olabel[0] = background_label_;
|
||||
obox_wt[0] = static_cast<T>(0.);
|
||||
olabel_wt[0] = static_cast<T>(0.);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
struct NegTargetAssignFunctor {
|
||||
void operator()(const platform::DeviceContext& ctx, const int* neg_indices,
|
||||
const size_t* lod, const int num, const int num_prior_box,
|
||||
const int background_label, int* out_label,
|
||||
T* out_label_wt) const;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class TargetAssignKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* enc_gt_box = ctx.Input<framework::LoDTensor>("EncodedGTBBox");
|
||||
auto* gt_label = ctx.Input<framework::LoDTensor>("GTScoreLabel");
|
||||
auto* match_indices = ctx.Input<framework::Tensor>("MatchIndices");
|
||||
auto* neg_indices = ctx.Input<framework::LoDTensor>("NegIndices");
|
||||
|
||||
auto* out_box = ctx.Output<framework::Tensor>("PredBBoxLabel");
|
||||
auto* out_box_wt = ctx.Output<framework::Tensor>("PredBBoxWeight");
|
||||
auto* out_label = ctx.Output<framework::Tensor>("PredScoreLabel");
|
||||
auto* out_label_wt = ctx.Output<framework::Tensor>("PredScoreWeight");
|
||||
|
||||
PADDLE_ENFORCE_EQ(enc_gt_box->lod().size(), 1UL);
|
||||
PADDLE_ENFORCE_EQ(gt_label->lod().size(), 1UL);
|
||||
PADDLE_ENFORCE_EQ(neg_indices->lod().size(), 1UL);
|
||||
|
||||
int background_label = ctx.Attr<int>("background_label");
|
||||
|
||||
const T* box_data = enc_gt_box->data<T>();
|
||||
const int* label_data = gt_label->data<int>();
|
||||
const int* match_idx_data = match_indices->data<int>();
|
||||
const int* neg_idx_data = neg_indices->data<int>();
|
||||
|
||||
T* obox_data = out_box->mutable_data<T>(ctx.GetPlace());
|
||||
T* obox_wt_data = out_box_wt->mutable_data<T>(ctx.GetPlace());
|
||||
int* olabel_data = out_label->mutable_data<int>(ctx.GetPlace());
|
||||
T* olabel_wt_data = out_label_wt->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int64_t num = match_indices->dims()[0];
|
||||
int64_t num_prior_box = match_indices->dims()[1];
|
||||
|
||||
auto gt_lod = enc_gt_box->lod().back();
|
||||
auto gt_label_lod = gt_label->lod().back();
|
||||
auto neg_lod = neg_indices->lod().back();
|
||||
for (size_t i = 0; i < gt_lod.size(); ++i) {
|
||||
PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]);
|
||||
}
|
||||
|
||||
size_t* gt_lod_data = gt_lod.data(ctx.GetPlace());
|
||||
size_t* neg_lod_data = neg_lod.data(ctx.GetPlace());
|
||||
|
||||
TargetAssignFunctor<T> functor(box_data, label_data, match_idx_data,
|
||||
gt_lod_data, background_label, num,
|
||||
num_prior_box, obox_data, obox_wt_data,
|
||||
olabel_data, olabel_wt_data);
|
||||
|
||||
auto& device_ctx = ctx.template device_context<DeviceContext>();
|
||||
platform::ForRange<DeviceContext> for_range(device_ctx,
|
||||
num * num_prior_box);
|
||||
for_range(functor);
|
||||
|
||||
NegTargetAssignFunctor<DeviceContext, T> neg_trg_functor;
|
||||
neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, num, num_prior_box,
|
||||
background_label, olabel_data, olabel_wt_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import random
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def gen_match_and_neg_indices(num_prior, gt_lod, neg_lod):
|
||||
if len(gt_lod) != len(neg_lod):
|
||||
raise AssertionError("The input arguments are illegal.")
|
||||
|
||||
batch_size = len(gt_lod) - 1
|
||||
|
||||
match_indices = -1 * np.ones((batch_size, num_prior)).astype('int32')
|
||||
neg_indices = np.zeros((neg_lod[-1], 1)).astype('int32')
|
||||
|
||||
for n in range(batch_size):
|
||||
gt_num = gt_lod[n + 1] - gt_lod[n]
|
||||
ids = random.sample([i for i in range(num_prior)], gt_num)
|
||||
match_indices[n, ids] = [i for i in range(gt_num)]
|
||||
|
||||
ret_ids = set([i for i in range(num_prior)]) - set(ids)
|
||||
s = neg_lod[n]
|
||||
e = neg_lod[n + 1]
|
||||
l = e - s
|
||||
neg_ids = random.sample(ret_ids, l)
|
||||
neg_indices[s:e, :] = np.array(neg_ids).astype('int32').reshape(l, 1)
|
||||
|
||||
return match_indices, neg_indices
|
||||
|
||||
|
||||
def target_assign(encoded_box, gt_label, match_indices, neg_indices, gt_lod,
|
||||
neg_lod, background_label):
|
||||
batch_size, num_prior = match_indices.shape
|
||||
|
||||
# init target bbox
|
||||
trg_box = np.zeros((batch_size, num_prior, 4)).astype('float32')
|
||||
# init weight for target bbox
|
||||
trg_box_wt = np.zeros((batch_size, num_prior, 1)).astype('float32')
|
||||
# init target label
|
||||
trg_label = np.ones((batch_size, num_prior, 1)).astype('int32')
|
||||
trg_label = trg_label * background_label
|
||||
# init weight for target label
|
||||
trg_label_wt = np.zeros((batch_size, num_prior, 1)).astype('float32')
|
||||
|
||||
for i in range(batch_size):
|
||||
cur_indices = match_indices[i]
|
||||
col_ids = np.where(cur_indices > -1)
|
||||
col_val = cur_indices[col_ids]
|
||||
|
||||
gt_start = gt_lod[i]
|
||||
# target bbox
|
||||
for v, c in zip(col_val + gt_start, col_ids[0].tolist()):
|
||||
trg_box[i][c][:] = encoded_box[v][c][:]
|
||||
|
||||
# weight for target bbox
|
||||
trg_box_wt[i][col_ids] = 1.0
|
||||
|
||||
trg_label[i][col_ids] = gt_label[col_val + gt_start]
|
||||
|
||||
trg_label_wt[i][col_ids] = 1.0
|
||||
# set target label weight to 1.0 for the negative samples
|
||||
neg_ids = neg_indices[neg_lod[i]:neg_lod[i + 1]]
|
||||
trg_label_wt[i][neg_ids] = 1.0
|
||||
|
||||
return trg_box, trg_box_wt, trg_label, trg_label_wt
|
||||
|
||||
|
||||
class TestTargetAssginOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "target_assign"
|
||||
|
||||
num_prior = 120
|
||||
num_class = 21
|
||||
gt_lod = [0, 5, 11, 23]
|
||||
neg_lod = [0, 4, 7, 13]
|
||||
batch_size = len(gt_lod) - 1
|
||||
num_gt = gt_lod[-1]
|
||||
background_label = 0
|
||||
|
||||
encoded_box = np.random.random((num_gt, num_prior, 4)).astype('float32')
|
||||
gt_label = np.random.randint(
|
||||
num_class, size=(num_gt, 1)).astype('int32')
|
||||
match_indices, neg_indices = gen_match_and_neg_indices(num_prior,
|
||||
gt_lod, neg_lod)
|
||||
trg_box, trg_box_wt, trg_label, trg_label_wt = target_assign(
|
||||
encoded_box, gt_label, match_indices, neg_indices, gt_lod, neg_lod,
|
||||
background_label)
|
||||
|
||||
self.inputs = {
|
||||
'EncodedGTBBox': (encoded_box, [gt_lod]),
|
||||
'GTScoreLabel': (gt_label, [gt_lod]),
|
||||
'MatchIndices': (match_indices),
|
||||
'NegIndices': (neg_indices, [neg_lod]),
|
||||
}
|
||||
self.attrs = {'background_label': background_label}
|
||||
self.outputs = {
|
||||
'PredBBoxLabel': (trg_box),
|
||||
'PredBBoxWeight': (trg_box_wt),
|
||||
'PredScoreLabel': (trg_label),
|
||||
'PredScoreWeight': (trg_label_wt),
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue