|
|
|
@ -17,17 +17,18 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include "mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h"
|
|
|
|
|
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) {
|
|
|
|
|
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
|
|
|
|
|
if (attr == nullptr || attr->group != attr->channelIn) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new (std::nothrow) schema::DepthwiseConv2DT());
|
|
|
|
|
if (depthwiseConv2DParam == nullptr) {
|
|
|
|
|
// MS_LOGW("new DepthwiseConv2DT failed");
|
|
|
|
|
MS_LOG(ERROR) << "new DepthwiseConv2DT failed";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
depthwiseConv2DParam->format = attr->format;
|
|
|
|
@ -48,12 +49,12 @@ bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *
|
|
|
|
|
depthwiseConv2DParam->activationType = attr->activationType;
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
|
|
|
|
delete (op->primitive->value.value);
|
|
|
|
|
op->primitive->value.value = depthwiseConv2DParam.release();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
|
|
|
|
MS_LOG(DEBUG) << "onnx ConvParser";
|
|
|
|
|
auto attr = new schema::Conv2DT();
|
|
|
|
|
// set opdef each attr params
|
|
|
|
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
|
|
|
@ -61,30 +62,32 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|
|
|
|
attr->group = static_cast<int32_t>(onnx_node_attr.i());
|
|
|
|
|
} else if (onnx_node_attr.name() == "dilations") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
// TODO(wangzhe) verify the change
|
|
|
|
|
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "kernels") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "kernel_shape") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
// TODO(wangzhe) verify the change
|
|
|
|
|
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "auto_pad") {
|
|
|
|
|
attr->padMode = GetOnnxPadMode(onnx_node_attr);
|
|
|
|
|
} else if (onnx_node_attr.name() == "pads") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 4) {
|
|
|
|
|
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
@ -93,16 +96,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|
|
|
|
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
|
|
|
|
|
} else if (onnx_node_attr.name() == "strides") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
// TODO(wangzhe) verify the change
|
|
|
|
|
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "order") {
|
|
|
|
|
if (onnx_node_attr.s() == "NHWC") {
|
|
|
|
|
attr->format = schema::Format_NHWC;
|
|
|
|
|
} else {
|
|
|
|
|
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str());
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -114,7 +118,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|
|
|
|
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
|
|
|
|
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
|
|
|
|
|
if (nodeIter == onnx_graph.initializer().end()) {
|
|
|
|
|
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str())
|
|
|
|
|
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> weight_shape;
|
|
|
|
@ -129,7 +133,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|
|
|
|
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
|
|
|
|
|
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; });
|
|
|
|
|
if (nodeIter == onnx_graph.node().end()) {
|
|
|
|
|
// MS_LOGE("can not find node: %s", onnx_conv_weight.c_str())
|
|
|
|
|
MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> dims;
|
|
|
|
@ -139,6 +143,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|
|
|
|
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
|
|
|
|
|
}
|
|
|
|
|
attr->channelOut = dims[0];
|
|
|
|
|
// TODO(wangzhe) verify this code
|
|
|
|
|
attr->channelIn = dims[3] * attr->group;
|
|
|
|
|
}
|
|
|
|
|
attr->format = schema::Format_NCHW;
|
|
|
|
@ -156,7 +161,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|
|
|
|
if (attr->group != 1) {
|
|
|
|
|
if (!ParseGroupConvolution(op, attr)) {
|
|
|
|
|
delete attr;
|
|
|
|
|
// MS_LOGE("Convert Convolution to Depthwise failed");
|
|
|
|
|
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -169,4 +174,3 @@ OnnxNodeRegistrar g_onnxConvReluParser("ConvRelu", new OnnxConvParser());
|
|
|
|
|
OnnxNodeRegistrar g_onnxInt8ConvReluParser("Int8ConvRelu", new OnnxConvParser());
|
|
|
|
|
} // namespace lite
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
|
|
|
|
|