!5333 fix memory leak

Merge pull request !5333 from hangq/primitive
pull/5333/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 66271e4155

@ -1,12 +1,12 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -52,7 +52,7 @@ public class MSTensor {
this.setDataType(this.tensorPtr, dataType);
}
public byte[] getBtyeData() {
public byte[] getByteData() {
return this.getByteData(this.tensorPtr);
}

@ -18,6 +18,7 @@
#include <jni.h>
#include "common/ms_log.h"
#include "include/context.h"
#include "include/thread_pool_config.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_createContext(JNIEnv *env, jobject thiz,
jint device_type,
@ -44,13 +45,13 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_creat
}
switch (cpu_bind_mode) {
case -1:
context->cpu_bind_mode_ = mindspore::lite::MID_CPU;
context->cpu_bind_mode_ = MID_CPU;
break;
case 0:
context->cpu_bind_mode_ = mindspore::lite::NO_BIND;
context->cpu_bind_mode_ = NO_BIND;
break;
case 1:
context->cpu_bind_mode_ = mindspore::lite::HIGHER_CPU;
context->cpu_bind_mode_ = HIGHER_CPU;
break;
default:
MS_LOGE("Invalid cpu_bind_mode : %d", cpu_bind_mode);

@ -118,9 +118,9 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByte
return env->NewByteArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewByteArray(local_data_size);
env->SetByteArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewByteArray(local_element_num);
env->SetByteArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
@ -144,9 +144,9 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLong
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewLongArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewLongArray(local_data_size);
env->SetLongArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewLongArray(local_element_num);
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
@ -170,9 +170,9 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntDa
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewIntArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewIntArray(local_data_size);
env->SetIntArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewIntArray(local_element_num);
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
@ -196,9 +196,9 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFlo
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewFloatArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewFloatArray(local_data_size);
env->SetFloatArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewFloatArray(local_element_num);
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}

@ -100,8 +100,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<tensor::Tensor *
const std::vector<tensor::Tensor *> &out_tensors,
const PrimitiveC *primitive, const Context *ctx,
const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != ctx);
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "

