|
|
|
@ -53,42 +53,36 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
const onnx::NodeProto &onnx_node) {
|
|
|
|
|
MS_LOG(DEBUG) << "onnx DeConvParser";
|
|
|
|
|
auto attr = std::make_unique<schema::DeConv2DT>();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int OnnxDeConvParser::ParseParameters(const onnx::NodeProto &onnx_node,
|
|
|
|
|
const std::unique_ptr<schema::DeConv2DT> &attr) {
|
|
|
|
|
attr->padMode = schema::PadMode_NOTSET;
|
|
|
|
|
attr->group = 1;
|
|
|
|
|
attr->strideW = 1;
|
|
|
|
|
attr->strideH = 1;
|
|
|
|
|
attr->dilateW = 1;
|
|
|
|
|
attr->dilateH = 1;
|
|
|
|
|
|
|
|
|
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
|
|
|
|
if (onnx_node_attr.name() == "group") {
|
|
|
|
|
attr->group = static_cast<int32_t>(onnx_node_attr.i());
|
|
|
|
|
} else if (onnx_node_attr.name() == "dilations") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "kernels") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "kernel_shape") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
@ -97,7 +91,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
|
|
|
|
|
} else if (onnx_node_attr.name() == "pads") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 4) {
|
|
|
|
|
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
@ -106,7 +100,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
|
|
|
|
|
} else if (onnx_node_attr.name() == "strides") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
@ -115,13 +109,30 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
|
|
|
|
|
attr->format = schema::Format::Format_NHWC;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else if (onnx_node_attr.name() == "output_padding") {
|
|
|
|
|
MS_LOG(ERROR) << "output_padding param hasn't been supported";
|
|
|
|
|
return nullptr;
|
|
|
|
|
attr->outputPaddingH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->outputPaddingW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
const onnx::NodeProto &onnx_node) {
|
|
|
|
|
MS_LOG(DEBUG) << "onnx DeConvParser";
|
|
|
|
|
auto attr = std::make_unique<schema::DeConv2DT>();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto status = ParseParameters(onnx_node, attr);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Parse parameters failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto &onnx_conv_weight = onnx_node.input(1);
|
|
|
|
|
auto node_iter =
|
|
|
|
|