diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc new file mode 100644 index 0000000000..de5007ca5f --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Flatten; + node->primitive->value.value = attr.release(); + return 0; +} + +AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h new file mode 100644 index 0000000000..776aab0f94 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H +#define MINDSPORE_ANF_RESHAPE_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfReshapePopulater : public AnfNodePopulater { + public: + AnfReshapePopulater() = default; + ~AnfReshapePopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_RESHAPE_PARSER_H diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc index f4505fbfb1..5daaeb12db 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc @@ -15,25 +15,28 @@ */ #include "src/common/anf_importer/import_from_protobuf.h" + #include #include + +#include #include #include -#include -#include #include +#include #include +#include #include -#include -#include "ir/func_graph.h" -#include "ir/anf.h" + +#include "frontend/operator/ops.h" #include "google/protobuf/io/zero_copy_stream_impl.h" -#include "src/param_value_lite.h" +#include "include/errorcode.h" +#include "ir/anf.h" +#include "ir/func_graph.h" #include "src/ir/tensor.h" -#include "frontend/operator/ops.h" +#include "src/param_value_lite.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "utils/log_adapter.h" -#include "include/errorcode.h" using string = std::string; using int32 = int32_t; @@ -54,26 +57,27 @@ enum ParseForm : int { }; static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, - {"scalar", FORM_PARSE_SCALAR}, - {"tensor", FORM_PARSE_TENSOR}}; + {"type", FORM_PARSE_TYPE}, + {"scalar", FORM_PARSE_SCALAR}, + {"tensor", FORM_PARSE_TENSOR}}; static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, - {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, - {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, - {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, - {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, - {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, - {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, + {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, + {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, + {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, + {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, + {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, + {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, }; +#if 0 std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, const std::unordered_map &kv) { std::string str = attr_name; @@ -190,16 +194,17 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map(attr_tensor.type##_data(0)); \ - return MakeValue(value); \ -} else { \ - MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ -} \ -return{}; \ -} +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype( \ + const onnx::TensorProto &attr_tensor) { \ + if (attr_tensor.type##_data_size() == 1) { \ + auto value = static_cast(attr_tensor.type##_data(0)); \ + return MakeValue(value); \ + } else { \ + MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ + } \ + return {}; \ + } PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) @@ -634,8 +639,508 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); return true; } +#endif + +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + void ParseAttrInScalar_##type##_##valuetype( \ + const PrimitivePtr &prim, const std::string &attr_name, \ + const onnx::TensorProto &attr_tensor) { \ + MS_EXCEPTION_IF_NULL(prim); \ + std::vector attr_value_vec; \ + for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ + auto value = static_cast(attr_tensor.type##_data(i)); \ + attr_value_vec.push_back(MakeValue(value)); \ + } \ + if (attr_value_vec.size() == 1) { \ + prim->AddAttr(attr_name, attr_value_vec[0]); \ + } else { \ + prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ + } \ + } + +PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) +PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) +PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) +PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) +PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) + +bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( + const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { + MS_EXCEPTION_IF_NULL(node); + if (!value_proto.has_type() || !value_proto.has_name()) { + MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; + return false; + } + node->set_name(value_proto.name()); + const auto &type_proto = value_proto.type(); + if (!type_proto.has_tensor_type()) { + MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! "; + return false; + } + const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); + if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; + return false; + } + const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); + std::vector shape; + for (int i = 0; i < tensor_shape.dim_size(); ++i) { + shape.push_back(tensor_shape.dim(i).dim_value()); + } + + if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == + kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; + return false; + } + + auto type_ptr = + TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); + auto abstract_tensor = + std::make_shared(type_ptr, shape); + node->set_abstract(abstract_tensor); + + if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { + tensor::Tensor *tensor_info = new tensor::Tensor( + kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); + MS_EXCEPTION_IF_NULL(tensor_info); + tensor_info->MallocData(); + const onnx::TensorProto initialize_proto = + default_para_map_[value_proto.name()]; + std::string initial_data = initialize_proto.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + MS_EXCEPTION_IF_NULL(tensor_data_buf); + memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), + initial_data.size()); + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_addr(tensor_data_buf); + param_value->set_tensor_size(tensor_info->Size()); + node->set_default_param(param_value); + } + anfnode_build_map_[value_proto.name()] = node; + return true; +} + +bool AnfImporterFromProtobuf::ImportParametersForGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "Parameters had default paramerer size is: " + << importProto.initializer_size(); + + for (int i = 0; i < importProto.initializer_size(); ++i) { + const onnx::TensorProto &initializer_proto = importProto.initializer(i); + if (!initializer_proto.has_name()) { + MS_LOG(ERROR) + << "initializer vector of onnx GraphProto has no name at index: " + << i; + return false; + } + default_para_map_[initializer_proto.name()] = initializer_proto; + } + + MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); + for (int i = 0; i < importProto.input_size(); ++i) { + const onnx::ValueInfoProto &input_proto = importProto.input(i); + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), + input_proto)) { + MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; + return false; + } + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm( + const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == + kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" + << attr_tensor_type; + return false; + } + prim->AddAttr(attr_name, + TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( + const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + const int attr_tensor_type = attr_tensor.data_type(); + switch (attr_tensor_type) { + case onnx::TensorProto_DataType_STRING: { + ParseAttrInScalar_string_string(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_INT32: { + ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_INT64: { + ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_UINT64: { + ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_FLOAT: { + ParseAttrInScalar_float_float(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_DOUBLE: { + ParseAttrInScalar_double_double(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_BOOL: { + ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor); + auto value = prim->GetAttr(attr_name); + break; + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " + << attr_tensor_type; + return false; + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( + const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; + return false; +} + +bool AnfImporterFromProtobuf::GetAttrValueForCNode( + const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { + MS_EXCEPTION_IF_NULL(prim); + const std::string &attr_name = attr_proto.name(); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + const onnx::TensorProto &attr_tensor = attr_proto.t(); + switch (kParseTypeSwitchMap[ref_attr_name]) { + case FORM_PARSE_TYPE: { + return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor); + } + case FORM_PARSE_TENSOR: { + return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; + } +} +bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( + const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + std::vector shape; + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor_info = std::make_shared( + kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor_info->MallocData(); + const std::string &tensor_buf = attr_tensor.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), + tensor_buf.size()); + auto new_value_node = NewValueNode(MakeValue(tensor_info)); + MS_EXCEPTION_IF_NULL(new_value_node); + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); + auto abstract_tensor = + std::make_shared(type_ptr, shape); + new_value_node->set_abstract(abstract_tensor); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( + const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + ValuePtr value_ptr = nullptr; + switch (attr_tensor_type) { + case onnx::TensorProto_DataType_INT32: { + std::vector add_data; + for (int i = 0; i < attr_tensor.int32_data_size(); ++i) { + add_data.push_back(attr_tensor.int32_data(i)); + } + if (add_data.size() == 1) { + value_ptr = MakeValue(add_data[0]); + } else if (!add_data.empty()) { + value_ptr = MakeValue>(add_data); + } + break; + } + case onnx::TensorProto_DataType_FLOAT: { + std::vector add_data; + for (int i = 0; i < attr_tensor.float_data_size(); ++i) { + add_data.push_back(attr_tensor.float_data(i)); + } + + if (add_data.size() == 1) { + value_ptr = MakeValue(add_data[0]); + } else if (!add_data.empty()) { + value_ptr = MakeValue>(add_data); + } + break; + } + case onnx::TensorProto_DataType_UNDEFINED: { + std::vector elems; + value_ptr = std::make_shared(elems); + break; + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " + << attr_tensor_type; + return false; + } + auto new_value_node = NewValueNode(value_ptr); + MS_EXCEPTION_IF_NULL(new_value_node); + new_value_node->set_abstract(value_ptr->ToAbstract()); + anfnode_build_map_[value_node_name] = new_value_node; + + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm( + const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == + kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) + << "Obtain ValueNode attr in type-form has not support input type: " + << attr_tensor_type; + return false; + } + auto new_value_node = + NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = + std::make_shared(std::make_shared()); + new_value_node->set_abstract(abs_type); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::GetAttrValueForValueNode( + const std::string &ref_attr_name, const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + switch (kParseTypeSwitchMap[ref_attr_name]) { + case FORM_PARSE_SCALAR: { + return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); + } + case FORM_PARSE_TENSOR: { + return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); + } + case FORM_PARSE_TYPE: { + return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + default: + MS_LOG(ERROR) + << "parse ValueNode value don't support input of ref_attr_name"; + return false; + } +} + +bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( + const onnx::NodeProto &node_proto) { + const std::string &value_node_name = node_proto.output(0); + const onnx::AttributeProto &attr_proto = node_proto.attribute(0); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + const onnx::TensorProto &attr_tensor = attr_proto.t(); + + return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); +} + +abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode( + const onnx::AttributeProto &attr_proto) { + std::vector shape_vec; + const onnx::TensorProto &attr_tensor = attr_proto.t(); + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape_vec.push_back(attr_tensor.dims(i)); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); + auto abstract_tensor = + std::make_shared(type_ptr, shape_vec); + MS_EXCEPTION_IF_NULL(abstract_tensor); + return abstract_tensor; +} + +CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + if (!node_proto.has_op_type()) { + MS_LOG(ERROR) << "Get CNode op_type failed!"; + return nullptr; + } + const std::string &node_name = node_proto.output(0); + const std::string &fullname_with_scope = node_proto.domain(); + const std::string &node_type = node_proto.op_type(); + PrimitivePtr prim = std::make_shared(node_type); + MS_EXCEPTION_IF_NULL(prim); + prim->set_instance_name(node_type); + + abstract::AbstractTensorPtr abstract = nullptr; + abstract::AbstractTensorPtr abstract_first = nullptr; + abstract::AbstractTensorPtr abstract_second = nullptr; + for (int i = 0; i < node_proto.attribute_size(); ++i) { + const onnx::AttributeProto &attr_proto = node_proto.attribute(i); + if (attr_proto.name() == kCNodeShapeAttr) { + abstract = GetAbstractForCNode(attr_proto); + continue; + } + if (attr_proto.name() == kCNodeShape1Attr) { + abstract_first = GetAbstractForCNode(attr_proto); + continue; + } + if (attr_proto.name() == kCNodeShape2Attr) { + abstract_second = GetAbstractForCNode(attr_proto); + continue; + } + if (!GetAttrValueForCNode(prim, attr_proto)) { + MS_LOG(ERROR) << "Get CNode attr failed!"; + return nullptr; + } + } + + std::vector inputs; + inputs.clear(); + inputs.push_back(NewValueNode(prim)); + for (int i = 0; i < node_proto.input_size(); ++i) { + const std::string &input_name = node_proto.input(i); + if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { + MS_LOG(ERROR) << node_name << " input " << i << input_name + << "can't find in nodes have parsed"; + return nullptr; + } + inputs.push_back(anfnode_build_map_[input_name]); + } + CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + if (node_type == "LayerNorm") { + AbstractBasePtrList elem; + elem.push_back(abstract); + elem.push_back(abstract_first); + elem.push_back(abstract_second); + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (node_type == "ArgMaxWithValue") { + AbstractBasePtrList elem; + elem.push_back(abstract); + elem.push_back(abstract_first); + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (nullptr == abstract) { + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + elem.push_back(cnode_ptr->input(index)->abstract()); + } + cnode_ptr->set_abstract(std::make_shared(elem)); + } else { + cnode_ptr->set_abstract(abstract); + } + cnode_ptr->set_fullname_with_scope(fullname_with_scope); + anfnode_build_map_[node_name] = cnode_ptr; + return cnode_ptr; +} + +bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_EXCEPTION_IF_NULL(cnode_ptr); + std::vector inputs; + if (importProto.output_size() > 1) { + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + AbstractBasePtrList elem; + for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { + const onnx::ValueInfoProto &output_node = importProto.output(out_size); + const std::string &out_tuple = output_node.name(); + inputs.push_back(anfnode_build_map_[out_tuple]); + elem.push_back(anfnode_build_map_[out_tuple]->abstract()); + } + auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); + maketuple_ptr->set_abstract( + std::make_shared(elem)); + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(maketuple_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success."; + } else { + const onnx::ValueInfoProto &output_node = importProto.output(0); + const onnx::TypeProto &output_typeproto = output_node.type(); + int output_type = output_typeproto.tensor_type().elem_type(); + std::vector output_shape; + for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); + ++i) { + output_shape.push_back( + output_typeproto.tensor_type().shape().dim(i).dim_value()); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); + auto abstract_tensor = + std::make_shared(type_ptr, output_shape); + + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(cnode_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + return_node->set_abstract(abstract_tensor); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success!"; + } + return true; +} + +bool AnfImporterFromProtobuf::ImportNodesForGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); + CNodePtr cnode_ptr = nullptr; + for (int i = 0; i < importProto.node_size(); ++i) { + const onnx::NodeProto &node_proto = importProto.node(i); + const std::string &node_type = node_proto.op_type(); + if (node_type == kConstantValueNode) { + if (!BuildValueNodeForFuncGraph(node_proto)) { + MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; + return false; + } + continue; + } + cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; + return false; + } + } + + BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); + return true; +} -bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { +bool AnfImporterFromProtobuf::BuildFuncGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); MS_EXCEPTION_IF_NULL(debug_info_ptr); @@ -651,7 +1156,8 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph return ImportNodesForGraph(outputFuncGraph, importProto); } -bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { +bool AnfImporterFromProtobuf::ParseModelConfigureInfo( + const onnx::ModelProto &model_proto) { if (!model_proto.has_producer_name()) { MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; return false; @@ -672,7 +1178,6 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo return true; } - int AnfImporterFromProtobuf::Import() { FuncGraphPtr dstGraph = std::make_shared(); MS_EXCEPTION_IF_NULL(dstGraph); @@ -689,9 +1194,9 @@ int AnfImporterFromProtobuf::Import() { return RET_OK; } - -onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { - std::unique_ptr onnx_file(new(std::nothrow) char[PATH_MAX]{0}); +onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary( + const std::string &model_path) { + std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { MS_LOG(ERROR) << "open file failed."; return nullptr; @@ -707,11 +1212,10 @@ onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string delete onnx_model; return nullptr; } - (void) close(fd); + (void)close(fd); MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; return onnx_model; } FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h index b4fbea9eaf..4513c79f17 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h @@ -47,6 +47,7 @@ class AnfImporterFromProtobuf : public AnfImporter { bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); +#if 0 bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, @@ -76,6 +77,30 @@ class AnfImporterFromProtobuf : public AnfImporter { const onnx::TensorProto &attr_tensor); std::unordered_map GetAbstractForCNode(const onnx::AttributeProto &attr_proto); +#endif + bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); + bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); + bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr); + bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); + bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + + bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, + const onnx::TensorProto &attr_tensor); + bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); + + private: std::string producer_name_;