diff --git a/mindspore/ccsrc/onnx/ir_exporter.cc b/mindspore/ccsrc/onnx/ir_exporter.cc index 687d7c23e2..b47f983695 100644 --- a/mindspore/ccsrc/onnx/ir_exporter.cc +++ b/mindspore/ccsrc/onnx/ir_exporter.cc @@ -91,9 +91,9 @@ class IrExportBuilder { void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const CNodePtr &node, const std::vector &inputs, - onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, + std::string suffix = "0"); void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); @@ -112,10 +112,10 @@ class IrExportBuilder { private: onnx::ModelProto model_; - onnx::NodeProto *last_node_; + onnx::NodeProto *last_node_{nullptr}; std::list todo_; std::map node_index_map_; - size_t node_index_ = 0; + size_t node_index_{0}; }; using IrExporterPtr = std::shared_ptr; @@ -349,44 +349,34 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { } void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::NodeProto *const node_proto) { + onnx::NodeProto *const node_proto, std::string suffix) { onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_ref_attr_name("shape"); - attr_proto->set_name("shape"); + if (suffix.compare("0") != 0) { + attr_proto->set_name("shape" + suffix); + } else { + attr_proto->set_name("shape"); + } onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); SetTensorProto(type, shape, tensor_proto); } -void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector &inputs, - onnx::NodeProto *const node_proto) { +void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { // Get shape of cnode - // 1. prim kPrimTupleGetItem need to get shape of input node according to the index + // 1. prim ArgMaxWithValue need to get shape from tuple element // 2. some cnode doesn't has shape, such as LayerNorm // 3. other cnodes have shape - if (node->IsApply(prim::kPrimTupleGetItem)) { - // Get index of tuple get_item - int index_pos = inputs.size() - 1; - if (!inputs[index_pos]->isa()) { - MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos; - } - auto value = inputs[index_pos]->cast()->value(); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name(); - } - size_t index = GetValue(value); - - // Get type and shape of input node - auto tup_type = inputs[0]->Type(); - if (!tup_type->isa()) { - MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name(); + if (node->IsApply(prim::kPrimArgMaxWithValue)) { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa()) { + MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); } - auto type = tup_type->cast()->elements()[index]; - auto tup_shape = inputs[0]->Shape()->cast(); - if (index >= tup_shape->shape().size()) { - MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size(); + auto elements = type->cast()->elements(); + auto tuple_shape = shape->cast()->shape(); + for (size_t i = 0; i < elements.size(); i++) { + SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); } - auto shape = tup_shape->shape()[index]; - SetShapeToNodeProto(type, shape, node_proto); } else { auto type = node->Type(); auto shape = node->Shape(); @@ -422,7 +412,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g std::string type_name = GetOpTypeName(op); node_proto->set_op_type(type_name); last_node_ = node_proto; - SetShapeToNodeProto(node, op_inputs, node_proto); + SetShapeToNodeProto(node, node_proto); (void)std::for_each(input_names.begin(), input_names.end(), [&node_proto](const string &name) { node_proto->add_input(name); }); diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index f5ec529877..3d93a33356 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -147,6 +147,7 @@ const PrimitivePtr kPrimAddN = std::make_shared("AddN"); const PrimitivePtr KPrimTransData = std::make_shared("TransData"); const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); const PrimitivePtr kPrimPad = std::make_shared("Pad"); +const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); // Maths const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 291c4a92d5..c6bea7fe7a 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -156,6 +156,7 @@ extern const PrimitivePtr kPrimAddN; extern const PrimitivePtr KPrimTransData; extern const PrimitivePtr kPrimNMSWithMask; extern const PrimitivePtr kPrimPad; +extern const PrimitivePtr kPrimArgMaxWithValue; // Maths extern const PrimitivePtr kPrimTensorAdd; diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index f5046bb1ec..d931796d47 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -320,6 +320,13 @@ def test_export(): export(net, input_data, file_name="./me_export.pb", file_format="GEIR") +@non_graph_engine +def test_binary_export(): + net = MYNET() + input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) + export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY") + + def teardown_module(): files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] for item in files: