From 22dc34058c776255a9a4d2d2d6e72e54fea9b729 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Sat, 5 Sep 2020 15:56:00 +0800 Subject: [PATCH] mindspore export inferface finetune lite --- mindspore/lite/test/models_mindspore.cfg | 4 +- .../anf_importer/import_from_protobuf.cc | 417 +++++++++++------- .../tools/anf_importer/import_from_protobuf.h | 17 +- 3 files changed, 264 insertions(+), 174 deletions(-) diff --git a/mindspore/lite/test/models_mindspore.cfg b/mindspore/lite/test/models_mindspore.cfg index f408c7ef61..51bcc21a4e 100644 --- a/mindspore/lite/test/models_mindspore.cfg +++ b/mindspore/lite/test/models_mindspore.cfg @@ -1,2 +1,2 @@ -ssd.pb -mobilenet_v2.pb +#ssd.pb +# mobilenet_v2.pb diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index 3f0a0667ac..8eeb0d83ce 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -46,9 +46,6 @@ using uint64 = uint64_t; namespace mindspore::lite { static constexpr char kConstantValueNode[] = "Constant"; -static constexpr char kCNodeShapeAttr[] = "shape"; -static constexpr char kCNodeShape1Attr[] = "shape1"; -static constexpr char kCNodeShape2Attr[] = "shape2"; enum ParseForm : int { FORM_PARSE_TYPE = 0, @@ -57,32 +54,143 @@ enum ParseForm : int { }; static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; + {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, }; -#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ - void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ - const onnx::TensorProto &attr_tensor) { \ - MS_EXCEPTION_IF_NULL(prim); \ - std::vector attr_value_vec; \ - for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ - auto value = static_cast(attr_tensor.type##_data(i)); \ - attr_value_vec.push_back(MakeValue(value)); \ - } \ - if (attr_value_vec.size() == 1) { \ - prim->AddAttr(attr_name, attr_value_vec[0]); \ - } else { \ - prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ - } \ +std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, + const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("scalar:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + std::stack rules; + std::stack value; + int num = 0, count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + num++; + } + } + } + return {}; +} + +std::shared_ptr +ParserAttrShape(const std::string &attr_name, const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("shape:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + std::stack rules; + std::stack value; + int num = 0, count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + num++; + } + } + } + return {}; +} + +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ + if (attr_tensor.type##_data_size() == 1) { \ + auto value = static_cast(attr_tensor.type##_data(0)); \ + return MakeValue(value); \ + } else { \ + MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ + } \ + return {}; \ } PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) @@ -193,45 +301,34 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim return true; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { - MS_EXCEPTION_IF_NULL(prim); +ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); switch (attr_tensor_type) { case onnx::TensorProto_DataType_STRING: { - ParseAttrInScalar_string_string(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_string_string(attr_tensor); } case onnx::TensorProto_DataType_INT32: { - ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_int32_int32(attr_tensor); } case onnx::TensorProto_DataType_INT64: { - ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_int64_int64(attr_tensor); } case onnx::TensorProto_DataType_UINT64: { - ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_uint64_uint64(attr_tensor); } case onnx::TensorProto_DataType_FLOAT: { - ParseAttrInScalar_float_float(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_float_float(attr_tensor); } case onnx::TensorProto_DataType_DOUBLE: { - ParseAttrInScalar_double_double(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_double_double(attr_tensor); } case onnx::TensorProto_DataType_BOOL: { - ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor); - auto value = prim->GetAttr(attr_name); - break; + return ParseAttrInScalar_int32_bool(attr_tensor); } - default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; - return false; + default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + return {}; } - return true; + return {}; } bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, @@ -268,7 +365,6 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr prim->set_attr(attr_name, MakeValue(attr_value)); } } - return ret == EOK; } @@ -280,22 +376,46 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con return false; } const std::string &ref_attr_name = attr_proto.ref_attr_name(); - const onnx::TensorProto &attr_tensor = attr_proto.t(); - switch (kParseTypeSwitchMap[ref_attr_name]) { - case FORM_PARSE_TYPE: { - return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); - } - case FORM_PARSE_SCALAR: { - return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + } + default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; } - case FORM_PARSE_TENSOR: { - return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + } + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + std::unordered_map::iterator iter = kv.begin(); + prim->AddAttr(attr_name, iter->second); + } else { + auto res = ParserScalarAttrValue(ref_attr_name, kv); + prim->AddAttr(attr_name, res); } - default: - MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; - return false; } + return true; } + bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); @@ -321,53 +441,6 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val return true; } -bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - ValuePtr value_ptr = nullptr; - switch (attr_tensor_type) { - case onnx::TensorProto_DataType_INT32: { - std::vector add_data; - for (int i = 0; i < attr_tensor.int32_data_size(); ++i) { - add_data.push_back(attr_tensor.int32_data(i)); - } - if (add_data.size() == 1) { - value_ptr = MakeValue(add_data[0]); - } else if (!add_data.empty()) { - value_ptr = MakeValue >(add_data); - } - break; - } - case onnx::TensorProto_DataType_FLOAT: { - std::vector add_data; - for (int i = 0; i < attr_tensor.float_data_size(); ++i) { - add_data.push_back(attr_tensor.float_data(i)); - } - - if (add_data.size() == 1) { - value_ptr = MakeValue(add_data[0]); - } else if (!add_data.empty()) { - value_ptr = MakeValue >(add_data); - } - break; - } - case onnx::TensorProto_DataType_UNDEFINED: { - std::vector elems; - value_ptr = std::make_shared(elems); - break; - } - default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; - return false; - } - auto new_value_node = NewValueNode(value_ptr); - MS_EXCEPTION_IF_NULL(new_value_node); - new_value_node->set_abstract(value_ptr->ToAbstract()); - anfnode_build_map_[value_node_name] = new_value_node; - - return true; -} - bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); @@ -382,23 +455,56 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value return true; } -bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, - const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - switch (kParseTypeSwitchMap[ref_attr_name]) { - case FORM_PARSE_SCALAR: { - return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); - } - case FORM_PARSE_TENSOR: { - return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); +bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name, + const onnx::AttributeProto &attr_proto) { + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); + } + default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; } - case FORM_PARSE_TYPE: { - return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + + ValueNodePtr new_value_node; + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + auto iter = kv.begin(); + new_value_node = NewValueNode(iter->second); + new_value_node->set_abstract(iter->second->ToAbstract()); + } else { + auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); + new_value_node = NewValueNode(value_ptr); + new_value_node->set_abstract(value_ptr->ToAbstract()); } - default: - MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; - return false; + anfnode_build_map_[value_node_name] = new_value_node; } + return true; } bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { @@ -408,22 +514,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto & MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; return false; } - const std::string &ref_attr_name = attr_proto.ref_attr_name(); - const onnx::TensorProto &attr_tensor = attr_proto.t(); - - return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); + return GetAttrValueForValueNode(value_node_name, attr_proto); } -abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { - std::vector shape_vec; - const onnx::TensorProto &attr_tensor = attr_proto.t(); - for (int i = 0; i < attr_tensor.dims_size(); ++i) { - shape_vec.push_back(attr_tensor.dims(i)); +std::unordered_map +AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + std::vector shape_vec; + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + for (int j = 0; j < attr_tensor.dims_size(); ++j) { + shape_vec.push_back(attr_tensor.dims(j)); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape_vec); + kv.insert(std::pair(attr_tensor.name(), abstract_tensor)); } - auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); - auto abstract_tensor = std::make_shared(type_ptr, shape_vec); - MS_EXCEPTION_IF_NULL(abstract_tensor); - return abstract_tensor; + return kv; } CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, @@ -437,25 +544,16 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out const std::string &node_name = node_proto.output(0); const std::string &fullname_with_scope = node_proto.domain(); const std::string &node_type = node_proto.op_type(); - PrimitivePtr prim = std::make_shared(node_type); + PrimitivePtr prim = std::make_shared(node_type); MS_EXCEPTION_IF_NULL(prim); prim->set_instance_name(node_type); - - abstract::AbstractTensorPtr abstract = nullptr; - abstract::AbstractTensorPtr abstract_first = nullptr; - abstract::AbstractTensorPtr abstract_second = nullptr; + std::unordered_map kv; + string shape_ref_attr_name; for (int i = 0; i < node_proto.attribute_size(); ++i) { const onnx::AttributeProto &attr_proto = node_proto.attribute(i); - if (attr_proto.name() == kCNodeShapeAttr) { - abstract = GetAbstractForCNode(attr_proto); - continue; - } - if (attr_proto.name() == kCNodeShape1Attr) { - abstract_first = GetAbstractForCNode(attr_proto); - continue; - } - if (attr_proto.name() == kCNodeShape2Attr) { - abstract_second = GetAbstractForCNode(attr_proto); + if (attr_proto.ref_attr_name().find("shape:") != string::npos) { + shape_ref_attr_name = attr_proto.ref_attr_name(); + kv = GetAbstractForCNode(attr_proto); continue; } if (!GetAttrValueForCNode(prim, attr_proto)) { @@ -463,6 +561,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out return nullptr; } } + std::vector inputs; inputs.clear(); for (int i = 0; i < node_proto.input_size(); ++i) { @@ -481,26 +580,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); - if (node_type == "LayerNorm") { - AbstractBasePtrList elem; - elem.push_back(abstract); - elem.push_back(abstract_first); - elem.push_back(abstract_second); - cnode_ptr->set_abstract(std::make_shared(elem)); - } else if (node_type == "ArgMaxWithValue") { - AbstractBasePtrList elem; - elem.push_back(abstract); - elem.push_back(abstract_first); - cnode_ptr->set_abstract(std::make_shared(elem)); - } else if (nullptr == abstract) { + if (0 == kv.size()) { AbstractBasePtrList elem; for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { elem.push_back(cnode_ptr->input(index)->abstract()); } cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (1 == kv.size()) { + std::unordered_map::iterator iter = kv.begin(); + cnode_ptr->set_abstract(iter->second); } else { + auto abstract = ParserAttrShape(shape_ref_attr_name, kv); cnode_ptr->set_abstract(abstract); } + cnode_ptr->set_fullname_with_scope(fullname_with_scope); anfnode_build_map_[node_name] = cnode_ptr; return cnode_ptr; @@ -652,7 +745,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { auto onnx_model = new onnx::ModelProto; - if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { + if (ReadProtoFromBinaryFile((const char *) model_path.c_str(), onnx_model) != RET_OK) { MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path; return nullptr; } diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.h b/mindspore/lite/tools/anf_importer/import_from_protobuf.h index 718f721fb0..0a7e0e69d1 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.h +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.h @@ -42,9 +42,9 @@ class AnfImporterFromProtobuf : public AnfImporter { int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override; private: - int ConverterConstTensor() override{ return RET_ERROR; }; - int ConverterCNode() override{ return RET_ERROR; }; - int AddReturnCNode() override{ return RET_ERROR; }; + int ConverterConstTensor() override { return RET_ERROR; }; + int ConverterCNode() override { return RET_ERROR; }; + int AddReturnCNode() override { return RET_ERROR; }; bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, const schema::QuantType &quantType); @@ -59,18 +59,15 @@ class AnfImporterFromProtobuf : public AnfImporter { bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); - bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor); + ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - - bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, - const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto); bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); + std::unordered_map GetAbstractForCNode(const onnx::AttributeProto &attr_proto); private: std::string producer_name_;