From d277d3f5ae5471b7084db8bc977709bc4c6e5f83 Mon Sep 17 00:00:00 2001 From: hangq Date: Thu, 27 Aug 2020 14:35:38 +0800 Subject: [PATCH] fix memory leak --- .../java/com/mindspore/lite/MSTensor.java | 8 +- .../lite/java/native/runtime/context.cpp | 7 +- .../lite/java/native/runtime/ms_tensor.cpp | 24 +- mindspore/lite/src/kernel_registry.cc | 4 +- mindspore/lite/src/lite_session.cc | 10 +- mindspore/lite/src/model.cc | 8 +- mindspore/lite/src/param_value_lite.h | 4 +- .../kernel/arm/fp32_grad/arithmetic_grad.cc | 2 +- .../kernel/opencl/subgraph_opencl_kernel.cc | 8 +- .../src/runtime/opencl/opencl_executor.cc | 2 +- mindspore/lite/src/scheduler.cc | 12 +- .../tflite/tflite_parsers_test_utils.cc | 2 +- mindspore/lite/tools/benchmark/benchmark.cc | 18 +- mindspore/lite/tools/benchmark/benchmark.h | 1 - mindspore/lite/tools/common/graph_util.cc | 6 +- mindspore/lite/tools/common/node_util.h | 2 +- mindspore/lite/tools/common/tensor_util.cc | 6 +- .../lite/tools/converter/anf_transform.cc | 34 +- .../lite/tools/converter/anf_transform.h | 12 +- mindspore/lite/tools/converter/converter.cc | 39 +-- mindspore/lite/tools/converter/converter.h | 4 - .../tools/converter/graphdef_transform.cc | 8 - .../fusion/matmul_biasadd_fusion_pass.h | 2 +- .../legacy_optimizer/graph/dtype_trans_pass.h | 2 +- mindspore/lite/tools/converter/model_parser.h | 18 +- .../parser/caffe/caffe_model_parser.cc | 16 +- .../parser/caffe/caffe_model_parser.h | 2 +- .../parser/onnx/onnx_model_parser.cc | 4 +- .../converter/parser/onnx/onnx_model_parser.h | 2 +- .../parser/tflite/tflite_arithmetic_parser.cc | 38 +-- .../parser/tflite/tflite_model_parser.cc | 48 +-- .../parser/tflite/tflite_model_parser.h | 3 +- .../quantizer/post_training_quantizer.cc | 311 ++++++++---------- .../quantizer/post_training_quantizer.h | 47 ++- .../fusion/constant_folding_fusion.cc | 68 ++-- .../fusion/constant_folding_fusion.h | 5 +- 36 files changed, 396 insertions(+), 391 deletions(-) diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java index 0dcf8d2049..f74178bf37 100644 --- a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Huawei Technologies Co., Ltd - *

+ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - *

+ * * http://www.apache.org/licenses/LICENSE-2.0 - *

+ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -52,7 +52,7 @@ public class MSTensor { this.setDataType(this.tensorPtr, dataType); } - public byte[] getBtyeData() { + public byte[] getByteData() { return this.getByteData(this.tensorPtr); } diff --git a/mindspore/lite/java/native/runtime/context.cpp b/mindspore/lite/java/native/runtime/context.cpp index eaedc4d26a..7898eacd2f 100644 --- a/mindspore/lite/java/native/runtime/context.cpp +++ b/mindspore/lite/java/native/runtime/context.cpp @@ -18,6 +18,7 @@ #include #include "common/ms_log.h" #include "include/context.h" +#include "include/thread_pool_config.h" extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_createContext(JNIEnv *env, jobject thiz, jint device_type, @@ -44,13 +45,13 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_creat } switch (cpu_bind_mode) { case -1: - context->cpu_bind_mode_ = mindspore::lite::MID_CPU; + context->cpu_bind_mode_ = MID_CPU; break; case 0: - context->cpu_bind_mode_ = mindspore::lite::NO_BIND; + context->cpu_bind_mode_ = NO_BIND; break; case 1: - context->cpu_bind_mode_ = mindspore::lite::HIGHER_CPU; + context->cpu_bind_mode_ = HIGHER_CPU; break; default: MS_LOGE("Invalid cpu_bind_mode : %d", cpu_bind_mode); diff --git a/mindspore/lite/java/native/runtime/ms_tensor.cpp b/mindspore/lite/java/native/runtime/ms_tensor.cpp index d71aca9f41..3a42f810a1 100644 --- a/mindspore/lite/java/native/runtime/ms_tensor.cpp +++ b/mindspore/lite/java/native/runtime/ms_tensor.cpp @@ -118,9 +118,9 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByte return env->NewByteArray(0); } - auto local_data_size = ms_tensor_ptr->Size(); - auto ret = env->NewByteArray(local_data_size); - env->SetByteArrayRegion(ret, 0, local_data_size, local_data); + auto local_element_num = ms_tensor_ptr->ElementsNum(); + auto ret = env->NewByteArray(local_element_num); + env->SetByteArrayRegion(ret, 0, local_element_num, local_data); return ret; } @@ -144,9 +144,9 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLong MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type()); return env->NewLongArray(0); } - auto local_data_size = ms_tensor_ptr->Size(); - auto ret = env->NewLongArray(local_data_size); - env->SetLongArrayRegion(ret, 0, local_data_size, local_data); + auto local_element_num = ms_tensor_ptr->ElementsNum(); + auto ret = env->NewLongArray(local_element_num); + env->SetLongArrayRegion(ret, 0, local_element_num, local_data); return ret; } @@ -170,9 +170,9 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntDa MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type()); return env->NewIntArray(0); } - auto local_data_size = ms_tensor_ptr->Size(); - auto ret = env->NewIntArray(local_data_size); - env->SetIntArrayRegion(ret, 0, local_data_size, local_data); + auto local_element_num = ms_tensor_ptr->ElementsNum(); + auto ret = env->NewIntArray(local_element_num); + env->SetIntArrayRegion(ret, 0, local_element_num, local_data); return ret; } @@ -196,9 +196,9 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFlo MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type()); return env->NewFloatArray(0); } - auto local_data_size = ms_tensor_ptr->Size(); - auto ret = env->NewFloatArray(local_data_size); - env->SetFloatArrayRegion(ret, 0, local_data_size, local_data); + auto local_element_num = ms_tensor_ptr->ElementsNum(); + auto ret = env->NewFloatArray(local_element_num); + env->SetFloatArrayRegion(ret, 0, local_element_num, local_data); return ret; } diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index d5a8f1a4d7..d283877516 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -100,8 +100,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector &out_tensors, const PrimitiveC *primitive, const Context *ctx, const kernel::KernelKey &key) { - MS_EXCEPTION_IF_NULL(primitive); - MS_EXCEPTION_IF_NULL(ctx); + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != ctx); auto parameter = kernel::PopulateParameter(primitive); if (parameter == nullptr) { MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index d08ecb7ae9..7c09737fad 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -33,9 +33,9 @@ namespace mindspore { namespace lite { int LiteSession::ConvertTensors(const lite::Model *model) { - MS_EXCEPTION_IF_NULL(model); + MS_ASSERT(nullptr != model); auto meta_graph = model->GetMetaGraph(); - MS_EXCEPTION_IF_NULL(meta_graph); + MS_ASSERT(nullptr != meta_graph); uint32_t tensorCount = meta_graph->allTensors()->size(); for (uint32_t i = 0; i < tensorCount; i++) { auto *srcTensor = meta_graph->allTensors()->GetAs(i); @@ -246,7 +246,7 @@ int LiteSession::CompileGraph(Model *model) { std::vector LiteSession::GetInputs() const { return this->input_vec_; } int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { - MS_EXCEPTION_IF_NULL(this->context_); + MS_ASSERT(this->context_); if (before == nullptr && after == nullptr) { return executor->Run(this->inputs_, this->outputs_, this->kernels_, this->context_->allocator.get()); } else { @@ -255,7 +255,7 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session:: } int LiteSession::Init(Context *context) { - MS_EXCEPTION_IF_NULL(context); + MS_ASSERT(nullptr != context); this->context_ = new (std::nothrow) Context(context->thread_num_, context->allocator, context->device_ctx_); if (this->context_ == nullptr) { MS_LOG(ERROR) << "new context failed"; @@ -276,7 +276,7 @@ int LiteSession::Init(Context *context) { } #endif executor = new Executor(); - MS_EXCEPTION_IF_NULL(executor); + MS_ASSERT(nullptr != executor); return RET_OK; } diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 75e34db8b3..797f65b2bd 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -101,7 +101,7 @@ int ModelImpl::BuildOps() { MS_LOG(ERROR) << "mete_graph is nullptr"; return -1; } - MS_EXCEPTION_IF_NULL(meta_graph_->nodes()); + MS_ASSERT(nullptr != meta_graph_->nodes()); for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) { auto cNode = meta_graph_->nodes()->GetAs(i); auto name = cNode->name()->str(); @@ -129,17 +129,17 @@ Model *Model::Import(const char *model_buf, size_t size) { Model::~Model() { delete (this->model_impl_); } mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const { - MS_EXCEPTION_IF_NULL(model_impl_); + MS_ASSERT(nullptr != model_impl_); return const_cast(model_impl_->GetOp(name)); } void Model::FreeMetaGraph() { - MS_EXCEPTION_IF_NULL(model_impl_); + MS_ASSERT(nullptr != model_impl_); return model_impl_->FreeMetaGraph(); } const schema::MetaGraph *Model::GetMetaGraph() const { - MS_EXCEPTION_IF_NULL(model_impl_); + MS_ASSERT(nullptr != model_impl_); return model_impl_->meta_graph(); } diff --git a/mindspore/lite/src/param_value_lite.h b/mindspore/lite/src/param_value_lite.h index f747042575..a57c4b8140 100644 --- a/mindspore/lite/src/param_value_lite.h +++ b/mindspore/lite/src/param_value_lite.h @@ -31,8 +31,8 @@ class ParamValueLite : public Value { ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} virtual ~ParamValueLite() { if (tensor_addr_ != nullptr) { - auto tensor_mem = reinterpret_cast(tensor_addr_); - delete tensor_mem; + auto tensor_mem = reinterpret_cast(tensor_addr_); + delete[](tensor_mem); tensor_addr_ = nullptr; tensor_size_ = 0; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc index 83edebb78f..7dd8f64357 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc @@ -277,7 +277,7 @@ kernel::LiteKernel *CpuArithmeticGradFp32KernelCreator(const std::vectorset_allocator(allocator_); } for (auto input_kernel : kernel->in_kernels()) { - MS_EXCEPTION_IF_NULL(input_kernel); + MS_ASSERT(nullptr != input_kernel); auto ret = input_kernel->DecOutTensorRefCount(); if (0 != ret) { MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; @@ -214,21 +214,21 @@ int SubGraphOpenCLKernel::MallocTensorWithReuse() { } } for (auto kernel : out_kernels_) { - MS_EXCEPTION_IF_NULL(kernel); + MS_ASSERT(nullptr != kernel); auto ret = kernel->DecOutTensorRefCount(); if (0 != ret) { MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; } } for (auto kernel : in_convert_ops_) { - MS_EXCEPTION_IF_NULL(kernel); + MS_ASSERT(nullptr != kernel); auto ret = kernel->DecOutTensorRefCount(); if (0 != ret) { MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; } } for (auto kernel : out_convert_ops_) { - MS_EXCEPTION_IF_NULL(kernel); + MS_ASSERT(nullptr != kernel); auto ret = kernel->DecOutTensorRefCount(); if (0 != ret) { MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/opencl/opencl_executor.cc index 500f22f52a..170e937db0 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_executor.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.cc @@ -65,7 +65,7 @@ int OpenCLExecutor::Run(std::vector &inputs, std::vectorin_kernels()) { - MS_EXCEPTION_IF_NULL(input_kernel); + MS_ASSERT(nullptr != input_kernel); ret = input_kernel->DecOutTensorRefCount(); if (0 != ret) { MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 6232875309..1198729422 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -77,10 +77,10 @@ int Scheduler::ReSizeKernels(const std::vector &kernels) { } int Scheduler::InferShape(const lite::Model *model, std::vector *tensors) { - MS_EXCEPTION_IF_NULL(model); - MS_EXCEPTION_IF_NULL(tensors); + MS_ASSERT(nullptr != model); + MS_ASSERT(nullptr != tensors); auto meta_graph = model->GetMetaGraph(); - MS_EXCEPTION_IF_NULL(meta_graph); + MS_ASSERT(nullptr != meta_graph); bool infer_shape_interrupt = false; uint32_t kernelCount = meta_graph->nodes()->size(); for (uint32_t i = 0; i < kernelCount; i++) { @@ -121,10 +121,10 @@ int Scheduler::InferShape(const lite::Model *model, std::vector *tensors, std::vector *kernels) { - MS_EXCEPTION_IF_NULL(model); - MS_EXCEPTION_IF_NULL(tensors); + MS_ASSERT(nullptr != model); + MS_ASSERT(nullptr != tensors); auto meta_graph = model->GetMetaGraph(); - MS_EXCEPTION_IF_NULL(meta_graph); + MS_ASSERT(nullptr != meta_graph); uint32_t kernelCount = meta_graph->nodes()->size(); auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph); for (uint32_t i = 0; i < kernelCount; i++) { diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc index 66ec6ad1ec..6f7bc265a4 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc @@ -23,7 +23,7 @@ namespace mindspore { schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) { lite::TfliteModelParser parser; - meta_graph = parser.Parse(model_path, weight_path); + meta_graph = parser.ParseToFb(model_path, weight_path); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; return nullptr; diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index 8b288a0b99..be5c8520da 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -107,7 +107,7 @@ int Benchmark::ReadInputFile() { } auto inputData = cur_tensor->MutableData(); memcpy(inputData, binBuf, tensorDataSize); - delete binBuf; + delete[](binBuf); } } return RET_OK; @@ -455,6 +455,12 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { } if (!_flags->calibDataPath.empty()) { status = MarkAccuracy(); + for (auto &data : calibData) { + data.second->shape.clear(); + data.second->data.clear(); + delete data.second; + } + calibData.clear(); if (status != 0) { MS_LOG(ERROR) << "Run MarkAccuracy error: " << status; std::cout << "Run MarkAccuracy error: " << status << std::endl; @@ -472,16 +478,6 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { return status; } } - - if (cleanData) { - for (auto &data : calibData) { - data.second->shape.clear(); - data.second->data.clear(); - delete data.second; - } - calibData.clear(); - } - delete (session); delete (model); return RET_OK; diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h index b01b29b74c..0df9f5424b 100644 --- a/mindspore/lite/tools/benchmark/benchmark.h +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -138,7 +138,6 @@ class MS_API Benchmark { std::vector msInputs; std::unordered_map> msOutputs; std::unordered_map calibData; - bool cleanData = true; }; int MS_API RunBenchmark(int argc, const char **argv); diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index 71bd873582..333a20436a 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -35,7 +35,7 @@ OpDefCopyer GetSimpleOpCopyer() { newCNode->quantType = inCNode->quantType; newCNode->primitive = std::make_unique(); newCNode->primitive->value.type = inCNode->primitive->value.type; - return std::move(newCNode); + return newCNode; }; } @@ -96,7 +96,7 @@ std::vector GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size preNodeIdx.emplace_back(i); } } - return std::move(preNodeIdx); + return preNodeIdx; } std::vector GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { @@ -111,7 +111,7 @@ std::vector GetLinkedPostIdx(const schema::MetaGraphT &graphT, const siz postNodeIdx.emplace_back(i); } } - return std::move(postNodeIdx); + return postNodeIdx; } STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) { diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 7681405454..6f619b3a24 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -89,7 +89,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in MS_LOG(ERROR) << "Dim size invalid"; return RET_ERROR; } - std::unique_ptr buf(new (std::nothrow) T[count]); + std::unique_ptr buf(new (std::nothrow) T[count]); if (buf == nullptr) { MS_LOG(ERROR) << "new buf failed"; return RET_ERROR; diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index f0dfa304ed..a864518544 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -24,7 +24,7 @@ std::unique_ptr GetTensorQuantParam(const std::unique_ptr MS_ASSERT(tensor != nullptr); auto &quantParams = tensor->quantParams; if (!quantParams.empty()) { - return std::move(CopyQuantParamT(quantParams.front())); + return CopyQuantParamT(quantParams.front()); } else { return nullptr; } @@ -39,7 +39,7 @@ std::unique_ptr CopyQuantParamT(const std::unique_ptrmax = srcQuantParam->max; dstQuantParam->narrowRange = srcQuantParam->narrowRange; dstQuantParam->numBits = srcQuantParam->numBits; - return std::move(dstQuantParam); + return dstQuantParam; } size_t GetElementSize(const TensorT &tensor) { return GetElementSize(TypeId(tensor.dataType)); } @@ -87,7 +87,7 @@ std::unique_ptr CopyTensorDefT(const std::unique_ptr &oldTenso if (!oldTensor->quantParams.empty()) { newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor))); } - return std::move(newTensor); + return newTensor; } size_t GetRefCount(MetaGraphT *graphT, uint32_t tensorIdx) { diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 2501121890..fd2b0ae0e1 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -23,6 +23,8 @@ #include "tools/optimizer/fusion/conv_scale_fusion.h" #include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" +#include "tools/converter/quantizer/post_training_quantizer.h" +#include "tools/converter/quantizer/quant_cast.h" using std::string; namespace mindspore { @@ -31,10 +33,9 @@ AnfTransform::AnfTransform() = default; AnfTransform::~AnfTransform() = default; -void AnfTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } - -FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) { - // return old_graph; +FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) { + MS_ASSERT(nullptr != old_graph); + // fusion const_fold auto optimizer = std::make_shared(); auto pm = std::make_shared("anf fusion pass manager", false); pm->AddPass(std::make_shared()); @@ -47,6 +48,31 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) { pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(old_graph); + + // quant + if (config != nullptr && config->quantType == schema::QuantType_PostTraining) { + this->mQuantizer = std::make_unique(new_graph, config->configFile, 8); + if (mQuantizer == nullptr) { + MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; + return nullptr; + } + } + if (mQuantizer != nullptr) { + mQuantizer->flags = *config; + auto status = mQuantizer->DoQuantize(new_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "Quant failed " << status; + return nullptr; + } + quant::QuantCast quant_cast; + quant_cast.SetInputDataDType(kNumberTypeFloat32); + status = quant_cast.Run(new_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "add QuantCast error"; + return nullptr; + } + } + return new_graph; } } // namespace lite diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 3b393a15bc..491450740d 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -17,11 +17,12 @@ #ifndef MS_ANF_TRANSFORM_H #define MS_ANF_TRANSFORM_H +#include #include "schema/inner/model_generated.h" #include "tools/common/storage.h" #include "tools/converter/converter_flags.h" #include "ir/anf.h" - +#include "tools/converter/quantizer/quantizer.h" namespace mindspore { namespace lite { @@ -29,15 +30,12 @@ class AnfTransform { public: AnfTransform(); virtual ~AnfTransform(); - FuncGraphPtr Transform(const FuncGraphPtr &old_graph); - void SetGraphDef(schema::MetaGraphT *dstDef); - inline schema::MetaGraphT *GetOutput() { return graphDefT; } + FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); - protected: - schema::MetaGraphT *graphDefT = nullptr; + private: + std::unique_ptr mQuantizer = nullptr; }; } // namespace lite } // namespace mindspore #endif - diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 175e9cab46..c0dae10a15 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -70,41 +70,23 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { MS_ASSERT(nullptr != modelParser); const std::string modelFile = flag->modelFile; const std::string weightFile = flag->weightFile; - auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType); - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; - return nullptr; - } - graph = ModelParser::Fb2Anf(meta_graph); + graph = modelParser->Parse(modelFile, weightFile, flag->quantType); } if (graph == nullptr) { MS_LOG(ERROR) << "Parser/Import model return nullptr"; return nullptr; } - graph = anfTransform->Transform(graph); - - CreateQuantizer(graph, flag); - if (mQuantizer != nullptr) { - mQuantizer->flags = *flag; - auto status = mQuantizer->DoQuantize(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "Quant failed " << status; - return nullptr; - } - quant::QuantCast quant_cast; - quant_cast.SetInputDataDType(kNumberTypeFloat32); - status = quant_cast.Run(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "add QuantCast error"; - return nullptr; - } + graph = anfTransform->Transform(graph, flag); + if (graph == nullptr) { + MS_LOG(ERROR) << "Transform anf graph return nullptr"; + return nullptr; } // anf -- fb auto meta_graph = Export(graph); if (meta_graph == nullptr) { - MS_LOG(ERROR) << "Export to meta_graph return nullptr"; + MS_LOG(ERROR) << "Export to meta graph return nullptr"; return nullptr; } @@ -113,20 +95,13 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { transform->CreateQuantizer(flag); auto status = transform->Transform(*flag); if (status != 0) { - MS_LOG(ERROR) << "FBTransform model failed " << status; + MS_LOG(ERROR) << "Transform meta graph failed " << status; return nullptr; } return meta_graph; } -void Converter::CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags) { - auto type = flags->quantType; - if (type == mindspore::schema::QuantType_PostTraining) { - MS_LOG(INFO) << "create post training quantizer."; - mQuantizer.reset(new quant::PostTrainingQuantizer(func_graph, flags->configFile, 8)); - } -} int RunConverter(int argc, const char **argv) { std::unique_ptr flags(new (std::nothrow) converter::Flags); if (flags == nullptr) { diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index e71b34ce97..7deb65c9b7 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -25,7 +25,6 @@ #include "tools/anf_importer/anf_importer.h" #include "tools/converter/converter_flags.h" #include "tools/converter/anf_transform.h" -#include "tools/converter/quantizer/quantizer.h" namespace mindspore { namespace lite { @@ -34,15 +33,12 @@ class Converter { Converter(); virtual ~Converter(); virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags); - void CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags); - void FreeFuncGraph(const FuncGraphPtr &func_graph); protected: ModelParser *modelParser = nullptr; AnfImporter *modelImporter = nullptr; GraphDefTransform *transform = nullptr; AnfTransform *anfTransform = nullptr; - std::unique_ptr mQuantizer = nullptr; }; int RunConverter(int argc, const char **argv); diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 2407e7d3fb..47dd18de33 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -15,15 +15,12 @@ */ #include "tools/converter/graphdef_transform.h" -#include #include #include #include "schema/model_generated.h" #include "utils/log_adapter.h" -#include "src/common/op_utils.h" #include "tools/converter/converter_flags.h" #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" -// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" @@ -37,7 +34,6 @@ #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" #include "tools/converter/quantizer/aware_quantizer.h" -#include "tools/converter/converter.h" using std::string; namespace mindspore::lite { @@ -72,7 +68,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { weightHardCodePass->SetFmkType(ctx.fmk); weightFormatPass->SetQuantType(ctx.quantType); weightFormatPass->SetFmkType(ctx.fmk); -// weightFormatPass->SetDstFormat(Format_KHWC); weightFormatOptimizer.AddPass(weightHardCodePass); weightFormatOptimizer.AddPass(weightFormatPass); status = weightFormatOptimizer.Run(graphDefT); @@ -153,9 +148,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { formatTransOptimizer.AddPass(new EltwiseFormatTransPass()); formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - // if (ctx.quantType == QuantType_AwareTraining) { - // formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass()); - // } status = formatTransOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h index 3f3a42df84..ff23fcd47f 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h @@ -74,7 +74,7 @@ class MatMulBiasAddFusionPass : public FusionPass { std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(), [](const int32_t ele) { return ele; }); newOpDef->primitive->value.value = transposeParam; - return std::move(newOpDef); + return newOpDef; }; }; } // namespace lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index 1c1c0a7284..2b1906b6fe 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -72,7 +72,7 @@ class DTypeTransPass : public GraphPass { QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT; QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT; newCNode->primitive->value.value = QuantDTypeCastParam; - return std::move(newCNode); + return newCNode; }; }; } // namespace lite diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index 56d71b33ae..f216699acb 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -32,16 +32,16 @@ class ModelParser { virtual ~ModelParser() {} - virtual FuncGraphPtr ParseToAnf(const std::string &modelFile, const std::string &weightFile) { - auto *meta_graph = Parse(modelFile, weightFile); - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; - return nullptr; - } - return Fb2Anf(Parse(modelFile, weightFile)); + FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType = QuantType_QUANT_NONE) { + auto *meta_graph = ParseToFb(modelFile, weightFile, quantType); + auto func_graph = this->Fb2Anf(meta_graph); + delete(meta_graph); + return func_graph; } - virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType = QuantType_QUANT_NONE) = 0; + + virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType = QuantType_QUANT_NONE) = 0; public: static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 051bb03ccb..387337c67b 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -31,7 +31,7 @@ CaffeModelParser::~CaffeModelParser() {} const std::set CaffeModelParser::skipedLayerType = {"Dropout"}; -schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, +schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { @@ -49,7 +49,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, return nullptr; } - std::unique_ptr subGraphDef = std::make_unique(); + auto metaGraph = std::make_unique(); TensorCache tensorCache; caffe::NetParameter proto; @@ -57,7 +57,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile; return nullptr; } - subGraphDef->name = proto.name(); + metaGraph->name = proto.name(); caffe::NetParameter weight; if (ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight) != RET_OK) { @@ -71,22 +71,22 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, return nullptr; } - status = ParseLayer(proto, weight, &tensorCache, subGraphDef.get()); + status = ParseLayer(proto, weight, &tensorCache, metaGraph.get()); if (status != RET_OK) { MS_LOG(ERROR) << "ParseLayer failed " << status; return nullptr; } - status = SetGraphTensorIndex(proto, &tensorCache, subGraphDef.get()); + status = SetGraphTensorIndex(proto, &tensorCache, metaGraph.get()); if (status != RET_OK) { MS_LOG(ERROR) << "Set inputTensor index and outputTensor index for graph failed!"; return nullptr; } - subGraphDef->name = GetModelName(modelFile); + metaGraph->name = GetModelName(modelFile); - SetAllTensors(tensorCache, subGraphDef.get()); + SetAllTensors(tensorCache, metaGraph.get()); - return subGraphDef.release(); + return metaGraph.release(); } STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index e475c06875..ffea8e6aaa 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -33,7 +33,7 @@ class CaffeModelParser : public ModelParser { virtual ~CaffeModelParser(); - MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType = QuantType_QUANT_NONE) override; private: diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index b6edf4ebd4..926573940f 100755 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -507,14 +507,14 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) } } -MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, +schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; return nullptr; } - std::unique_ptr dst_graph = std::make_unique(); + auto dst_graph = std::make_unique(); onnx::ModelProto onnx_model; if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) { MS_LOG(ERROR) << "read onnx model fail"; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 838b82f727..4ce0615bd5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -40,7 +40,7 @@ class OnnxModelParser : public ModelParser { virtual ~OnnxModelParser(); - MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType = QuantType_QUANT_NONE) override; private: diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index c615a795b2..8db54f0497 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -44,7 +44,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "Add") == 0) { MS_LOG(DEBUG) << "parse TfliteAddParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -59,7 +59,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Sub") == 0) { MS_LOG(DEBUG) << "parse TfliteSubParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -74,7 +74,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Mul") == 0) { MS_LOG(DEBUG) << "parse TfliteMulParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -89,7 +89,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Div") == 0) { MS_LOG(DEBUG) << "parse TfliteDivParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -113,7 +113,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "FloorMod") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorModParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -131,7 +131,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "SquaredDifference") == 0) { MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -140,7 +140,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Pow") == 0) { MS_LOG(DEBUG) << "parse TflitePowParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -152,7 +152,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Maximum") == 0) { MS_LOG(DEBUG) << "parse TfliteMaximumParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -161,7 +161,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Minimum") == 0) { MS_LOG(DEBUG) << "parse TfliteMinimumParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -202,7 +202,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "Abs") == 0) { MS_LOG(DEBUG) << "parse TfliteAbsParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -211,7 +211,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Exp") == 0) { MS_LOG(DEBUG) << "parse TfliteExpParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -220,7 +220,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Sqrt") == 0) { MS_LOG(DEBUG) << "parse TfliteSqrtParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -229,7 +229,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Rsqrt") == 0) { MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -238,7 +238,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Square") == 0) { MS_LOG(DEBUG) << "parse TfliteSquareParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -247,7 +247,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Sin") == 0) { MS_LOG(DEBUG) << "parse TfliteSinParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -265,7 +265,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Log") == 0) { MS_LOG(DEBUG) << "parse TfliteLogParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -274,7 +274,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Round") == 0) { MS_LOG(DEBUG) << "parse TfliteRoundParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -283,7 +283,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Ceil") == 0) { MS_LOG(DEBUG) << "parse TfliteCeilParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -292,7 +292,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "flOOR") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorParser"; - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 0e9ddad40d..e4c2a5086f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -28,26 +28,25 @@ namespace mindspore { namespace lite { TfliteModelParser::TfliteModelParser() = default; -TfliteModelParser::~TfliteModelParser() = default; +TfliteModelParser::~TfliteModelParser() { delete[](this->tfliteModelBuf); } std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *model_path) { size_t size; - auto buf = ReadFile(model_path, &size); - if (buf == nullptr) { + tfliteModelBuf = ReadFile(model_path, &size); + if (tfliteModelBuf == nullptr) { MS_LOG(ERROR) << "the file buffer is nullptr"; return nullptr; } - flatbuffers::Verifier verify((const uint8_t *)buf, size); + flatbuffers::Verifier verify((const uint8_t *)tfliteModelBuf, size); if (!tflite::VerifyModelBuffer(verify)) { MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; return nullptr; } - return tflite::UnPackModel(buf); + return tflite::UnPackModel(tfliteModelBuf); } STATUS TfliteModelParser::CopyConstTensorData(const std::vector> &tflite_model_buffer, - const tflite::TensorT *tflite_tensor, - schema::TensorT *tensor) { + const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { auto count = 1; std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); @@ -95,8 +94,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, - const QuantType &quant_type, - schema::MetaGraphT *sub_graph) { + const QuantType &quant_type, schema::MetaGraphT *sub_graph) { int idx = 0; for (const auto &tflite_op : tflite_subgraph->operators) { auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; @@ -107,7 +105,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit return RET_ERROR; } - std::unique_ptr op = std::make_unique(); + auto op = std::make_unique(); op->name = op_type + "-" + std::to_string(idx++); op->quantType = quant_type; MS_LOG(INFO) << "parse op: " << op->name.c_str(); @@ -227,7 +225,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr return RET_OK; } -STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) { +STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) { for (auto &op : sub_graph->nodes) { if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { auto attr = op->primitive->value.AsDepthwiseConv2D(); @@ -301,15 +299,10 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) return RET_OK; } -MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, - const std::string &weight_file, - const QuantType &quant_type) { - std::unique_ptr sub_graph = std::make_unique(); - sub_graph->name = "MS_model converted by TF-Lite"; - quantType = quant_type; - +schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { // load graph - std::unique_ptr tflite_model = ReadTfliteModel(model_file.c_str()); + auto tflite_model = ReadTfliteModel(model_file.c_str()); if (tflite_model == nullptr) { MS_LOG(ERROR) << "read tflite model failed"; return nullptr; @@ -321,31 +314,38 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, } const auto &tflite_subgraph = tflite_model->subgraphs[0]; + auto meta_graph = std::make_unique(); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "new meta graph failed"; + return nullptr; + } + meta_graph->name = "MS_model converted by TF-Lite"; + quantType = quant_type; // convert op - if (ConvertOp(tflite_model, tflite_subgraph, quant_type, sub_graph.get()) != RET_OK) { + if (ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get()) != RET_OK) { MS_LOG(ERROR) << "parse op failed."; return nullptr; } // convert tensor - if (ConvertTensor(tflite_subgraph, tflite_model->buffers, sub_graph.get()) != RET_OK) { + if (ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get()) != RET_OK) { MS_LOG(ERROR) << "convert tensor failed"; return nullptr; } // set graph input/output - if (GetGraphInfo(tflite_subgraph, sub_graph.get()) != RET_OK) { + if (GetGraphInfo(tflite_subgraph, meta_graph.get()) != RET_OK) { MS_LOG(ERROR) << "convert tensors failed"; return nullptr; } // update for depthwiseConv - if (ConvertGroupDepthwiseOp(sub_graph.get()) != RET_OK) { + if (ConvertGroupDepthwiseOp(meta_graph.get()) != RET_OK) { MS_LOG(ERROR) << "convert group depthwise conv failed"; return nullptr; } - return sub_graph.release(); + return meta_graph.release(); } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 71e28c3c88..38dbe95592 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -41,7 +41,7 @@ class TfliteModelParser : public ModelParser { ~TfliteModelParser() override; - MetaGraphT *Parse(const std::string &model_file, + schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, const QuantType &quantType = QuantType_QUANT_NONE) override; @@ -78,6 +78,7 @@ class TfliteModelParser : public ModelParser { std::map opMap; std::map tfliteOpMap; QuantType quantType = QuantType_QUANT_NONE; + char *tfliteModelBuf = nullptr; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 850af3bcf8..9d7a8b4478 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -41,200 +41,173 @@ using std::vector; namespace mindspore { namespace lite { namespace quant { -struct DivergInfo { - std::vector histogram; - CNodePtr cnode; - int bin_num; - float interval = 0; - float max; - float min; - float best_T = 0.0f; - size_t bit_num; - int quant_max = 255; - int quant_min = 0; - std::string method_x = kMethodKL; - - DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) { - this->method_x = method_x; - this->cnode = cnode; - this->bin_num = bins; - this->bit_num = bits; - histogram.resize(bin_num); - max = -FLT_MAX; - min = FLT_MAX; - this->quant_max = quant_max; - this->quant_min = quant_min; - std::fill(histogram.begin(), histogram.end(), 1.0e-7); +STATUS DivergInfo::RecordMaxValue(const std::vector &datas) { + for (float data : datas) { + max = std::max(data, max); + min = std::min(data, min); } + return RET_OK; +} - STATUS RecordMaxValue(const std::vector &datas) { - for (float data : datas) { - max = std::max(data, max); - min = std::min(data, min); +void DivergInfo::UpdateInterval() { + auto max_value = std::max(fabs(this->max), fabs(this->min)); + this->interval = max_value / static_cast(bin_num); +} + +STATUS DivergInfo::UpdateHistogram(const std::vector &data) { + for (auto value : data) { + if (value == 0) { + continue; } - return RET_OK; + int bin_index = std::min(static_cast(std::fabs(value) / this->interval), bin_num - 1); + this->histogram[bin_index]++; } + return RET_OK; +} - void UpdateInterval() { - auto max_value = std::max(fabs(this->max), fabs(this->min)); - this->interval = max_value / static_cast(bin_num); +void DivergInfo::DumpHistogram() { + MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram"; + for (float item : this->histogram) { + std::cout << item << " "; } + std::cout << std::endl; +} - STATUS UpdateHistogram(const std::vector &data) { - for (auto value : data) { - if (value == 0) { - continue; - } - int bin_index = std::min(static_cast(std::fabs(value) / this->interval), bin_num - 1); - this->histogram[bin_index]++; - } +STATUS DivergInfo::ComputeThreshold() { + if (method_x == kMethodMaxMin) { + this->best_T = std::max(fabs(this->max), fabs(this->min)); + MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T; return RET_OK; } - void DumpHistogram() { - MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram"; - for (float item : this->histogram) { - std::cout << item << " "; - } - std::cout << std::endl; - } - - STATUS ComputeThreshold() { - if (method_x == kMethodMaxMin) { - this->best_T = std::max(fabs(this->max), fabs(this->min)); - MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T; - return RET_OK; + constexpr int quant_bint_nums = 128; + int threshold = quant_bint_nums; + float min_kl = FLT_MAX; + float after_threshold_sum = std::accumulate(this->histogram.begin() + quant_bint_nums, this->histogram.end(), 0.0f); + + for (int i = quant_bint_nums; i < this->bin_num; ++i) { + std::vector quantized_histogram(quant_bint_nums, 0); + std::vector reference_histogram(this->histogram.begin(), this->histogram.begin() + i); + std::vector expanded_histogram(i, 0); + reference_histogram[i - 1] += after_threshold_sum; + after_threshold_sum -= this->histogram[i]; + + const float bin_interval = static_cast(i) / static_cast(quant_bint_nums); + + // merge i bins to target bins + for (int j = 0; j < quant_bint_nums; ++j) { + const float start = j * bin_interval; + const float end = start + bin_interval; + const int left_upper = static_cast(std::ceil(start)); + if (left_upper > start) { + const double left_scale = left_upper - start; + quantized_histogram[j] += left_scale * this->histogram[left_upper - 1]; + } + const int right_lower = static_cast(std::floor(end)); + if (right_lower < end) { + const double right_scale = end - right_lower; + quantized_histogram[j] += right_scale * this->histogram[right_lower]; + } + std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, + [&quantized_histogram, j](float item) { quantized_histogram[j] += item; }); } - - constexpr int quant_bint_nums = 128; - int threshold = quant_bint_nums; - float min_kl = FLT_MAX; - float after_threshold_sum = std::accumulate(this->histogram.begin() + quant_bint_nums, this->histogram.end(), 0.0f); - - for (int i = quant_bint_nums; i < this->bin_num; ++i) { - std::vector quantized_histogram(quant_bint_nums, 0); - std::vector reference_histogram(this->histogram.begin(), this->histogram.begin() + i); - std::vector expanded_histogram(i, 0); - reference_histogram[i - 1] += after_threshold_sum; - after_threshold_sum -= this->histogram[i]; - - const float bin_interval = static_cast(i) / static_cast(quant_bint_nums); - - // merge i bins to target bins - for (int j = 0; j < quant_bint_nums; ++j) { - const float start = j * bin_interval; - const float end = start + bin_interval; - const int left_upper = static_cast(std::ceil(start)); - if (left_upper > start) { - const double left_scale = left_upper - start; - quantized_histogram[j] += left_scale * this->histogram[left_upper - 1]; + // expand target bins to i bins in order to calculate KL with reference_histogram + for (int j = 0; j < quant_bint_nums; ++j) { + const float start = j * bin_interval; + const float end = start + bin_interval; + float count = 0; + const int left_upper = static_cast(std::ceil(start)); + float left_scale = 0.0f; + if (left_upper > start) { + left_scale = left_upper - start; + if (this->histogram[left_upper - 1] != 0) { + count += left_scale; } - const int right_lower = static_cast(std::floor(end)); - if (right_lower < end) { - const double right_scale = end - right_lower; - quantized_histogram[j] += right_scale * this->histogram[right_lower]; - } - std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, - [&quantized_histogram, j](float item) { quantized_histogram[j] += item; }); } - // expand target bins to i bins in order to calculate KL with reference_histogram - for (int j = 0; j < quant_bint_nums; ++j) { - const float start = j * bin_interval; - const float end = start + bin_interval; - float count = 0; - const int left_upper = static_cast(std::ceil(start)); - float left_scale = 0.0f; - if (left_upper > start) { - left_scale = left_upper - start; - if (this->histogram[left_upper - 1] != 0) { - count += left_scale; - } - } - const int right_lower = static_cast(std::floor(end)); - double right_scale = 0.0f; - if (right_lower < end) { - right_scale = end - right_lower; - if (this->histogram[right_lower] != 0) { - count += right_scale; - } - } - std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, - [&count](float item) { - if (item != 0) { - count += 1; - } - }); - if (count == 0) { - continue; - } - const float average_num = quantized_histogram[j] / count; - if (left_upper > start && this->histogram[left_upper - 1] != 0) { - expanded_histogram[left_upper - 1] += average_num * left_scale; + const int right_lower = static_cast(std::floor(end)); + double right_scale = 0.0f; + if (right_lower < end) { + right_scale = end - right_lower; + if (this->histogram[right_lower] != 0) { + count += right_scale; } - if (right_lower < end && this->histogram[right_lower] != 0) { - expanded_histogram[right_lower] += average_num * right_scale; + } + std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, [&count](float item) { + if (item != 0) { + count += 1; } - for (int k = left_upper; k < right_lower; ++k) { - if (this->histogram[k] != 0) { - expanded_histogram[k] += average_num; - } + }); + if (count == 0) { + continue; + } + const float average_num = quantized_histogram[j] / count; + if (left_upper > start && this->histogram[left_upper - 1] != 0) { + expanded_histogram[left_upper - 1] += average_num * left_scale; + } + if (right_lower < end && this->histogram[right_lower] != 0) { + expanded_histogram[right_lower] += average_num * right_scale; + } + for (int k = left_upper; k < right_lower; ++k) { + if (this->histogram[k] != 0) { + expanded_histogram[k] += average_num; } } - auto KLDivergence = [](std::vector p, std::vector q) { - auto sum = 0.0f; - std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; }); - std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; }); - sum = 0.0f; - std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; }); - std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; }); - - float result = 0.0f; - const int size = p.size(); - for (int i = 0; i < size; ++i) { - if (p[i] != 0) { - if (q[i] == 0) { - result += 1.0f; - } else { - result += (p[i] * std::log((p[i]) / (q[i]))); - } + } + auto KLDivergence = [](std::vector p, std::vector q) { + auto sum = 0.0f; + std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; }); + std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; }); + sum = 0.0f; + std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; }); + std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; }); + + float result = 0.0f; + const int size = p.size(); + for (int i = 0; i < size; ++i) { + if (p[i] != 0) { + if (q[i] == 0) { + result += 1.0f; + } else { + result += (p[i] * std::log((p[i]) / (q[i]))); } } - return result; - }; - const float kl = KLDivergence(reference_histogram, expanded_histogram); - if (kl < min_kl) { - min_kl = kl; - threshold = i; } + return result; + }; + const float kl = KLDivergence(reference_histogram, expanded_histogram); + if (kl < min_kl) { + min_kl = kl; + threshold = i; } - this->best_T = (static_cast(threshold) + 0.5f) * this->interval; - MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T - << " max: " << std::max(fabs(this->max), fabs(this->min)); - return RET_OK; } + this->best_T = (static_cast(threshold) + 0.5f) * this->interval; + MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T + << " max: " << std::max(fabs(this->max), fabs(this->min)); + return RET_OK; +} - std::pair GetScale() { - float max_value = this->best_T; - float min_value = -max_value; +std::pair DivergInfo::GetScale() { + float max_value = this->best_T; + float min_value = -max_value; - MS_ASSERT(quant_max - quant_min != 0); - float scale = (max_value - min_value) / (quant_max - quant_min); - MS_ASSERT(scale != 0); - return std::make_pair(this->cnode, scale); - } + MS_ASSERT(quant_max - quant_min != 0); + float scale = (max_value - min_value) / (quant_max - quant_min); + MS_ASSERT(scale != 0); + return std::make_pair(this->cnode, scale); +} - std::pair GetZeropoint() { - int zero_point = 0; - if (quant_min == 0 && quant_max == 255) { - zero_point = 128; - } else if (quant_min == -127 && quant_max == 127) { - zero_point = 0; - } else { - MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max; - } - return std::make_pair(this->cnode, zero_point); +std::pair DivergInfo::GetZeropoint() { + int zero_point = 0; + if (quant_min == 0 && quant_max == 255) { + zero_point = 128; + } else if (quant_min == -127 && quant_max == 127) { + zero_point = 0; + } else { + MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max; } -}; + return std::make_pair(this->cnode, zero_point); +} + std::unordered_map Calibrator::GetScale( std::unordered_map> *diverg_info) { std::unordered_map result; @@ -359,7 +332,7 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) { void Calibrator::AddImage(const string file) { auto exist = [](const string file) { - struct stat buf{}; + struct stat buf {}; return stat(file.c_str(), &buf) == 0; }; if (exist(file)) { diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 7729fcd79a..e8e6be2d6d 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "src/lite_session.h" #include "tools/converter/quantizer/quantizer.h" #include "tools/converter/converter.h" @@ -90,13 +91,51 @@ class PostTrainingQuantizer : public Quantizer { STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); - STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitive_c, bool perchannel, - bool depthwise); + STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitive_c, bool perchannel, bool depthwise); STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitive_c); }; -struct DivergInfo; +struct DivergInfo { + std::vector histogram; + CNodePtr cnode; + int bin_num; + float interval = 0; + float max; + float min; + float best_T = 0.0f; + size_t bit_num; + int quant_max = 255; + int quant_min = 0; + std::string method_x = kMethodKL; + + DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) { + this->method_x = method_x; + this->cnode = cnode; + this->bin_num = bins; + this->bit_num = bits; + histogram.resize(bin_num); + max = -FLT_MAX; + min = FLT_MAX; + this->quant_max = quant_max; + this->quant_min = quant_min; + std::fill(histogram.begin(), histogram.end(), 1.0e-7); + } + + STATUS RecordMaxValue(const std::vector &datas); + + void UpdateInterval(); + + STATUS UpdateHistogram(const std::vector &data); + + void DumpHistogram(); + + STATUS ComputeThreshold(); + + std::pair GetScale(); + + std::pair GetZeropoint(); +}; class Calibrator { public: @@ -123,7 +162,7 @@ class Calibrator { STATUS UpdateDivergInverval(std::unordered_map> *diverg_info); - STATUS UpdateDataFrequency(const std::string& op_name, const std::vector& data, + STATUS UpdateDataFrequency(const std::string &op_name, const std::vector &data, std::unordered_map> *diverg_info); void Dump(); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 55537d41b8..bb1b9859e8 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "tools/optimizer/fusion/constant_folding_fusion.h" #include #include #include -#include "schema/inner/model_generated.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/anf_exporter/anf_exporter.h" #include "src/kernel_registry.h" @@ -30,7 +30,7 @@ using mindspore::lite::PrimitiveC; using mindspore::lite::tensor::Tensor; namespace mindspore::opt { namespace { -const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { +std::vector GetCNodeInputTensors(const CNodePtr &CNode) { MS_ASSERT(CNode != nullptr); auto tmp_meta_graph = std::make_unique(); auto tmp_fb_node = std::make_unique(); @@ -48,11 +48,11 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { } auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); // when tensorT as graph input - if (lite_tensor_size == 0) { + if (lite_tensor_size <= 0) { delete lite_tensor; return input_tensors; } - auto tensor_data = new (std::nothrow) char[lite_tensor_size / sizeof(char)]; + auto tensor_data = reinterpret_cast(malloc(lite_tensor_size / sizeof(char))); if (tensor_data == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; delete lite_tensor; @@ -61,16 +61,16 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { auto ret = memcpy_s(tensor_data, lite_tensor_size, tensorT->data.data(), lite_tensor_size); if (ret != EOK) { delete lite_tensor; - delete tensor_data; + delete[](tensor_data); MS_LOG(EXCEPTION) << "memcpy error: " << ret; - return input_tensors; } lite_tensor->SetData(tensor_data); input_tensors.emplace_back(lite_tensor); } return input_tensors; } -const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { + +ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { auto parameter = func_graph->add_parameter(); std::vector shape(tensor->shape()); auto type_id = static_cast(tensor->data_type()); @@ -102,17 +102,12 @@ const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *ten parameter->set_default_param(param_value); return parameter; } -kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vector outputs, +kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vector outputs, OpParameter *parameter, mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(nullptr != lite_primitive); auto data_type = inputs.front()->data_type(); kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()}; lite::Context context; - auto parameter = kernel::PopulateParameter(primitive); - if (parameter == nullptr) { - MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << (schema::PrimitiveType)primitive->Type(); - return nullptr; - } auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); if (creator != nullptr) { auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive); @@ -121,16 +116,19 @@ kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vector *input_tensor) { - MS_ASSERT(input_tensor != nullptr); - for (size_t i = 0; i < input_tensor->size(); i++) { - if ((*input_tensor)[i] == nullptr) { - continue; +void FreeTensors(std::vector *input_tensor, std::vector *output_tensor) { + if (input_tensor != nullptr) { + for (size_t i = 0; i < input_tensor->size(); i++) { + delete (*input_tensor)[i]; + (*input_tensor)[i] = nullptr; + } + } + if (output_tensor != nullptr) { + for (size_t i = 0; i < output_tensor->size(); i++) { + delete (*output_tensor)[i]; + (*output_tensor)[i] = nullptr; } - delete (*input_tensor)[i]; - (*input_tensor)[i] = nullptr; } - return; } const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, @@ -148,7 +146,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An auto input_cnode = input_node->cast(); auto input_tensors = GetCNodeInputTensors(input_cnode); if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { - FreeInputTensor(&input_tensors); + FreeTensors(&input_tensors, nullptr); continue; } MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); @@ -157,39 +155,47 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An auto lite_primitive = GetValueNode>(input_cnode->input(0)); if (lite_primitive == nullptr) { MS_LOG(ERROR) << "lite_primitive is nullptr"; + FreeTensors(&input_tensors, &output_tensors); return nullptr; } // here, input_tensor's format need to be transposed nhwc according to fmkType, // but for the time being, we only transpose the tensor with 0/1/2/3D. // Others should be added in future. for (size_t j = 0; j < input_tensors.size(); ++j) { - input_tensors[j]->SetFormat(schema::Format_NHWC); - if (input_tensors[j]->shape().size() == 4) { - MS_LOG(WARNING) << "init input_tensor format to nhwc"; - } + input_tensors[j]->SetFormat(schema::Format_NHWC); + if (input_tensors[j]->shape().size() == 4) { + MS_LOG(WARNING) << "init input_tensor format to nhwc"; + } } lite_primitive->InferShape(input_tensors, output_tensors); - auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get()); + auto parameter = kernel::PopulateParameter(lite_primitive.get()); + if (parameter == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " + << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); + return nullptr; + } + auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get()); if (lite_kernel == nullptr) { MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; - FreeInputTensor(&input_tensors); + FreeTensors(&input_tensors, &output_tensors); return nullptr; } auto ret = lite_kernel->Run(); if (0 != ret) { - FreeInputTensor(&input_tensors); + FreeTensors(&input_tensors, &output_tensors); MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name(); return nullptr; } auto new_parameter = CreateNewParamter(func_graph, output_tensors.front()); if (new_parameter == nullptr) { - FreeInputTensor(&input_tensors); + FreeTensors(&input_tensors, &output_tensors); MS_LOG(ERROR) << "CreateNewParamter failed, name: " << lite_kernel->name(); return nullptr; } new_parameter->set_name(input_node->fullname_with_scope()); any_node->set_input(i, new_parameter); - FreeInputTensor(&input_tensors); + FreeTensors(&input_tensors, &output_tensors); + delete (lite_kernel); } } return any_node; diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h index ad6a64ecf6..29fde221bf 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h @@ -17,6 +17,10 @@ #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ +#include "schema/inner/model_generated.h" +#include "src/ir/tensor.h" +#include "src/lite_kernel.h" +#include "nnacl/op_base.h" #include "backend/optimizer/common/optimizer.h" namespace mindspore { @@ -30,4 +34,3 @@ class ConstFoldPass : public PatternProcessPass { } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ -