|
|
|
@ -14,21 +14,21 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) {
|
|
|
|
|
MS_LOG(DEBUG) << "onnx DeConvParser";
|
|
|
|
|
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
|
|
|
|
|
schema::CNodeT *op) {
|
|
|
|
|
if (attr == nullptr || attr->group != attr->channelOut) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
|
|
|
|
|
if (deDepthwiseConv2DParam == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new DeDepthwiseConv2DT failed";
|
|
|
|
|
MS_LOG(WARNING) << "new op failed";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
deDepthwiseConv2DParam->format = attr->format;
|
|
|
|
@ -47,38 +47,53 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
|
|
|
|
|
deDepthwiseConv2DParam->dilateH = attr->dilateH;
|
|
|
|
|
deDepthwiseConv2DParam->hasBias = attr->hasBias;
|
|
|
|
|
deDepthwiseConv2DParam->activationType = attr->activationType;
|
|
|
|
|
if (op != nullptr) {
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
|
|
|
|
|
op->primitive->value.value = deDepthwiseConv2DParam.release();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
|
|
|
|
|
op->primitive->value.value = deDepthwiseConv2DParam.release();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::CNodeT *op) {
|
|
|
|
|
MS_LOG(DEBUG) << "onnx DeConvParser";
|
|
|
|
|
if (op == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "op is null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
if (op->primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "op->primitive is null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set opdef each attr params
|
|
|
|
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
|
|
|
|
if (onnx_node_attr.name() == "group") {
|
|
|
|
|
attr->group = static_cast<int32_t>(onnx_node_attr.i());
|
|
|
|
|
} else if (onnx_node_attr.name() == "dilations") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "kernels") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
|
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
|
|
|
|
} else if (onnx_node_attr.name() == "kernel_shape") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
@ -87,7 +102,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|
|
|
|
attr->padMode = GetOnnxPadMode(onnx_node_attr);
|
|
|
|
|
} else if (onnx_node_attr.name() == "pads") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 4) {
|
|
|
|
|
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
@ -96,7 +111,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|
|
|
|
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
|
|
|
|
|
} else if (onnx_node_attr.name() == "strides") {
|
|
|
|
|
if (onnx_node_attr.ints().size() != 2) {
|
|
|
|
|
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size());
|
|
|
|
|
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
|
|
|
@ -105,7 +120,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|
|
|
|
if (onnx_node_attr.s() == "NHWC") {
|
|
|
|
|
attr->format = schema::Format_NHWC;
|
|
|
|
|
} else {
|
|
|
|
|
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str());
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -116,7 +131,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|
|
|
|
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
|
|
|
|
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
|
|
|
|
|
if (nodeIter == onnx_graph.initializer().end()) {
|
|
|
|
|
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str())
|
|
|
|
|
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> weight_shape;
|
|
|
|
@ -137,7 +152,6 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|
}
|
|
|
|
|