fix return code

pull/8203/head
yankai 4 years ago
parent 93daaf4de9
commit 3fd4e309a5

@ -37,6 +37,7 @@
#include "proto/onnx.pb.h"
#include "src/common/log_adapter.h"
#include "tools/common/protobuf_utils.h"
#include "tools/common/graph_util.h"
using string = std::string;
using int32 = int32_t;
@ -844,8 +845,14 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
auto onnx_model = new onnx::ModelProto;
if (RET_OK != ValidateFileStr(model_path, ".mindir")) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.mindir";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID);
return nullptr;
}
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
return onnx_model;

@ -33,7 +33,6 @@
#include "proto/onnx.pb.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/common//graph_util.h"
#include "include/version.h"
namespace mindspore {
@ -128,14 +127,10 @@ int RunConverter(int argc, const char **argv) {
switch (flags->fmk) {
case FmkType::FmkType_MS: {
auto graph = std::make_shared<FuncGraph>();
if (RET_OK != ValidateFileStr(flags->modelFile, ".mindir")) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.mindir";
return RET_INPUT_PARAM_INVALID;
}
auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile);
if (onnx_graph == nullptr) {
MS_LOG(ERROR) << "Read MINDIR model from binary failed";
return RET_INPUT_PARAM_INVALID;
MS_LOG(ERROR) << "Read MINDIR from binary return nullptr";
break;
}
MindsporeImporter mindsporeImporter(onnx_graph, graph);
fb_graph = mindsporeImporter.Convert(flags.get());

@ -85,6 +85,7 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) {
return false;
}
auto total_node_count = input_node_indexes.size() + output_node_indexes.size();
size_t half_count = total_node_count / 2;
if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) {

Loading…
Cancel
Save