|
|
|
@ -29,11 +29,12 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
enum OpMergeMode {
|
|
|
|
|
OP_MERGE_UNDEFINED = 0, // undefined behavior
|
|
|
|
|
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
|
|
|
|
|
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
|
|
|
|
|
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
|
|
|
|
|
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
|
|
|
|
|
OP_MERGE_UNDEFINED = 0, // undefined behavior
|
|
|
|
|
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
|
|
|
|
|
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
|
|
|
|
|
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
|
|
|
|
|
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
|
|
|
|
|
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct OpMergedInfo {
|
|
|
|
@ -233,6 +234,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
|
|
|
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
|
|
|
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
|
|
|
|
|
|
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(
|
|
|
|
|
MaxPoolWithArgmax, MaxPool,
|
|
|
|
|
OpNameInfo()
|
|
|
|
|
.Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
|
|
|
|
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
|
|
|
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
|
|
|
|
|
|
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(
|
|
|
|
|
AvgPool, AveragePool,
|
|
|
|
|
OpNameInfo()
|
|
|
|
@ -254,6 +262,7 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
|
|
|
|
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Flatten)());
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(MaxPool)());
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
|
|
|
|
|
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
|
|
|
@ -328,6 +337,8 @@ class OnnxExporter {
|
|
|
|
|
onnx::GraphProto *graph_proto);
|
|
|
|
|
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
|
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
|
|
|
|
void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
|
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
|
|
|
|
|
|
|
|
|
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
|
|
|
|
onnx::GraphProto *graph_proto);
|
|
|
|
@ -516,6 +527,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto
|
|
|
|
|
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM;
|
|
|
|
|
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
|
|
|
|
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
|
|
|
|
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
|
|
|
|
|
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) &&
|
|
|
|
|
GetInt32Value(cnode->input(2)) == 0) {
|
|
|
|
|
op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX;
|
|
|
|
|
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
|
|
|
|
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -563,6 +580,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
|
|
|
|
|
case OP_MERGE_BATCH_NORM:
|
|
|
|
|
ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto);
|
|
|
|
|
break;
|
|
|
|
|
case OP_MERGE_MAXPOOL_WITH_ARGMAX:
|
|
|
|
|
ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
|
|
|
|
|
break;
|
|
|
|
@ -811,6 +831,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN
|
|
|
|
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
|
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
|
|
|
|
onnx::GraphProto *const graph_proto) {
|
|
|
|
|
auto maxpool_with_argmax_node = dyn_cast<CNode>(node->input(1));
|
|
|
|
|
|
|
|
|
|
PrimitivePtr prim_maxpool_with_argmax =
|
|
|
|
|
dyn_cast<Primitive>((dyn_cast<ValueNode>(maxpool_with_argmax_node->input(0)))->value());
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) {
|
|
|
|
|
inputs.push_back(maxpool_with_argmax_node->input(i));
|
|
|
|
|
}
|
|
|
|
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
|
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
|
|
|
|
if (node->inputs().size() != 2) {
|
|
|
|
|