mindspore/mindspore/lite/tools/converter/converter.cc

179 lines
5.8 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/converter.h"
#include <memory>
#include <vector>
#include <utility>
#include "tools/converter/converter_flags.h"
#include "src/common/common.h"
#include "src/common/file_utils.h"
#include "ir/func_graph.h"
#include "src/common/log_adapter.h"
#include "tools/common/storage.h"
#include "parser/caffe/caffe_converter.h"
#include "parser/tflite/tflite_converter.h"
#include "parser/onnx/onnx_converter.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/anf_importer/import_from_protobuf.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "include/version.h"
namespace mindspore {
namespace lite {
using FmkType = converter::FmkType;
static const char *DELIM_SLASH = "/";
Converter::Converter() {
this->transform = new GraphDefTransform;
this->anfTransform = new AnfTransform;
}
Converter::~Converter() {
delete modelParser;
delete modelImporter;
delete transform;
delete anfTransform;
}
class MindsporeImporter : public Converter {
public:
MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) {
modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph));
}
~MindsporeImporter() override = default;
};
MetaGraphT *Converter::Convert(const converter::Flags *flag) {
// parse the model and weight file to generate inference data structure
FuncGraphPtr graph = nullptr;
if (flag->fmk == converter::FmkType_MS) {
MS_ASSERT(nullptr != modelImporter);
int status = modelImporter->Import(flag->quantType);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
graph = modelImporter->GetResult();
} else {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;
const std::string weightFile = flag->weightFile;
graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
}
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
}
graph = anfTransform->Transform(graph, flag);
if (graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
}
// anf -- fb
auto meta_graph = Export(graph);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;
}
// transform
transform->SetGraphDef(meta_graph);
transform->CreateQuantizer(flag);
auto status = transform->Transform(*flag);
if (status != RET_OK) {
MS_LOG(ERROR) << "Transform meta graph failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return meta_graph;
}
int RunConverter(int argc, const char **argv) {
std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags);
if (flags == nullptr) {
MS_LOG(ERROR) << "new flags error ";
std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << std::endl;
return RET_MEMORY_FAILED;
}
auto status = flags->Init(argc, argv);
if (status != RET_OK) {
if (status != RET_SUCCESS_EXIT) {
MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << std::endl;
}
return status;
}
// Load graph
std::string modelName = flags->modelFile.substr(flags->modelFile.find_last_of(DELIM_SLASH) + 1);
MS_LOG(INFO) << "start reading model file";
MetaGraphT *fb_graph = nullptr;
switch (flags->fmk) {
case FmkType::FmkType_MS: {
auto graph = std::make_shared<FuncGraph>();
auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile);
MindsporeImporter mindsporeImporter(onnx_graph, graph);
fb_graph = mindsporeImporter.Convert(flags.get());
delete onnx_graph;
break;
}
case FmkType::FmkType_CAFFE: {
CaffeConverter caffeConverter;
fb_graph = caffeConverter.Convert(flags.get());
} break;
case FmkType::FmkType_TFLITE: {
TfliteConverter tfLiteConverter;
fb_graph = tfLiteConverter.Convert(flags.get());
} break;
case FmkType::FmkType_ONNX: {
OnnxConverter onnxConverter;
fb_graph = onnxConverter.Convert(flags.get());
} break;
default: {
MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk;
std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << std::endl;
return RET_INPUT_PARAM_INVALID;
}
}
NoSupportOp::GetInstance()->PrintOps();
status = ReturnCode::GetSingleReturnCode()->GetReturnCode();
if (fb_graph == nullptr) {
MS_LOG(ERROR) << "Convert model return nullptr";
std::cout << "CONVERT RESULT FAILED:" << status << std::endl;
return status;
}
// save graph to file
Storage storage;
fb_graph->version = Version();
status = storage.Save(*fb_graph, flags->outputFile);
if (status != 0) {
MS_LOG(ERROR) << "Save graph to file failed";
std::cout << "SAVE GRAPH FAILED:" << status << std::endl;
return status;
}
delete fb_graph;
MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!";
std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
return status;
}
} // namespace lite
} // namespace mindspore