@ -33,9 +33,9 @@
namespace mindspore {
namespace lite {
int LiteSession::ConvertTensors(const lite::Model *model) {
MS_EXCEPTION_IF_NULL(model);
MS_ASSERT(nullptr != model);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
MS_ASSERT(nullptr != meta_graph);
uint32_t tensorCount = meta_graph->allTensors()->size();
for (uint32_t i = 0; i < tensorCount; i++) {
auto *srcTensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
@ -246,7 +246,7 @@ int LiteSession::CompileGraph(Model *model) {
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const { return this->input_vec_; }
int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) {
MS_EXCEPTION_IF_NULL(this->context_);
MS_ASSERT(this->context_);
if (before == nullptr && after == nullptr) {
return executor->Run(this->inputs_, this->outputs_, this->kernels_, this->context_->allocator.get());
} else {
@ -255,7 +255,7 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session::
}
int LiteSession::Init(Context *context) {
MS_EXCEPTION_IF_NULL(context);
MS_ASSERT(nullptr != context);
this->context_ = new (std::nothrow) Context(context->thread_num_, context->allocator, context->device_ctx_);
if (this->context_ == nullptr) {
MS_LOG(ERROR) << "new context failed";
@ -276,7 +276,7 @@ int LiteSession::Init(Context *context) {
}
#endif
executor = new Executor();
MS_EXCEPTION_IF_NULL(executor);
MS_ASSERT(nullptr != executor);
return RET_OK;
}

@ -101,7 +101,7 @@ int ModelImpl::BuildOps() {
MS_LOG(ERROR) << "mete_graph is nullptr";
return -1;
}
MS_EXCEPTION_IF_NULL(meta_graph_->nodes());
MS_ASSERT(nullptr != meta_graph_->nodes());
for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) {
auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i);
auto name = cNode->name()->str();
@ -129,17 +129,17 @@ Model *Model::Import(const char *model_buf, size_t size) {
Model::~Model() { delete (this->model_impl_); }
mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const {
MS_EXCEPTION_IF_NULL(model_impl_);
MS_ASSERT(nullptr != model_impl_);
return const_cast<PrimitiveC *>(model_impl_->GetOp(name));
}
void Model::FreeMetaGraph() {
MS_EXCEPTION_IF_NULL(model_impl_);
MS_ASSERT(nullptr != model_impl_);
return model_impl_->FreeMetaGraph();
}
const schema::MetaGraph *Model::GetMetaGraph() const {
MS_EXCEPTION_IF_NULL(model_impl_);
MS_ASSERT(nullptr != model_impl_);
return model_impl_->meta_graph();
}

@ -31,8 +31,8 @@ class ParamValueLite : public Value {
ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {}
virtual ~ParamValueLite() {
if (tensor_addr_ != nullptr) {
auto tensor_mem = reinterpret_cast<char*>(tensor_addr_);
delete tensor_mem;
auto tensor_mem = reinterpret_cast<char *>(tensor_addr_);
delete[](tensor_mem);
tensor_addr_ = nullptr;
tensor_size_ = 0;
}

@ -277,7 +277,7 @@ kernel::LiteKernel *CpuArithmeticGradFp32KernelCreator(const std::vector<lite::t
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_EXCEPTION_IF_NULL(opParameter);
MS_ASSERT(nullptr != opParameter);
if (opParameter == nullptr) {
return nullptr;
}

@ -206,7 +206,7 @@ int SubGraphOpenCLKernel::MallocTensorWithReuse() {
output->set_allocator(allocator_);
}
for (auto input_kernel : kernel->in_kernels()) {
MS_EXCEPTION_IF_NULL(input_kernel);
MS_ASSERT(nullptr != input_kernel);
auto ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
@ -214,21 +214,21 @@ int SubGraphOpenCLKernel::MallocTensorWithReuse() {
}
}
for (auto kernel : out_kernels_) {
MS_EXCEPTION_IF_NULL(kernel);
MS_ASSERT(nullptr != kernel);
auto ret = kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
}
}
for (auto kernel : in_convert_ops_) {
MS_EXCEPTION_IF_NULL(kernel);
MS_ASSERT(nullptr != kernel);
auto ret = kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
}
}
for (auto kernel : out_convert_ops_) {
MS_EXCEPTION_IF_NULL(kernel);
MS_ASSERT(nullptr != kernel);
auto ret = kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";

@ -65,7 +65,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso
}
}
for (auto input_kernel : kernel->in_kernels()) {
MS_EXCEPTION_IF_NULL(input_kernel);
MS_ASSERT(nullptr != input_kernel);
ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";

@ -77,10 +77,10 @@ int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
}
int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *> *tensors) {
MS_EXCEPTION_IF_NULL(model);
MS_EXCEPTION_IF_NULL(tensors);
MS_ASSERT(nullptr != model);
MS_ASSERT(nullptr != tensors);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
MS_ASSERT(nullptr != meta_graph);
bool infer_shape_interrupt = false;
uint32_t kernelCount = meta_graph->nodes()->size();
for (uint32_t i = 0; i < kernelCount; i++) {
@ -121,10 +121,10 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) {
MS_EXCEPTION_IF_NULL(model);
MS_EXCEPTION_IF_NULL(tensors);
MS_ASSERT(nullptr != model);
MS_ASSERT(nullptr != tensors);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
MS_ASSERT(nullptr != meta_graph);
uint32_t kernelCount = meta_graph->nodes()->size();
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (uint32_t i = 0; i < kernelCount; i++) {

@ -23,7 +23,7 @@ namespace mindspore {
schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) {
lite::TfliteModelParser parser;
meta_graph = parser.Parse(model_path, weight_path);
meta_graph = parser.ParseToFb(model_path, weight_path);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;

@ -107,7 +107,7 @@ int Benchmark::ReadInputFile() {
}
auto inputData = cur_tensor->MutableData();
memcpy(inputData, binBuf, tensorDataSize);
delete binBuf;
delete[](binBuf);
}
}
return RET_OK;
@ -455,6 +455,12 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
}
if (!_flags->calibDataPath.empty()) {
status = MarkAccuracy();
for (auto &data : calibData) {
data.second->shape.clear();
data.second->data.clear();
delete data.second;
}
calibData.clear();
if (status != 0) {
MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
std::cout << "Run MarkAccuracy error: " << status << std::endl;
@ -472,16 +478,6 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
return status;
}
}
if (cleanData) {
for (auto &data : calibData) {
data.second->shape.clear();
data.second->data.clear();
delete data.second;
}
calibData.clear();
}
delete (session);
delete (model);
return RET_OK;

@ -138,7 +138,6 @@ class MS_API Benchmark {
std::vector<mindspore::tensor::MSTensor *> msInputs;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs;
std::unordered_map<std::string, CheckTensor *> calibData;
bool cleanData = true;
};
int MS_API RunBenchmark(int argc, const char **argv);

@ -35,7 +35,7 @@ OpDefCopyer GetSimpleOpCopyer() {
newCNode->quantType = inCNode->quantType;
newCNode->primitive = std::make_unique<schema::PrimitiveT>();
newCNode->primitive->value.type = inCNode->primitive->value.type;
return std::move(newCNode);
return newCNode;
};
}
@ -96,7 +96,7 @@ std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size
preNodeIdx.emplace_back(i);
}
}
return std::move(preNodeIdx);
return preNodeIdx;
}
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
@ -111,7 +111,7 @@ std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const siz
postNodeIdx.emplace_back(i);
}
}
return std::move(postNodeIdx);
return postNodeIdx;
}
STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {

@ -89,7 +89,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
MS_LOG(ERROR) << "Dim size invalid";
return RET_ERROR;
}
std::unique_ptr<T> buf(new (std::nothrow) T[count]);
std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
if (buf == nullptr) {
MS_LOG(ERROR) << "new buf failed";
return RET_ERROR;

@ -24,7 +24,7 @@ std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT>
MS_ASSERT(tensor != nullptr);
auto &quantParams = tensor->quantParams;
if (!quantParams.empty()) {
return std::move(CopyQuantParamT(quantParams.front()));
return CopyQuantParamT(quantParams.front());
} else {
return nullptr;
}
@ -39,7 +39,7 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem
dstQuantParam->max = srcQuantParam->max;
dstQuantParam->narrowRange = srcQuantParam->narrowRange;
dstQuantParam->numBits = srcQuantParam->numBits;
return std::move(dstQuantParam);
return dstQuantParam;
}
size_t GetElementSize(const TensorT &tensor) { return GetElementSize(TypeId(tensor.dataType)); }
@ -87,7 +87,7 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &oldTenso
if (!oldTensor->quantParams.empty()) {
newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor)));
}
return std::move(newTensor);
return newTensor;
}
size_t GetRefCount(MetaGraphT *graphT, uint32_t tensorIdx) {

@ -24,6 +24,8 @@
#include "tools/optimizer/fusion/conv_scale_fusion.h"
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
using std::string;
namespace mindspore {
@ -32,10 +34,9 @@ AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default;
void AnfTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) {
// return old_graph;
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
// fusion const_fold
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
@ -54,6 +55,31 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) {
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
// quant
if (config != nullptr && config->quantType == schema::QuantType_PostTraining) {
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
return nullptr;
}
}
if (mQuantizer != nullptr) {
mQuantizer->flags = *config;
auto status = mQuantizer->DoQuantize(new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;
return nullptr;
}
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
return nullptr;
}
}
return new_graph;
}
} // namespace lite

