|
|
|
@ -131,33 +131,33 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|
|
|
|
|
|
|
|
|
const auto &onnx_conv_weight = onnx_node.input(1);
|
|
|
|
|
if (onnx_node.op_type() == "Conv") {
|
|
|
|
|
auto nodeIter =
|
|
|
|
|
auto node_iter =
|
|
|
|
|
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
|
|
|
|
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
|
|
|
|
|
if (nodeIter == onnx_graph.initializer().end()) {
|
|
|
|
|
if (node_iter == onnx_graph.initializer().end()) {
|
|
|
|
|
MS_LOG(WARNING) << "not find node: " << onnx_conv_weight;
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<int> weight_shape;
|
|
|
|
|
auto size = (*nodeIter).dims_size();
|
|
|
|
|
auto size = (*node_iter).dims_size();
|
|
|
|
|
weight_shape.reserve(size);
|
|
|
|
|
for (int i = 0; i < size; ++i) {
|
|
|
|
|
weight_shape.emplace_back((*nodeIter).dims(i));
|
|
|
|
|
weight_shape.emplace_back((*node_iter).dims(i));
|
|
|
|
|
}
|
|
|
|
|
attr->channelOut = weight_shape[0];
|
|
|
|
|
attr->channelIn = weight_shape[1] * attr->group;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto nodeIter =
|
|
|
|
|
auto node_iter =
|
|
|
|
|
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
|
|
|
|
|
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; });
|
|
|
|
|
if (nodeIter == onnx_graph.node().end()) {
|
|
|
|
|
if (node_iter == onnx_graph.node().end()) {
|
|
|
|
|
MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> dims;
|
|
|
|
|
auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(),
|
|
|
|
|
auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(),
|
|
|
|
|
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; });
|
|
|
|
|
if (iter != (*nodeIter).attribute().end()) {
|
|
|
|
|
if (iter != (*node_iter).attribute().end()) {
|
|
|
|
|
if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "dims insert failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|