!8319 [MSLITE] Fix bug of onnx model parser.

From: @wang_shaocong
Reviewed-by: 
Signed-off-by:
pull/8319/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9743f07b27

@ -29,6 +29,7 @@
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
@ -263,7 +264,15 @@ bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std:
return true;
}
for (auto preNodeIdx : preNodeIdxes) {
MS_ASSERT(subGraph->nodes.size() > preNodeIdx);
MS_ASSERT(graph->nodes.size() > preNodeIdx);
// Case of multiple outputs is not supported.
if (GetInputNodeIdx(*graph, preNodeIdx).size() > kDoubleNum ||
GetOutputNodeIdx(*graph, preNodeIdx).size() > kSingleNum) {
sinkIdes.erase((sinkIdes.end() - 1));
pathSinkIdes.erase((pathSinkIdes.end() - 1));
target->UnSetPath();
return false;
}
// match left
if (MatchTree(graph, preNodeIdx, target->left, sinkIdes, pathSinkIdes)) {
// match right

@ -75,6 +75,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->dilateW = 1;
attr->group = 1;
attr->padMode = schema::PadMode_NOTSET;
attr->format = schema::Format::Format_NCHW;
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
@ -161,7 +162,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->channelOut = dims[0];
attr->channelIn = dims[3] * attr->group;
}
attr->format = schema::Format::Format_NCHW;
attr->hasBias = onnx_node.input().size() == 3;
if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") {
attr->activationType = schema::ActivationType_RELU;

@ -244,6 +244,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
// set quantParams to Int8GivenTensor.
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<schema::QuantParamT>();
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "Y_scale") {
quant_param->scale = onnx_node_attr.f();
} else if (onnx_node_attr.name() == "Y_zero_point") {
quant_param->zeroPoint = static_cast<int32_t>(onnx_node_attr.i());
}
}
tensor->quantParams.emplace_back(std::move(quant_param));
} else {
MS_LOG(ERROR) << "unsupported data type " << tensor->dataType;
return RET_ERROR;
@ -256,9 +266,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
}
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
TensorCache *tensor_cache, const QuantType &quantType,
schema::MetaGraphT *dst_graph) {
schema::CNodeT *dst_op, TensorCache *tensor_cache,
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
// change op_type() to name(), that is unique
static bool interrupt = false;
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
@ -267,7 +276,6 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size "
<< onnx_node.input_size();
// get the real op type
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
if (onnx_node.op_type() == "Loop") {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
interrupt = true;
@ -305,6 +313,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
MS_LOG(ERROR) << "SetOpInputIndex failed";
return RET_ERROR;
}
if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) {
auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
weight_tensor->format = dst_op->primitive->value.AsConv2D()->format;
} else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) {
auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format;
}
// set op output index
std::vector<string> node_outputs;
(void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end());
@ -314,6 +329,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
MS_LOG(ERROR) << "SetOpOutputIndex failed";
return RET_ERROR;
}
auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front());
if (output_tensor == nullptr) {
interrupt = true;
MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr.";
return RET_ERROR;
}
SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache);
return RET_OK;
}
@ -572,9 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
}
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
status_node =
ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, dst_graph);
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph);
if (status_node != RET_OK) {
status = (status == RET_OK ? status_node : status);
continue;

@ -66,8 +66,8 @@ class OnnxModelParser : public ModelParser {
TensorCache *tensor_cache, int *index);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache,
const QuantType &quantType, schema::MetaGraphT *dst_graph);
schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType,
schema::MetaGraphT *dst_graph);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache,

@ -24,7 +24,7 @@
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "schema/inner/model_generated.h"
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace lite {
class OnnxNodeParser {

@ -14,14 +14,14 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h"
#include "tools/converter/parser/onnx/onnx_quantize_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UnusefulNodeParser";
STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -32,30 +32,27 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const o
return RET_NULL_PTR;
}
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed.";
return RET_NULL_PTR;
}
if (onnx_node.op_type() == "Int8Quantize") {
std::unique_ptr<schema::OnnxInt8QuantizeT> attr = std::make_unique<schema::OnnxInt8QuantizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
op->primitive->value.value = attr.release();
attr->srcT = kNumberTypeFloat32;
attr->dstT = kNumberTypeInt8;
} else if (onnx_node.op_type() == "Int8Dequantize") {
std::unique_ptr<schema::OnnxInt8DequantizeT> attr = std::make_unique<schema::OnnxInt8DequantizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
op->primitive->value.value = attr.release();
attr->srcT = kNumberTypeInt8;
attr->dstT = kNumberTypeFloat32;
} else {
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxUnusefulNodeParser());
OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxUnusefulNodeParser());
OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser());
OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxQuantizeParser());
} // namespace lite
} // namespace mindspore

@ -14,21 +14,21 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxUnusefulNodeParser : public OnnxNodeParser {
class OnnxQuantizeParser : public OnnxNodeParser {
public:
OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {}
~OnnxUnusefulNodeParser() override = default;
OnnxQuantizeParser() : OnnxNodeParser("Quantize") {}
~OnnxQuantizeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H

@ -79,7 +79,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D ||
op_type == schema::PrimitiveType_DeConv2D || op_type == schema::PrimitiveType_DeDepthwiseConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
if (param_value->format() == schema::Format::Format_NHWC) {
param_value->set_format(schema::Format::Format_KHWC);
} else {
param_value->set_format(schema::Format::Format_KCHW);
}
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();

Loading…
Cancel
Save