refactor dpp

pull/7967/head
wangzhe 4 years ago
parent 5e039bfaad
commit e4a3b7b60c

@ -38,9 +38,11 @@ typedef struct DetectionPostProcessParameter {
void *decoded_boxes_; void *decoded_boxes_;
void *nms_candidate_; void *nms_candidate_;
void *indexes_; void *indexes_;
void *scores_;
void *all_class_indexes_;
void *all_class_scores_;
void *single_class_indexes_;
void *selected_; void *selected_;
void *score_with_class_;
void *score_with_class_all_;
} DetectionPostProcessParameter; } DetectionPostProcessParameter;
#endif // MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ #endif // MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_

File diff suppressed because it is too large Load Diff

@ -34,18 +34,24 @@ typedef struct {
float xmax; float xmax;
} BboxCorner; } BboxCorner;
typedef struct {
float score;
int index;
} ScoreWithIndex;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors,
DetectionPostProcessParameter *param);
int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
void (*)(const float *, int *, int, int), const DetectionPostProcessParameter *param,
const int task_id, const int thread_num);
int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
const float *decoded_boxes, float *output_boxes, float *output_classes,
float *output_scores, float *output_num, void (*)(const float *, int *, int, int),
const DetectionPostProcessParameter *param);
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
const float *input_scores, float *input_anchors, float *output_boxes, float *output_classes, float *output_boxes, float *output_classes, float *output_scores, float *output_num,
float *output_scores, float *output_num, DetectionPostProcessParameter *param); void (*)(const float *, int *, int, int), const DetectionPostProcessParameter *param);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -30,18 +30,27 @@ class DetectionPostProcessBaseCPUKernel : public LiteKernel {
DetectionPostProcessBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, DetectionPostProcessBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {} : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {
params_ = reinterpret_cast<DetectionPostProcessParameter *>(parameter);
}
virtual ~DetectionPostProcessBaseCPUKernel(); virtual ~DetectionPostProcessBaseCPUKernel();
int Init() override; int Init() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
protected: int thread_num_;
float *input_boxes = nullptr; int num_boxes_;
float *input_scores = nullptr; int num_classes_with_bg_;
float *input_boxes_ = nullptr;
float *input_scores_ = nullptr;
DetectionPostProcessParameter *params_ = nullptr;
protected:
virtual int GetInputData() = 0; virtual int GetInputData() = 0;
private:
void FreeAllocatedBuffer();
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_BASE_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_BASE_H_

@ -33,8 +33,8 @@ int DetectionPostProcessCPUKernel::GetInputData() {
MS_LOG(ERROR) << "Input data type error"; MS_LOG(ERROR) << "Input data type error";
return RET_ERROR; return RET_ERROR;
} }
input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); input_boxes_ = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); input_scores_ = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
return RET_OK; return RET_OK;
} }

@ -27,8 +27,30 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess;
namespace mindspore::kernel { namespace mindspore::kernel {
int DetectionPostProcessInt8CPUKernel::DequantizeInt8ToFp32(const int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, quant_size_ - task_id * thread_n_stride_);
int thread_offset = task_id * thread_n_stride_;
int ret = DoDequantizeInt8ToFp32(data_int8_ + thread_offset, data_fp32_ + thread_offset, quant_param_.scale,
quant_param_.zeroPoint, num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DequantizeInt8ToFp32Run(void *cdata, int task_id) {
auto KernelData = reinterpret_cast<DetectionPostProcessInt8CPUKernel *>(cdata);
auto ret = KernelData->DequantizeInt8ToFp32(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCastRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **data) { int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **data) {
auto data_int8 = reinterpret_cast<int8_t *>(tensor->MutableData()); data_int8_ = reinterpret_cast<int8_t *>(tensor->MutableData());
*data = reinterpret_cast<float *>(context_->allocator->Malloc(tensor->ElementsNum() * sizeof(float))); *data = reinterpret_cast<float *>(context_->allocator->Malloc(tensor->ElementsNum() * sizeof(float)));
if (*data == nullptr) { if (*data == nullptr) {
MS_LOG(ERROR) << "Malloc data failed."; MS_LOG(ERROR) << "Malloc data failed.";
@ -38,8 +60,17 @@ int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **
MS_LOG(ERROR) << "null quant param"; MS_LOG(ERROR) << "null quant param";
return RET_ERROR; return RET_ERROR;
} }
auto quant_param = tensor->GetQuantParams().front(); quant_param_ = tensor->GetQuantParams().front();
DoDequantizeInt8ToFp32(data_int8, *data, quant_param.scale, quant_param.zeroPoint, tensor->ElementsNum()); data_fp32_ = *data;
quant_size_ = tensor->ElementsNum();
thread_n_stride_ = UP_DIV(quant_size_, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, DequantizeInt8ToFp32Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCastRun error error_code[" << ret << "]";
context_->allocator->Free(*data);
return RET_ERROR;
}
return RET_OK; return RET_OK;
} }
int DetectionPostProcessInt8CPUKernel::GetInputData() { int DetectionPostProcessInt8CPUKernel::GetInputData() {
@ -47,11 +78,11 @@ int DetectionPostProcessInt8CPUKernel::GetInputData() {
MS_LOG(ERROR) << "Input data type error"; MS_LOG(ERROR) << "Input data type error";
return RET_ERROR; return RET_ERROR;
} }
int status = Dequantize(in_tensors_.at(0), &input_boxes); int status = Dequantize(in_tensors_.at(0), &input_boxes_);
if (status != RET_OK) { if (status != RET_OK) {
return status; return status;
} }
status = Dequantize(in_tensors_.at(1), &input_scores); status = Dequantize(in_tensors_.at(1), &input_scores_);
if (status != RET_OK) { if (status != RET_OK) {
return status; return status;
} }

@ -34,9 +34,16 @@ class DetectionPostProcessInt8CPUKernel : public DetectionPostProcessBaseCPUKern
: DetectionPostProcessBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} : DetectionPostProcessBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~DetectionPostProcessInt8CPUKernel() = default; ~DetectionPostProcessInt8CPUKernel() = default;
int8_t *data_int8_ = nullptr;
float *data_fp32_ = nullptr;
lite::QuantArg quant_param_;
int quant_size_;
int thread_n_stride_;
int DequantizeInt8ToFp32(const int task_id);
private: private:
int GetInputData();
int Dequantize(lite::Tensor *tensor, float **data); int Dequantize(lite::Tensor *tensor, float **data);
int GetInputData();
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DETECTION_POST_PROCESS_INT8_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DETECTION_POST_PROCESS_INT8_H_

Loading…
Cancel
Save