@ -17,11 +17,12 @@
#ifndef MS_ANF_TRANSFORM_H
#define MS_ANF_TRANSFORM_H
#include <memory>
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
#include "ir/anf.h"
#include "tools/converter/quantizer/quantizer.h"
namespace mindspore {
namespace lite {
@ -29,15 +30,12 @@ class AnfTransform {
public:
AnfTransform();
virtual ~AnfTransform();
FuncGraphPtr Transform(const FuncGraphPtr &old_graph);
void SetGraphDef(schema::MetaGraphT *dstDef);
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
protected:
schema::MetaGraphT *graphDefT = nullptr;
private:
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
};
} // namespace lite
} // namespace mindspore
#endif

@ -70,41 +70,23 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;
const std::string weightFile = flag->weightFile;
auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;
}
graph = ModelParser::Fb2Anf(meta_graph);
graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
}
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
}
graph = anfTransform->Transform(graph);
CreateQuantizer(graph, flag);
if (mQuantizer != nullptr) {
mQuantizer->flags = *flag;
auto status = mQuantizer->DoQuantize(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;
return nullptr;
}
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
return nullptr;
}
graph = anfTransform->Transform(graph, flag);
if (graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
}
// anf -- fb
auto meta_graph = Export(graph);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;
}
@ -113,20 +95,13 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
transform->CreateQuantizer(flag);
auto status = transform->Transform(*flag);
if (status != 0) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
MS_LOG(ERROR) << "Transform meta graph failed " << status;
return nullptr;
}
return meta_graph;
}
void Converter::CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags) {
auto type = flags->quantType;
if (type == mindspore::schema::QuantType_PostTraining) {
MS_LOG(INFO) << "create post training quantizer.";
mQuantizer.reset(new quant::PostTrainingQuantizer(func_graph, flags->configFile, 8));
}
}
int RunConverter(int argc, const char **argv) {
std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags);
if (flags == nullptr) {

@ -25,7 +25,6 @@
#include "tools/anf_importer/anf_importer.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/anf_transform.h"
#include "tools/converter/quantizer/quantizer.h"
namespace mindspore {
namespace lite {
@ -34,15 +33,12 @@ class Converter {
Converter();
virtual ~Converter();
virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags);
void CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags);
void FreeFuncGraph(const FuncGraphPtr &func_graph);
protected:
ModelParser *modelParser = nullptr;
AnfImporter *modelImporter = nullptr;
GraphDefTransform *transform = nullptr;
AnfTransform *anfTransform = nullptr;
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
};
int RunConverter(int argc, const char **argv);

@ -15,15 +15,12 @@
*/
#include "tools/converter/graphdef_transform.h"
#include <iostream>
#include <memory>
#include <string>
#include "schema/model_generated.h"
#include "utils/log_adapter.h"
#include "src/common/op_utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
@ -37,7 +34,6 @@
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"
#include "tools/converter/converter.h"
using std::string;
namespace mindspore::lite {
@ -72,7 +68,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
weightHardCodePass->SetFmkType(ctx.fmk);
weightFormatPass->SetQuantType(ctx.quantType);
weightFormatPass->SetFmkType(ctx.fmk);
// weightFormatPass->SetDstFormat(Format_KHWC);
weightFormatOptimizer.AddPass(weightHardCodePass);
weightFormatOptimizer.AddPass(weightFormatPass);
status = weightFormatOptimizer.Run(graphDefT);
@ -153,9 +148,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
formatTransOptimizer.AddPass(new EltwiseFormatTransPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
// if (ctx.quantType == QuantType_AwareTraining) {
// formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass());
// }
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";

@ -74,7 +74,7 @@ class MatMulBiasAddFusionPass : public FusionPass {
std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(),
[](const int32_t ele) { return ele; });
newOpDef->primitive->value.value = transposeParam;
return std::move(newOpDef);
return newOpDef;
};
};
} // namespace lite

@ -72,7 +72,7 @@ class DTypeTransPass : public GraphPass {
QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT;
QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT;
newCNode->primitive->value.value = QuantDTypeCastParam;
return std::move(newCNode);
return newCNode;
};
};
} // namespace lite

@ -32,16 +32,16 @@ class ModelParser {
virtual ~ModelParser() {}
virtual FuncGraphPtr ParseToAnf(const std::string &modelFile, const std::string &weightFile) {
auto *meta_graph = Parse(modelFile, weightFile);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;
}
return Fb2Anf(Parse(modelFile, weightFile));
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) {
auto *meta_graph = ParseToFb(modelFile, weightFile, quantType);
auto func_graph = this->Fb2Anf(meta_graph);
delete(meta_graph);
return func_graph;
}
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
public:
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save