!7620 op NonMaxSuppression
Merge pull request !7620 from zhaozhenlong/lite/op/non_max_suppressionpull/7620/MERGE
commit
b8fbabae34
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct NMSParameter {
|
||||
OpParameter op_parameter_;
|
||||
int center_point_box_;
|
||||
} NMSParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_
|
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/non_max_suppression.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
void NonMaxSuppression::SetCenterPointBox(int centerPointBox) {
|
||||
this->primitive_->value.AsNonMaxSuppression()->centerPointBox = centerPointBox;
|
||||
}
|
||||
|
||||
int NonMaxSuppression::GetCenterPointBox() const {
|
||||
return this->primitive_->value.AsNonMaxSuppression()->centerPointBox;
|
||||
}
|
||||
#else
|
||||
int NonMaxSuppression::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_NonMaxSuppression();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_NonMaxSuppression return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateNonMaxSuppression(*fbb, attr->centerPointBox());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_NonMaxSuppression, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NonMaxSuppression::GetCenterPointBox() const {
|
||||
return this->primitive_->value_as_NonMaxSuppression()->centerPointBox();
|
||||
}
|
||||
#endif
|
||||
int NonMaxSuppression::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(kNumberTypeInt32);
|
||||
output->SetFormat(input->GetFormat());
|
||||
MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime.";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_NON_MAX_SUPPRESSION_H_
|
||||
#define LITE_MINDSPORE_LITE_NON_MAX_SUPPRESSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class NonMaxSuppression : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(NonMaxSuppression, PrimitiveC);
|
||||
NonMaxSuppression() = default;
|
||||
explicit NonMaxSuppression(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetCenterPointBox(int centerPointBox);
|
||||
#else
|
||||
NonMaxSuppression() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetCenterPointBox() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_NOT_EQUAL_H_
|
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/non_max_suppression.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/non_max_suppression_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
OpParameter *PopulateNonMaxSuppressionParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
NMSParameter *param = reinterpret_cast<NMSParameter *>(malloc(sizeof(NMSParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc param failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(NMSParameter));
|
||||
param->op_parameter_.type_ = primitive->Type();
|
||||
auto prim =
|
||||
reinterpret_cast<mindspore::lite::NonMaxSuppression *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
param->center_point_box_ = prim->GetCenterPointBox();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
Registry NonMaxSuppressionParameterRegistry(schema::PrimitiveType_OneHot, PopulateNonMaxSuppressionParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,248 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/non_max_suppression.h"
|
||||
#include <queue>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include "nnacl/non_max_suppression_parameter.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::schema::PrimitiveType_NonMaxSuppression;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr size_t kMinInputsSize = 2;
|
||||
constexpr size_t kMaxInputsSize = 5;
|
||||
constexpr size_t kOutputNum = 1;
|
||||
constexpr size_t kBoxTensorIndex = 0;
|
||||
constexpr size_t kScoreTensorIndex = 1;
|
||||
constexpr size_t kMaxOutputNumTensorIndex = 2;
|
||||
constexpr size_t kIoUThresholdTensorIndex = 3;
|
||||
constexpr size_t kScoreThresholdTensorIndex = 4;
|
||||
constexpr int kBoxPointNum = 4;
|
||||
} // namespace
|
||||
|
||||
int NonMaxSuppressionCPUKernel::Init() {
|
||||
// boxes, scores, max_output_boxes, iou_threshold, score_threshold
|
||||
if (in_tensors_.size() < kMinInputsSize || in_tensors_.size() > kMaxInputsSize || out_tensors_.size() != kOutputNum) {
|
||||
MS_LOG(ERROR) << "NonMaxSuppression input size should be in [" << kMinInputsSize << ", " << kMaxInputsSize << "]"
|
||||
<< ", got " << in_tensors_.size() << ", output size should be" << kOutputNum << ", got "
|
||||
<< out_tensors_.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
param_ = reinterpret_cast<NMSParameter *>(op_parameter_);
|
||||
if (param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "cast to NMSParameter pointer got nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (param_->center_point_box_ != 0 && param_->center_point_box_ != 1) {
|
||||
MS_LOG(ERROR) << "NonMaxSuppression center_point_box should be 0 or 1, got " << param_->center_point_box_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
center_point_box_ = param_->center_point_box_;
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int NonMaxSuppressionCPUKernel::GetParams() {
|
||||
// optional input order: max_output_per_class, iou_threshold, score_threshold
|
||||
max_output_per_class_ = 0;
|
||||
if (in_tensors_.size() >= 3) {
|
||||
auto max_output_tensor = in_tensors_.at(kMaxOutputNumTensorIndex);
|
||||
if (max_output_tensor != nullptr && reinterpret_cast<int64_t *>(max_output_tensor->data_c()) != nullptr) {
|
||||
max_output_per_class_ = *(reinterpret_cast<int64_t *>(max_output_tensor->data_c()));
|
||||
}
|
||||
}
|
||||
iou_threshold_ = 0.0f;
|
||||
if (in_tensors_.size() >= 4) {
|
||||
auto iou_threshold_tensor = in_tensors_.at(kIoUThresholdTensorIndex);
|
||||
if (iou_threshold_tensor != nullptr && reinterpret_cast<float *>(iou_threshold_tensor->data_c() != nullptr)) {
|
||||
iou_threshold_ = *(reinterpret_cast<float *>(iou_threshold_tensor->data_c()));
|
||||
}
|
||||
}
|
||||
score_threshold_ = 0.0f;
|
||||
if (in_tensors_.size() >= 5) {
|
||||
auto score_threshold_tensor = in_tensors_.at(kScoreThresholdTensorIndex);
|
||||
if (score_threshold_tensor != nullptr && reinterpret_cast<float *>(score_threshold_tensor->data_c()) != nullptr) {
|
||||
score_threshold_ = *(reinterpret_cast<float *>(score_threshold_tensor->data_c()));
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NonMaxSuppressionCPUKernel::PreProcess() { return GetParams(); }
|
||||
|
||||
int NonMaxSuppressionCPUKernel::Run() {
|
||||
auto box_tensor = in_tensors_.at(kBoxTensorIndex);
|
||||
if (box_tensor == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto box_dims = box_tensor->shape(); // batch, box_num, 4
|
||||
constexpr size_t kBoxTensorDims = 3;
|
||||
if (box_dims.size() != kBoxTensorDims) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
constexpr size_t kBoxCoordIndex = 2;
|
||||
if (box_dims[kBoxCoordIndex] != kBoxPointNum) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto score_tensor = in_tensors_.at(kScoreTensorIndex);
|
||||
if (score_tensor == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto score_dims = score_tensor->shape(); // batch, class, box_num
|
||||
constexpr size_t kScoreTensorDims = 3;
|
||||
if (score_dims.size() != kScoreTensorDims) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
constexpr size_t kBatchIndex = 0;
|
||||
if (score_dims[kBatchIndex] != box_dims[kBatchIndex]) {
|
||||
MS_LOG(ERROR) << "Boxes tensor batch num should be equal to scores tensor's batch num.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
constexpr size_t kScoreDimsBoxNumIndex = 2;
|
||||
constexpr size_t kBoxDimsBoxNumIndex = 1;
|
||||
if (score_dims[kScoreDimsBoxNumIndex] != box_dims[kBoxDimsBoxNumIndex]) {
|
||||
MS_LOG(ERROR) << "Boxes tensor spatial dimension should be equal to scores tensor's spatial dimension.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const float *scores = reinterpret_cast<const float *>(score_tensor->data_c()); // batch, class, num
|
||||
if (scores == nullptr) {
|
||||
MS_LOG(ERROR) << "score tensor data nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int batch_num = score_dims[kBatchIndex];
|
||||
constexpr size_t kClassIndex = 1;
|
||||
int class_num = score_dims[kClassIndex];
|
||||
int box_num = score_dims[kScoreDimsBoxNumIndex];
|
||||
float *scores_data = reinterpret_cast<float *>(score_tensor->data_c());
|
||||
if (scores_data == nullptr) {
|
||||
MS_LOG(ERROR) << "score tensor data nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float *box_data = reinterpret_cast<float *>(box_tensor->data_c());
|
||||
if (box_data == nullptr) {
|
||||
MS_LOG(ERROR) << "box tensor data nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::vector<NMSBox> selected_box_per_class;
|
||||
selected_box_per_class.reserve(std::min(static_cast<int32_t>(box_num), max_output_per_class_));
|
||||
std::vector<NMSIndex> selected_index;
|
||||
|
||||
for (auto i = 0; i < batch_num; ++i) {
|
||||
int batch_offset = i * class_num * box_num;
|
||||
for (auto j = 0; j < class_num; ++j) {
|
||||
// per batch per class filter
|
||||
float *per_class_scores = scores_data + batch_offset + j * box_num;
|
||||
float *box = box_data + i * box_num * kBoxPointNum;
|
||||
std::vector<NMSBox> above_score_candidates;
|
||||
above_score_candidates.reserve(box_num);
|
||||
for (auto k = 0; k < box_num; ++k) {
|
||||
if (per_class_scores[k] > score_threshold_) {
|
||||
above_score_candidates.emplace_back(per_class_scores[k], k, center_point_box_, box[0], box[1], box[2],
|
||||
box[3]);
|
||||
}
|
||||
box += kBoxPointNum;
|
||||
}
|
||||
std::priority_queue<NMSBox, std::vector<NMSBox>, std::less<NMSBox>> sorted_candidates(
|
||||
std::less<NMSBox>(), std::move(above_score_candidates));
|
||||
|
||||
selected_box_per_class.clear();
|
||||
while (!sorted_candidates.empty() && static_cast<int32_t>(selected_index.size()) < max_output_per_class_) {
|
||||
auto cand = sorted_candidates.top();
|
||||
bool selected = true;
|
||||
auto IoUSuppressed = [this, &cand](const NMSBox &box) {
|
||||
float intersec_x1 = std::max(cand.x1_, box.x1_);
|
||||
float intersec_x2 = std::min(cand.x2_, box.x2_);
|
||||
float intersec_y1 = std::max(cand.y1_, box.y1_);
|
||||
float intersec_y2 = std::max(cand.y2_, box.y2_);
|
||||
const float intersec_area =
|
||||
std::max(intersec_x2 - intersec_x1, 0.0f) * std::max(intersec_y2 - intersec_y1, 0.0f);
|
||||
if (intersec_area <= 0.0f) {
|
||||
return false;
|
||||
}
|
||||
const float intersec_over_union = intersec_area / (cand.area_ + box.area_ - intersec_area);
|
||||
return intersec_over_union > this->iou_threshold_;
|
||||
};
|
||||
if (std::any_of(selected_box_per_class.begin(), selected_box_per_class.end(), IoUSuppressed)) {
|
||||
selected = false;
|
||||
}
|
||||
if (selected) {
|
||||
selected_box_per_class.push_back(cand);
|
||||
selected_index.emplace_back(
|
||||
NMSIndex{static_cast<int32_t>(i), static_cast<int32_t>(j), static_cast<int32_t>(cand.index_)});
|
||||
}
|
||||
sorted_candidates.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto output = out_tensors_.at(0);
|
||||
int selected_num = static_cast<int>(selected_index.size());
|
||||
const int output_last_dim = 3;
|
||||
output->set_shape({selected_num, output_last_dim});
|
||||
MS_ASSERT(output_last_dim * sizeof(int32_t) == sizeof(NMSIndex));
|
||||
int32_t *out_data = reinterpret_cast<int32_t *>(output->MutableData());
|
||||
memcpy(out_data, selected_index.data(), selected_index.size() * sizeof(NMSIndex));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuNonMaxSuppressionFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "NonMaxSuppression opParameter nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (desc.type != schema::PrimitiveType_NonMaxSuppression) {
|
||||
MS_LOG(ERROR) << "OneHot desc type should be " << schema::PrimitiveType_NonMaxSuppression << " got " << desc.type;
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) NonMaxSuppressionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot new kernel failed.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonMaxSuppression, CpuNonMaxSuppressionFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,93 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "mindspore/lite/nnacl/non_max_suppression_parameter.h"
|
||||
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class NonMaxSuppressionCPUKernel : public LiteKernel {
|
||||
public:
|
||||
NonMaxSuppressionCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
|
||||
~NonMaxSuppressionCPUKernel() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return RET_OK; };
|
||||
int PreProcess() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
int GetParams();
|
||||
|
||||
private:
|
||||
int center_point_box_;
|
||||
float iou_threshold_;
|
||||
float score_threshold_;
|
||||
int32_t max_output_per_class_;
|
||||
NMSParameter *param_ = nullptr;
|
||||
};
|
||||
|
||||
typedef struct NMSIndex {
|
||||
int32_t batch_index_;
|
||||
int32_t class_index_;
|
||||
int32_t box_index_;
|
||||
} NMSIndex;
|
||||
|
||||
class NMSBox {
|
||||
public:
|
||||
NMSBox() = default;
|
||||
~NMSBox() = default;
|
||||
explicit NMSBox(float score, int box_index, int center_point_box, float y_a, float x_a, float y_b, float x_b)
|
||||
: score_(score), index_(box_index) {
|
||||
if (0 == center_point_box) {
|
||||
y1_ = std::min(y_a, y_b);
|
||||
y2_ = std::max(y_a, y_b);
|
||||
x1_ = std::min(x_a, x_b);
|
||||
x2_ = std::max(x_a, x_b);
|
||||
} else {
|
||||
// x_center, y_center, width, height
|
||||
float half_wid = x_b / 2;
|
||||
x1_ = x_a - half_wid;
|
||||
x2_ = x_a + half_wid;
|
||||
float half_height = y_b / 2;
|
||||
y1_ = y_a - half_height;
|
||||
y2_ = y_a + half_height;
|
||||
}
|
||||
area_ = (y2_ - y1_) * (x2_ - x1_);
|
||||
}
|
||||
inline bool operator<(const NMSBox &box) const { return score_ < box.score_; }
|
||||
|
||||
public:
|
||||
float score_;
|
||||
int index_;
|
||||
float y1_; // y1 x1 y2 x2 ascending order
|
||||
float y2_;
|
||||
float x1_;
|
||||
float x2_;
|
||||
float area_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NON_MAX_SUPPRESSION_H_
|
@ -0,0 +1,116 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/non_max_suppression.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/lite_kernel.h"
|
||||
using mindspore::schema::Format_NHWC;
|
||||
|
||||
namespace mindspore {
|
||||
class TestNMSFp32 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestNMSFp32() = default;
|
||||
void Init(const std::vector<int> &box_tensor_shape, float *box_data, const std::vector<int> &score_tensor_shape,
|
||||
float *score_data, int32_t max_output, float iou_threshold, float score_threshold, int center_box_point);
|
||||
void TearDown() override;
|
||||
|
||||
public:
|
||||
float err_tol_ = 1e-5;
|
||||
lite::Tensor box_tensor_;
|
||||
lite::Tensor score_tensor_;
|
||||
lite::Tensor max_output_box_per_class_tensor_;
|
||||
lite::Tensor iou_threshold_tensor_;
|
||||
lite::Tensor score_threshold_tensor_;
|
||||
lite::Tensor out_tensor_;
|
||||
int32_t max_output_;
|
||||
float iou_threshold_;
|
||||
float score_threshold_;
|
||||
std::vector<lite::Tensor *> inputs_{&box_tensor_, &score_tensor_, &max_output_box_per_class_tensor_,
|
||||
&iou_threshold_tensor_, &score_threshold_tensor_};
|
||||
std::vector<lite::Tensor *> outputs_{&out_tensor_};
|
||||
NMSParameter param_;
|
||||
kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_NonMaxSuppression};
|
||||
lite::InnerContext ctx_ = lite::InnerContext();
|
||||
kernel::KernelCreator creator_ = nullptr;
|
||||
kernel::LiteKernel *kernel_ = nullptr;
|
||||
};
|
||||
|
||||
void TestNMSFp32::TearDown() {
|
||||
box_tensor_.SetData(nullptr);
|
||||
score_tensor_.SetData(nullptr);
|
||||
max_output_box_per_class_tensor_.SetData(nullptr);
|
||||
iou_threshold_tensor_.SetData(nullptr);
|
||||
score_threshold_tensor_.SetData(nullptr);
|
||||
out_tensor_.FreeData();
|
||||
}
|
||||
|
||||
void TestNMSFp32::Init(const std::vector<int> &box_tensor_shape, float *box_data,
|
||||
const std::vector<int> &score_tensor_shape, float *score_data, int32_t max_output,
|
||||
float iou_threshold, float score_threshold, int center_box_point) {
|
||||
box_tensor_.set_data_type(kNumberTypeFloat32);
|
||||
box_tensor_.SetFormat(Format_NHWC);
|
||||
box_tensor_.set_shape(box_tensor_shape);
|
||||
box_tensor_.SetData(box_data);
|
||||
|
||||
score_tensor_.set_data_type(kNumberTypeFloat32);
|
||||
score_tensor_.SetFormat(Format_NHWC);
|
||||
score_tensor_.set_shape(score_tensor_shape);
|
||||
score_tensor_.SetData(score_data);
|
||||
|
||||
max_output_ = max_output;
|
||||
max_output_box_per_class_tensor_.SetData(&max_output_);
|
||||
iou_threshold_ = iou_threshold;
|
||||
iou_threshold_tensor_.SetData(&iou_threshold_);
|
||||
score_threshold_ = score_threshold;
|
||||
score_threshold_tensor_.SetData(&score_threshold_);
|
||||
|
||||
out_tensor_.set_data_type(kNumberTypeInt32);
|
||||
|
||||
param_.center_point_box_ = center_box_point;
|
||||
ctx_ = lite::InnerContext();
|
||||
ASSERT_EQ(lite::RET_OK, ctx_.Init());
|
||||
creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc_);
|
||||
ASSERT_NE(creator_, nullptr);
|
||||
kernel_ = creator_(inputs_, outputs_, reinterpret_cast<OpParameter *>(¶m_), &ctx_, desc_, nullptr);
|
||||
ASSERT_NE(kernel_, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestNMSFp32, TestCase1) {
|
||||
std::vector<int> box_tensor_shape{1, 6, 4}; // batch 1, num 6, box coord 4
|
||||
float box_data[24] = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.0f, 0.1f, 0.0f, 1.0f, 1.0f,
|
||||
0.0f, 3.0f, 1.0f, 1.0f, 0.0f, 3.1f, 1.0f, 1.0f, 0.0f, 6.0f, 1.0f, 1.0f};
|
||||
std::vector<int> score_tensor_shape{1, 1, 6}; // batch 1, class 1, num 6
|
||||
float score_data[6] = {0.9f, 0.8f, 0.7f, 0.95f, 0.6f, 0.5f};
|
||||
int64_t max_output = 3;
|
||||
float iou_threshold = 0.5f;
|
||||
float score_threshold = 0.0f;
|
||||
int center_box_point = 1;
|
||||
auto output_size = 9;
|
||||
|
||||
Init(box_tensor_shape, box_data, score_tensor_shape, score_data, max_output, iou_threshold, score_threshold,
|
||||
center_box_point);
|
||||
auto ret = kernel_->PreProcess();
|
||||
EXPECT_EQ(0, ret);
|
||||
ret = kernel_->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
std::vector<int32_t> expect{0, 0, 3, 0, 0, 0, 0, 0, 5};
|
||||
CompareOutputData(reinterpret_cast<int32_t *>(out_tensor_.data_c()), expect.data(), output_size, err_tol_);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
Loading…
Reference in new issue