From cd899fba0a53fbdaebe058a907ccc13d4f312767 Mon Sep 17 00:00:00 2001 From: meixiaowei Date: Tue, 28 Apr 2020 14:25:59 +0800 Subject: [PATCH] ONNX adapter for the MaxPoolWithArgmax --- mindspore/ccsrc/onnx/onnx_exporter.cc | 44 ++++++++++++++++++++++--- tests/ut/python/utils/test_serialize.py | 25 ++++++++++++++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc index 168e625a89..1c5a7b93c3 100644 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/onnx/onnx_exporter.cc @@ -29,11 +29,12 @@ namespace mindspore { enum OpMergeMode { - OP_MERGE_UNDEFINED = 0, // undefined behavior - OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list - OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` - OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` - OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` + OP_MERGE_UNDEFINED = 0, // undefined behavior + OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list + OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` + OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` + OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` + OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` }; struct OpMergedInfo { @@ -233,6 +234,13 @@ OPERATOR_ONNX_CONVERT_DEFINE( .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) +OPERATOR_ONNX_CONVERT_DEFINE( + MaxPoolWithArgmax, MaxPool, + OpNameInfo() + .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) + .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) + OPERATOR_ONNX_CONVERT_DEFINE( AvgPool, AveragePool, OpNameInfo() @@ -254,6 +262,7 @@ void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(Flatten)()); fn(OP_CONVERT_FUNCTION_NAME(MaxPool)()); + fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)()); fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); @@ -328,6 +337,8 @@ class OnnxExporter { onnx::GraphProto *graph_proto); void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); @@ -516,6 +527,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; op_merged_infos[cnode->input(1)].referred_count -= 1; + } else if (cnode->IsApply(prim::kPrimTupleGetItem) && + IsPrimitiveCNode(cnode->input(1), std::make_shared("MaxPoolWithArgmax")) && + GetInt32Value(cnode->input(2)) == 0) { + op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; } } } @@ -563,6 +580,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto maxpool_with_argmax_node = dyn_cast(node->input(1)); + + PrimitivePtr prim_maxpool_with_argmax = + dyn_cast((dyn_cast(maxpool_with_argmax_node->input(0)))->value()); + std::vector inputs; + for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) { + inputs.push_back(maxpool_with_argmax_node->input(i)); + } + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto); +} + void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { if (node->inputs().size() != 2) { diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index cc6f346b77..59a4b93833 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -362,6 +362,31 @@ def test_lenet5_onnx_export(): net = LeNet5() export(net, input, file_name='lenet5.onnx', file_format='ONNX') +class DefinedNet(nn.Cell): + """simple Net definition with maxpoolwithargmax.""" + def __init__(self, num_classes=10): + super(DefinedNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros") + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=2, strides=2) + self.flatten = nn.Flatten() + self.fc = nn.Dense(int(56*56*64), num_classes) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x, argmax = self.maxpool(x) + x = self.flatten(x) + x = self.fc(x) + return x + +def test_net_onnx_maxpoolwithargmax_export(): + input = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01) + net = DefinedNet() + export(net, input, file_name='definedNet.onnx', file_format='ONNX') + @run_on_onnxruntime def test_lenet5_onnx_load_run():