diff --git a/mindspore/ccsrc/transform/onnx/ir_exporter.cc b/mindspore/ccsrc/transform/onnx/ir_exporter.cc index 78858eea8a..6feb4db1be 100644 --- a/mindspore/ccsrc/transform/onnx/ir_exporter.cc +++ b/mindspore/ccsrc/transform/onnx/ir_exporter.cc @@ -92,14 +92,17 @@ class IrExportBuilder { void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, - std::string suffix = "0"); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, + std::string *const seq_string); void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); - void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); + void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string); + void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string); onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); @@ -107,8 +110,10 @@ class IrExportBuilder { std::string GetNodeName(const AnfNodePtr &node); std::string GetUniqueNodeName(const AnfNodePtr &node); std::string GetOpTypeName(const AnfNodePtr &node); - size_t AllocateIndex() { return ++node_index_; } - void ResetIndex() { node_index_ = 0; } + size_t GetNodeIndex() { return ++node_index_; } + void ResetNodeIndex() { node_index_ = 0; } + size_t GetTupleIndex() { return ++shape_index_; } + void ResetTupleIndex() { shape_index_ = 0; } private: onnx::ModelProto model_; @@ -116,6 +121,7 @@ class IrExportBuilder { std::list todo_; std::map node_index_map_; size_t node_index_{0}; + size_t shape_index_{0}; }; using IrExporterPtr = std::shared_ptr; @@ -148,7 +154,7 @@ void IrExportBuilder::BuildModelInfo() { void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { onnx::GraphProto *graph_proto = model_.mutable_graph(); graph_proto->set_name(func_graph->ToString()); - ResetIndex(); + ResetNodeIndex(); todo_.clear(); todo_.push_back(func_graph); while (!todo_.empty()) { @@ -179,7 +185,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap input_proto->set_name(param_name); SetValueInfoProto(param, input_proto); if (!param->has_default()) { - MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; + MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; continue; } @@ -234,13 +240,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr auto elem_type = tensor->element(); const auto &dims = shape->cast()->shape(); type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); - for (const auto &dim : dims) { - MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + if (dims.size() == 0) { + MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + } else { + for (const auto &dim : dims) { + MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } } } else if (type->isa()) { auto tup_shape = shape->cast(); - type_proto->set_denotation(std::to_string(tup_shape->shape().size())); + type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); + } else if (type->isa() || type->isa()) { + type_proto->set_denotation(type->type_name()); } else { MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; } @@ -250,9 +263,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("tensor"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + attr_proto->set_ref_attr_name("tensor:value0"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + tensor_proto->set_name("value0"); auto data = value->cast(); tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); auto dtype = data->data_type(); @@ -286,6 +300,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + bool is_only_return = true; for (const AnfNodePtr &node : nodes) { if (!node->isa()) { MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; @@ -293,9 +308,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt } auto cnode = node->cast(); if (cnode == func_graph->get_return()) { + if (is_only_return) { + MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!"; + } BuildOutput(cnode, graph_proto); } else { BuildCNode(cnode, graph_proto); + is_only_return = false; } } } @@ -305,24 +324,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; } AnfNodePtr arg = node->input(1); - // Using make_tuple to set multi-output - if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { - auto tuple_node = arg->cast(); - for (size_t i = 1; i < tuple_node->size(); i++) { - auto input_node = arg->cast()->input(i); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - auto output_name = GetUniqueNodeName(tuple_node->input(i)); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(tuple_node->input(i), output_proto); - } - } else { - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - std::string output_name = GetUniqueNodeName(node); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(arg, output_proto); - } + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + std::string output_name = GetUniqueNodeName(node); + output_proto->set_name(output_name); + last_node_->set_output(0, output_name); + SetValueInfoProto(arg, output_proto); } std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { @@ -345,45 +351,43 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { } void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::NodeProto *const node_proto, std::string suffix) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_ref_attr_name("shape"); - if (suffix.compare("0") != 0) { - attr_proto->set_name("shape" + suffix); - } else { - attr_proto->set_name("shape"); - } - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetTensorProto(type, shape, tensor_proto); -} - -void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { - // Get shape of cnode - // 1. prim ArgMaxWithValue need to get shape from tuple element - // 2. some cnode doesn't has shape, such as LayerNorm - // 3. other cnodes have shape - if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa()) { - MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); - } + onnx::AttributeProto *const attr_proto, std::string *const seq_string) { + if (type->isa() && seq_string != nullptr) { + *seq_string += "Tuple["; auto elements = type->cast()->elements(); auto tuple_shape = shape->cast()->shape(); for (size_t i = 0; i < elements.size(); i++) { - SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); + SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); } + *seq_string += "],"; + } else if (type->isa() && shape->isa() && seq_string != nullptr) { + string shape_name = "shape" + std::to_string(GetTupleIndex()); + *seq_string += shape_name + ","; + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + tensor_proto->set_name(shape_name); + SetTensorProto(type, shape, tensor_proto); + } else if ((type->isa() || type->isa()) && seq_string != nullptr) { + *seq_string += type->type_name() + ","; } else { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa() || !shape->isa()) { - MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); - return; - } - SetShapeToNodeProto(type, shape, node_proto); + MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); } } +void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { + // Get shape of cnode + // 1. need to get shape from tuple element + // 2. save shape in TensorProto + // 3. save tuple string in ref_attr_name + auto type = node->Type(); + auto shape = node->Shape(); + ResetTupleIndex(); + std::string seq_string = "shape:"; + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + SetShapeToNodeProto(type, shape, attr_proto, &seq_string); + attr_proto->set_ref_attr_name(seq_string); + MS_LOG(DEBUG) << "CNode shape: " << seq_string; +} + void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { auto inputs_size = node->size(); if (inputs_size < 1) { @@ -445,15 +449,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { std::string node_name = ""; if (node->isa()) { node_name = GetNodeName(node); - } else if (node->isa() || node->isa()) { + } else if (node->isa()) { auto iter = node_index_map_.find(node); if (iter != node_index_map_.end()) { node_name = GetNodeName(node) + ":" + std::to_string(iter->second); } else { - auto node_idx = AllocateIndex(); + auto node_idx = GetNodeIndex(); node_index_map_[node] = node_idx; node_name = GetNodeName(node) + ":" + std::to_string(node_idx); } + } else if (node->isa()) { + auto node_idx = GetNodeIndex(); + node_index_map_[node] = node_idx; + node_name = GetNodeName(node) + ":" + std::to_string(node_idx); } else { MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); } @@ -487,17 +495,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("type"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); if (value->isa()) { + attr_proto->set_ref_attr_name("type:value0"); + tensor_proto->set_name("value0"); auto int_value = value->cast(); tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); } else if (value->isa()) { + attr_proto->set_ref_attr_name("type:value0"); + tensor_proto->set_name("value0"); auto float_value = value->cast(); tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); } else if (value->isa()) { - tensor_proto->set_name("tensor"); + attr_proto->set_ref_attr_name("type:tensor0"); + tensor_proto->set_name("tensor0"); auto elem_type = value->cast()->element(); if (elem_type->isa()) { auto int_value = elem_type->cast(); @@ -521,10 +533,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr SetScalarToAttributeProto(value, attr_proto); } else if (value->isa() || value->isa()) { SetTypeToAttributeProto(value, attr_proto); - } else if (value->isa()) { - SetSequenceToAttributeProto(value->cast(), attr_proto); + } else if (value->isa() || value->isa()) { + ResetTupleIndex(); + std::string seq_string = "scalar:"; + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + SetSequenceToAttributeProto(value->cast(), attr_proto, &seq_string); + attr_proto->set_ref_attr_name(seq_string); + MS_LOG(DEBUG) << "Attr string: " << seq_string; } else if (value->isa()) { SetTensorToAttributeProto(value, attr_proto); + } else if (value->isa()) { + attr_proto->set_ref_attr_name("none"); + MS_LOG(DEBUG) << "Attr string: " << value->type_name(); } else { MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); } @@ -534,16 +554,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetScalarToProto(value, tensor_proto); + attr_proto->set_ref_attr_name("scalar:value0"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + SetScalarToProto(value, tensor_proto, "value0"); } -void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { +void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, + const std::string &value_name) { if (value == nullptr || tensor_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; } + tensor_proto->set_name(value_name); if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); tensor_proto->add_string_data(GetValue(value)); @@ -562,44 +584,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto } else if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); tensor_proto->add_int64_data(value->cast()->value()); - } else if (value->isa()) { + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); + tensor_proto->add_uint64_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); + tensor_proto->add_uint64_data(value->cast()->value()); + } else if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); tensor_proto->add_float_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); + tensor_proto->add_double_data(GetValue(value)); } else { MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); } } -void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, - onnx::AttributeProto *const attr_proto) { +void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string) { + string value_name = "value" + std::to_string(GetTupleIndex()); + if (seq_string != nullptr) { + *seq_string += value_name + ","; + } + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + SetScalarToProto(value, tensor_proto, value_name); +} + +void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - if (value->isa()) { + if (value->isa() && seq_string != nullptr) { + *seq_string += "Tuple["; const ValueTuplePtr &tuple_value = value->cast(); if (tuple_value->value().size() == 0) { MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; return; } - auto type_id = tuple_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); for (const auto &item : tuple_value->value()) { - SetScalarToProto(item, tensor_proto); + if (item->isa()) { + SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); + } else { + SetSeqElemToAttributeProto(item, attr_proto, seq_string); + } } - } else if (value->isa()) { + *seq_string += "],"; + } else if (value->isa() && seq_string != nullptr) { + *seq_string += "List["; const ValueListPtr &list_value = value->cast(); if (list_value->value().size() == 0) { - MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; + MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; return; } - auto type_id = list_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); for (const auto &item : list_value->value()) { - SetScalarToProto(item, tensor_proto); + if (item->isa()) { + SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); + } else { + SetSeqElemToAttributeProto(item, attr_proto, seq_string); + } } + *seq_string += "],"; } }