!7620 op NonMaxSuppression

Merge pull request !7620 from zhaozhenlong/lite/op/non_max_suppression
pull/7620/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b8fbabae34

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct NMSParameter {
OpParameter op_parameter_;
int center_point_box_;
} NMSParameter;
#endif // MINDSPORE_LITE_NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_

@ -1081,9 +1081,6 @@ table FftImag {
}
table NonMaxSuppression {
maxOutBoxPerClass : int = 0;
iouThreshold : float = 0;
scoreThreshold : float = 0;
centerPointBox : int = 0;
}

@ -0,0 +1,59 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/non_max_suppression.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
void NonMaxSuppression::SetCenterPointBox(int centerPointBox) {
this->primitive_->value.AsNonMaxSuppression()->centerPointBox = centerPointBox;
}
int NonMaxSuppression::GetCenterPointBox() const {
return this->primitive_->value.AsNonMaxSuppression()->centerPointBox;
}
#else
int NonMaxSuppression::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_NonMaxSuppression();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_NonMaxSuppression return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateNonMaxSuppression(*fbb, attr->centerPointBox());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_NonMaxSuppression, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int NonMaxSuppression::GetCenterPointBox() const {
return this->primitive_->value_as_NonMaxSuppression()->centerPointBox();
}
#endif
int NonMaxSuppression::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(kNumberTypeInt32);
output->SetFormat(input->GetFormat());
MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime.";
return RET_INFER_INVALID;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_NON_MAX_SUPPRESSION_H_
#define LITE_MINDSPORE_LITE_NON_MAX_SUPPRESSION_H_
#include <vector>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class NonMaxSuppression : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(NonMaxSuppression, PrimitiveC);
NonMaxSuppression() = default;
explicit NonMaxSuppression(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetCenterPointBox(int centerPointBox);
#else
NonMaxSuppression() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetCenterPointBox() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_NOT_EQUAL_H_

@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/non_max_suppression.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/non_max_suppression_parameter.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateNonMaxSuppressionParameter(const mindspore::lite::PrimitiveC *primitive) {
NMSParameter *param = reinterpret_cast<NMSParameter *>(malloc(sizeof(NMSParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc param failed.";
return nullptr;
}
memset(param, 0, sizeof(NMSParameter));
param->op_parameter_.type_ = primitive->Type();
auto prim =
reinterpret_cast<mindspore::lite::NonMaxSuppression *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
param->center_point_box_ = prim->GetCenterPointBox();
return reinterpret_cast<OpParameter *>(param);
}
Registry NonMaxSuppressionParameterRegistry(schema::PrimitiveType_OneHot, PopulateNonMaxSuppressionParameter);
} // namespace lite
} // namespace mindspore

@ -136,6 +136,7 @@
#include "src/ops/custom_extract_features.h"
#include "src/ops/upsample.h"
#include "src/ops/layer_norm.h"
#include "src/ops/non_max_suppression.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -726,6 +727,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new Upsample(primitive);
case schema::PrimitiveType_LayerNorm:
return new LayerNorm(primitive);
case schema::PrimitiveType_NonMaxSuppression:
return new NonMaxSuppression(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

@ -0,0 +1,248 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/non_max_suppression.h"
#include <queue>
#include <functional>
#include <utility>
#include "nnacl/non_max_suppression_parameter.h"
#include "schema/model_generated.h"
#include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_NULL_PTR;
using mindspore::schema::PrimitiveType_NonMaxSuppression;
namespace mindspore::kernel {
namespace {
constexpr size_t kMinInputsSize = 2;
constexpr size_t kMaxInputsSize = 5;
constexpr size_t kOutputNum = 1;
constexpr size_t kBoxTensorIndex = 0;
constexpr size_t kScoreTensorIndex = 1;
constexpr size_t kMaxOutputNumTensorIndex = 2;
constexpr size_t kIoUThresholdTensorIndex = 3;
constexpr size_t kScoreThresholdTensorIndex = 4;
constexpr int kBoxPointNum = 4;
} // namespace
int NonMaxSuppressionCPUKernel::Init() {
// boxes, scores, max_output_boxes, iou_threshold, score_threshold
if (in_tensors_.size() < kMinInputsSize || in_tensors_.size() > kMaxInputsSize || out_tensors_.size() != kOutputNum) {
MS_LOG(ERROR) << "NonMaxSuppression input size should be in [" << kMinInputsSize << ", " << kMaxInputsSize << "]"
<< ", got " << in_tensors_.size() << ", output size should be" << kOutputNum << ", got "
<< out_tensors_.size();
return RET_ERROR;
}
param_ = reinterpret_cast<NMSParameter *>(op_parameter_);
if (param_ == nullptr) {
MS_LOG(ERROR) << "cast to NMSParameter pointer got nullptr";
return RET_NULL_PTR;
}
if (param_->center_point_box_ != 0 && param_->center_point_box_ != 1) {
MS_LOG(ERROR) << "NonMaxSuppression center_point_box should be 0 or 1, got " << param_->center_point_box_;
return RET_ERROR;
}
center_point_box_ = param_->center_point_box_;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int NonMaxSuppressionCPUKernel::GetParams() {
// optional input order: max_output_per_class, iou_threshold, score_threshold
max_output_per_class_ = 0;
if (in_tensors_.size() >= 3) {
auto max_output_tensor = in_tensors_.at(kMaxOutputNumTensorIndex);
if (max_output_tensor != nullptr && reinterpret_cast<int64_t *>(max_output_tensor->data_c()) != nullptr) {
max_output_per_class_ = *(reinterpret_cast<int64_t *>(max_output_tensor->data_c()));
}
}
iou_threshold_ = 0.0f;
if (in_tensors_.size() >= 4) {
auto iou_threshold_tensor = in_tensors_.at(kIoUThresholdTensorIndex);
if (iou_threshold_tensor != nullptr && reinterpret_cast<float *>(iou_threshold_tensor->data_c() != nullptr)) {
iou_threshold_ = *(reinterpret_cast<float *>(iou_threshold_tensor->data_c()));
}
}
score_threshold_ = 0.0f;
if (in_tensors_.size() >= 5) {
auto score_threshold_tensor = in_tensors_.at(kScoreThresholdTensorIndex);
if (score_threshold_tensor != nullptr && reinterpret_cast<float *>(score_threshold_tensor->data_c()) != nullptr) {
score_threshold_ = *(reinterpret_cast<float *>(score_threshold_tensor->data_c()));
}
}
return RET_OK;
}
int NonMaxSuppressionCPUKernel::PreProcess() { return GetParams(); }
int NonMaxSuppressionCPUKernel::Run() {
auto box_tensor = in_tensors_.at(kBoxTensorIndex);
if (box_tensor == nullptr) {
return RET_ERROR;
}
auto box_dims = box_tensor->shape(); // batch, box_num, 4
constexpr size_t kBoxTensorDims = 3;
if (box_dims.size() != kBoxTensorDims) {
return RET_ERROR;
}
constexpr size_t kBoxCoordIndex = 2;
if (box_dims[kBoxCoordIndex] != kBoxPointNum) {
return RET_ERROR;
}
auto score_tensor = in_tensors_.at(kScoreTensorIndex);
if (score_tensor == nullptr) {
return RET_ERROR;
}
auto score_dims = score_tensor->shape(); // batch, class, box_num
constexpr size_t kScoreTensorDims = 3;
if (score_dims.size() != kScoreTensorDims) {
return RET_ERROR;
}
constexpr size_t kBatchIndex = 0;
if (score_dims[kBatchIndex] != box_dims[kBatchIndex]) {
MS_LOG(ERROR) << "Boxes tensor batch num should be equal to scores tensor's batch num.";
return RET_ERROR;
}
constexpr size_t kScoreDimsBoxNumIndex = 2;
constexpr size_t kBoxDimsBoxNumIndex = 1;
if (score_dims[kScoreDimsBoxNumIndex] != box_dims[kBoxDimsBoxNumIndex]) {
MS_LOG(ERROR) << "Boxes tensor spatial dimension should be equal to scores tensor's spatial dimension.";
return RET_ERROR;
}
const float *scores = reinterpret_cast<const float *>(score_tensor->data_c()); // batch, class, num
if (scores == nullptr) {
MS_LOG(ERROR) << "score tensor data nullptr";
return RET_ERROR;
}
int batch_num = score_dims[kBatchIndex];
constexpr size_t kClassIndex = 1;
int class_num = score_dims[kClassIndex];
int box_num = score_dims[kScoreDimsBoxNumIndex];
float *scores_data = reinterpret_cast<float *>(score_tensor->data_c());
if (scores_data == nullptr) {
MS_LOG(ERROR) << "score tensor data nullptr";
return RET_ERROR;
}
float *box_data = reinterpret_cast<float *>(box_tensor->data_c());
if (box_data == nullptr) {
MS_LOG(ERROR) << "box tensor data nullptr";
return RET_ERROR;
}
std::vector<NMSBox> selected_box_per_class;
selected_box_per_class.reserve(std::min(static_cast<int32_t>(box_num), max_output_per_class_));
std::vector<NMSIndex> selected_index;
for (auto i = 0; i < batch_num; ++i) {
int batch_offset = i * class_num * box_num;
for (auto j = 0; j < class_num; ++j) {
// per batch per class filter
float *per_class_scores = scores_data + batch_offset + j * box_num;
float *box = box_data + i * box_num * kBoxPointNum;
std::vector<NMSBox> above_score_candidates;
above_score_candidates.reserve(box_num);
for (auto k = 0; k < box_num; ++k) {
if (per_class_scores[k] > score_threshold_) {
above_score_candidates.emplace_back(per_class_scores[k], k, center_point_box_, box[0], box[1], box[2],
box[3]);
}
box += kBoxPointNum;
}
std::priority_queue<NMSBox, std::vector<NMSBox>, std::less<NMSBox>> sorted_candidates(
std::less<NMSBox>(), std::move(above_score_candidates));
selected_box_per_class.clear();
while (!sorted_candidates.empty() && static_cast<int32_t>(selected_index.size()) < max_output_per_class_) {
auto cand = sorted_candidates.top();
bool selected = true;
auto IoUSuppressed = [this, &cand](const NMSBox &box) {
float intersec_x1 = std::max(cand.x1_, box.x1_);
float intersec_x2 = std::min(cand.x2_, box.x2_);
float intersec_y1 = std::max(cand.y1_, box.y1_);
float intersec_y2 = std::max(cand.y2_, box.y2_);
const float intersec_area =
std::max(intersec_x2 - intersec_x1, 0.0f) * std::max(intersec_y2 - intersec_y1, 0.0f);
if (intersec_area <= 0.0f) {
return false;
}
const float intersec_over_union = intersec_area / (cand.area_ + box.area_ - intersec_area);
return intersec_over_union > this->iou_threshold_;
};
if (std::any_of(selected_box_per_class.begin(), selected_box_per_class.end(), IoUSuppressed)) {
selected = false;
}
if (selected) {
selected_box_per_class.push_back(cand);
selected_index.emplace_back(
NMSIndex{static_cast<int32_t>(i), static_cast<int32_t>(j), static_cast<int32_t>(cand.index_)});
}
sorted_candidates.pop();
}
}
}
auto output = out_tensors_.at(0);
int selected_num = static_cast<int>(selected_index.size());
const int output_last_dim = 3;
output->set_shape({selected_num, output_last_dim});
MS_ASSERT(output_last_dim * sizeof(int32_t) == sizeof(NMSIndex));
int32_t *out_data = reinterpret_cast<int32_t *>(output->MutableData());
memcpy(out_data, selected_index.data(), selected_index.size() * sizeof(NMSIndex));
return RET_OK;
}
kernel::LiteKernel *CpuNonMaxSuppressionFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "NonMaxSuppression opParameter nullptr.";
return nullptr;
}
if (desc.type != schema::PrimitiveType_NonMaxSuppression) {
MS_LOG(ERROR) << "OneHot desc type should be " << schema::PrimitiveType_NonMaxSuppression << " got " << desc.type;
free(opParameter);
return nullptr;
}
auto *kernel = new (std::nothrow) NonMaxSuppressionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "OneHot new kernel failed.";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonMaxSuppression, CpuNonMaxSuppressionFp32KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,93 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_
#include <vector>
#include <algorithm>
#include "src/lite_kernel.h"
#include "mindspore/lite/nnacl/non_max_suppression_parameter.h"
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
class NonMaxSuppressionCPUKernel : public LiteKernel {
public:
NonMaxSuppressionCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~NonMaxSuppressionCPUKernel() override = default;
int Init() override;
int ReSize() override { return RET_OK; };
int PreProcess() override;
int Run() override;
private:
int GetParams();
private:
int center_point_box_;
float iou_threshold_;
float score_threshold_;
int32_t max_output_per_class_;
NMSParameter *param_ = nullptr;
};
typedef struct NMSIndex {
int32_t batch_index_;
int32_t class_index_;
int32_t box_index_;
} NMSIndex;
class NMSBox {
public:
NMSBox() = default;
~NMSBox() = default;
explicit NMSBox(float score, int box_index, int center_point_box, float y_a, float x_a, float y_b, float x_b)
: score_(score), index_(box_index) {
if (0 == center_point_box) {
y1_ = std::min(y_a, y_b);
y2_ = std::max(y_a, y_b);
x1_ = std::min(x_a, x_b);
x2_ = std::max(x_a, x_b);
} else {
// x_center, y_center, width, height
float half_wid = x_b / 2;
x1_ = x_a - half_wid;
x2_ = x_a + half_wid;
float half_height = y_b / 2;
y1_ = y_a - half_height;
y2_ = y_a + half_height;
}
area_ = (y2_ - y1_) * (x2_ - x1_);
}
inline bool operator<(const NMSBox &box) const { return score_ < box.score_; }
public:
float score_;
int index_;
float y1_; // y1 x1 y2 x2 ascending order
float y2_;
float x1_;
float x2_;
float area_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_

@ -0,0 +1,116 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/non_max_suppression.h"
#include "src/kernel_registry.h"
#include "src/lite_kernel.h"
using mindspore::schema::Format_NHWC;
namespace mindspore {
class TestNMSFp32 : public mindspore::CommonTest {
public:
TestNMSFp32() = default;
void Init(const std::vector<int> &box_tensor_shape, float *box_data, const std::vector<int> &score_tensor_shape,
float *score_data, int32_t max_output, float iou_threshold, float score_threshold, int center_box_point);
void TearDown() override;
public:
float err_tol_ = 1e-5;
lite::Tensor box_tensor_;
lite::Tensor score_tensor_;
lite::Tensor max_output_box_per_class_tensor_;
lite::Tensor iou_threshold_tensor_;
lite::Tensor score_threshold_tensor_;
lite::Tensor out_tensor_;
int32_t max_output_;
float iou_threshold_;
float score_threshold_;
std::vector<lite::Tensor *> inputs_{&box_tensor_, &score_tensor_, &max_output_box_per_class_tensor_,
&iou_threshold_tensor_, &score_threshold_tensor_};
std::vector<lite::Tensor *> outputs_{&out_tensor_};
NMSParameter param_;
kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_NonMaxSuppression};
lite::InnerContext ctx_ = lite::InnerContext();
kernel::KernelCreator creator_ = nullptr;
kernel::LiteKernel *kernel_ = nullptr;
};
void TestNMSFp32::TearDown() {
box_tensor_.SetData(nullptr);
score_tensor_.SetData(nullptr);
max_output_box_per_class_tensor_.SetData(nullptr);
iou_threshold_tensor_.SetData(nullptr);
score_threshold_tensor_.SetData(nullptr);
out_tensor_.FreeData();
}
void TestNMSFp32::Init(const std::vector<int> &box_tensor_shape, float *box_data,
const std::vector<int> &score_tensor_shape, float *score_data, int32_t max_output,
float iou_threshold, float score_threshold, int center_box_point) {
box_tensor_.set_data_type(kNumberTypeFloat32);
box_tensor_.SetFormat(Format_NHWC);
box_tensor_.set_shape(box_tensor_shape);
box_tensor_.SetData(box_data);
score_tensor_.set_data_type(kNumberTypeFloat32);
score_tensor_.SetFormat(Format_NHWC);
score_tensor_.set_shape(score_tensor_shape);
score_tensor_.SetData(score_data);
max_output_ = max_output;
max_output_box_per_class_tensor_.SetData(&max_output_);
iou_threshold_ = iou_threshold;
iou_threshold_tensor_.SetData(&iou_threshold_);
score_threshold_ = score_threshold;
score_threshold_tensor_.SetData(&score_threshold_);
out_tensor_.set_data_type(kNumberTypeInt32);
param_.center_point_box_ = center_box_point;
ctx_ = lite::InnerContext();
ASSERT_EQ(lite::RET_OK, ctx_.Init());
creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc_);
ASSERT_NE(creator_, nullptr);
kernel_ = creator_(inputs_, outputs_, reinterpret_cast<OpParameter *>(&param_), &ctx_, desc_, nullptr);
ASSERT_NE(kernel_, nullptr);
}
TEST_F(TestNMSFp32, TestCase1) {
std::vector<int> box_tensor_shape{1, 6, 4}; // batch 1, num 6, box coord 4
float box_data[24] = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.0f, 0.1f, 0.0f, 1.0f, 1.0f,
0.0f, 3.0f, 1.0f, 1.0f, 0.0f, 3.1f, 1.0f, 1.0f, 0.0f, 6.0f, 1.0f, 1.0f};
std::vector<int> score_tensor_shape{1, 1, 6}; // batch 1, class 1, num 6
float score_data[6] = {0.9f, 0.8f, 0.7f, 0.95f, 0.6f, 0.5f};
int64_t max_output = 3;
float iou_threshold = 0.5f;
float score_threshold = 0.0f;
int center_box_point = 1;
auto output_size = 9;
Init(box_tensor_shape, box_data, score_tensor_shape, score_data, max_output, iou_threshold, score_threshold,
center_box_point);
auto ret = kernel_->PreProcess();
EXPECT_EQ(0, ret);
ret = kernel_->Run();
EXPECT_EQ(0, ret);
std::vector<int32_t> expect{0, 0, 3, 0, 0, 0, 0, 0, 5};
CompareOutputData(reinterpret_cast<int32_t *>(out_tensor_.data_c()), expect.data(), output_size, err_tol_);
}
} // namespace mindspore

@ -37,29 +37,6 @@ STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, co
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (onnx_node.input_size() > 2) {
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(2); });
if (it != onnx_graph.initializer().end()) {
attr->maxOutBoxPerClass = it->int64_data(0);
}
}
if (onnx_node.input_size() > 3) {
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(3); });
if (it != onnx_graph.initializer().end()) {
attr->iouThreshold = it->float_data(0);
}
}
if (onnx_node.input_size() > 4) {
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(4); });
if (it != onnx_graph.initializer().end()) {
attr->scoreThreshold = it->float_data(0);
}
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
@ -70,7 +47,7 @@ STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, co
}
}
op->primitive->value.type = schema::PrimitiveType_Elu;
op->primitive->value.type = schema::PrimitiveType_NonMaxSuppression;
op->primitive->value.value = attr.release();
return RET_OK;
}

Loading…
Cancel
Save