!7967 optimize detection_post_process op

Merge pull request !7967 from wangzhe/dpp_refactor
pull/7996/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 691b5fdea0

@ -38,9 +38,11 @@ typedef struct DetectionPostProcessParameter {
void *decoded_boxes_;
void *nms_candidate_;
void *indexes_;
void *scores_;
void *all_class_indexes_;
void *all_class_scores_;
void *single_class_indexes_;
void *selected_;
void *score_with_class_;
void *score_with_class_all_;
} DetectionPostProcessParameter;
#endif // MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_

File diff suppressed because it is too large Load Diff

@ -34,18 +34,24 @@ typedef struct {
float xmax;
} BboxCorner;
typedef struct {
float score;
int index;
} ScoreWithIndex;
#ifdef __cplusplus
extern "C" {
#endif
int DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors,
DetectionPostProcessParameter *param);
int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
void (*)(const float *, int *, int, int), const DetectionPostProcessParameter *param,
const int task_id, const int thread_num);
int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
const float *decoded_boxes, float *output_boxes, float *output_classes,
float *output_scores, float *output_num, void (*)(const float *, int *, int, int),
const DetectionPostProcessParameter *param);
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes,
const float *input_scores, float *input_anchors, float *output_boxes, float *output_classes,
float *output_scores, float *output_num, DetectionPostProcessParameter *param);
int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
float *output_boxes, float *output_classes, float *output_scores, float *output_num,
void (*)(const float *, int *, int, int), const DetectionPostProcessParameter *param);
#ifdef __cplusplus
}
#endif

@ -30,18 +30,27 @@ class DetectionPostProcessBaseCPUKernel : public LiteKernel {
DetectionPostProcessBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {
params_ = reinterpret_cast<DetectionPostProcessParameter *>(parameter);
}
virtual ~DetectionPostProcessBaseCPUKernel();
int Init() override;
int ReSize() override;
int Run() override;
protected:
float *input_boxes = nullptr;
float *input_scores = nullptr;
int thread_num_;
int num_boxes_;
int num_classes_with_bg_;
float *input_boxes_ = nullptr;
float *input_scores_ = nullptr;
DetectionPostProcessParameter *params_ = nullptr;
protected:
virtual int GetInputData() = 0;
private:
void FreeAllocatedBuffer();
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_BASE_H_

@ -33,8 +33,8 @@ int DetectionPostProcessCPUKernel::GetInputData() {
MS_LOG(ERROR) << "Input data type error";
return RET_ERROR;
}
input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
input_boxes_ = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
input_scores_ = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
return RET_OK;
}

@ -27,8 +27,30 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess;
namespace mindspore::kernel {
int DetectionPostProcessInt8CPUKernel::DequantizeInt8ToFp32(const int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, quant_size_ - task_id * thread_n_stride_);
int thread_offset = task_id * thread_n_stride_;
int ret = DoDequantizeInt8ToFp32(data_int8_ + thread_offset, data_fp32_ + thread_offset, quant_param_.scale,
quant_param_.zeroPoint, num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DequantizeInt8ToFp32Run(void *cdata, int task_id) {
auto KernelData = reinterpret_cast<DetectionPostProcessInt8CPUKernel *>(cdata);
auto ret = KernelData->DequantizeInt8ToFp32(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCastRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **data) {
auto data_int8 = reinterpret_cast<int8_t *>(tensor->MutableData());
data_int8_ = reinterpret_cast<int8_t *>(tensor->MutableData());
*data = reinterpret_cast<float *>(context_->allocator->Malloc(tensor->ElementsNum() * sizeof(float)));
if (*data == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
@ -38,8 +60,17 @@ int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **
MS_LOG(ERROR) << "null quant param";
return RET_ERROR;
}
auto quant_param = tensor->GetQuantParams().front();
DoDequantizeInt8ToFp32(data_int8, *data, quant_param.scale, quant_param.zeroPoint, tensor->ElementsNum());
quant_param_ = tensor->GetQuantParams().front();
data_fp32_ = *data;
quant_size_ = tensor->ElementsNum();
thread_n_stride_ = UP_DIV(quant_size_, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, DequantizeInt8ToFp32Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCastRun error error_code[" << ret << "]";
context_->allocator->Free(*data);
return RET_ERROR;
}
return RET_OK;
}
int DetectionPostProcessInt8CPUKernel::GetInputData() {
@ -47,11 +78,11 @@ int DetectionPostProcessInt8CPUKernel::GetInputData() {
MS_LOG(ERROR) << "Input data type error";
return RET_ERROR;
}
int status = Dequantize(in_tensors_.at(0), &input_boxes);
int status = Dequantize(in_tensors_.at(0), &input_boxes_);
if (status != RET_OK) {
return status;
}
status = Dequantize(in_tensors_.at(1), &input_scores);
status = Dequantize(in_tensors_.at(1), &input_scores_);
if (status != RET_OK) {
return status;
}

@ -34,9 +34,16 @@ class DetectionPostProcessInt8CPUKernel : public DetectionPostProcessBaseCPUKern
: DetectionPostProcessBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~DetectionPostProcessInt8CPUKernel() = default;
int8_t *data_int8_ = nullptr;
float *data_fp32_ = nullptr;
lite::QuantArg quant_param_;
int quant_size_;
int thread_n_stride_;
int DequantizeInt8ToFp32(const int task_id);
private:
int GetInputData();
int Dequantize(lite::Tensor *tensor, float **data);
int GetInputData();
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DETECTION_POST_PROCESS_INT8_H_

Loading…
Cancel
Save