!4509 repair onnx converter

Merge pull request !4509 from wangzhe/master
pull/4509/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e16e27a71e

@ -102,6 +102,10 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
auto data = reinterpret_cast<int32_t *>(shape_tensor->Data()); auto data = reinterpret_cast<int32_t *>(shape_tensor->Data());
CalShape<int32_t>(data, inputs_, &out_shape, shape_size); CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
} break; } break;
case kNumberTypeInt64: {
auto data = reinterpret_cast<int64_t *>(shape_tensor->Data());
CalShape<int64_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeFloat: { case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->Data()); auto data = reinterpret_cast<float *>(shape_tensor->Data());
CalShape<float>(data, inputs_, &out_shape, shape_size); CalShape<float>(data, inputs_, &out_shape, shape_size);

@ -223,7 +223,6 @@ if(BUILD_CONVERTER)
${LITE_DIR}/tools/converter/graphdef_transform.cc ${LITE_DIR}/tools/converter/graphdef_transform.cc
${LITE_DIR}/tools/converter/converter_flags.cc ${LITE_DIR}/tools/converter/converter_flags.cc
${LITE_DIR}/tools/converter/converter.cc ${LITE_DIR}/tools/converter/converter.cc
${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc
${LITE_DIR}/test/st/converter_test.cc ${LITE_DIR}/test/st/converter_test.cc
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc
@ -351,6 +350,7 @@ if (BUILD_CONVERTER)
anf_importer_mid anf_importer_mid
tflite_parser_mid tflite_parser_mid
caffe_parser_mid caffe_parser_mid
onnx_parser_mid
node_mid node_mid
graph_pass_mid graph_pass_mid
fusion_mid fusion_mid

@ -71,7 +71,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc # ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc
../optimizer/common/node_pass_extends.cc ../optimizer/common/node_pass_extends.cc
../optimizer/common/pass_manager_extends.cc ../optimizer/common/pass_manager_extends.cc
@ -86,6 +85,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
add_subdirectory(parser/caffe) add_subdirectory(parser/caffe)
add_subdirectory(parser/tflite) add_subdirectory(parser/tflite)
add_subdirectory(parser/onnx)
add_subdirectory(legacy_optimizer) add_subdirectory(legacy_optimizer)
add_subdirectory(quantizer) add_subdirectory(quantizer)
@ -98,6 +98,7 @@ add_executable(converter_lite
target_link_libraries(converter_lite PRIVATE target_link_libraries(converter_lite PRIVATE
tflite_parser_mid tflite_parser_mid
caffe_parser_mid caffe_parser_mid
onnx_parser_mid
anf_importer_mid anf_importer_mid
node_mid node_mid
graph_pass_mid graph_pass_mid

@ -27,6 +27,7 @@
#include "tools/common/storage.h" #include "tools/common/storage.h"
#include "parser/caffe/caffe_converter.h" #include "parser/caffe/caffe_converter.h"
#include "parser/tflite/tflite_converter.h" #include "parser/tflite/tflite_converter.h"
#include "parser/onnx/onnx_converter.h"
#include "src/common/anf_exporter/anf_exporter.h" #include "src/common/anf_exporter/anf_exporter.h"
#include "src/common/anf_importer/import_from_protobuf.h" #include "src/common/anf_importer/import_from_protobuf.h"
#include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/parser/onnx/onnx.pb.h"
@ -185,6 +186,10 @@ int RunConverter(int argc, const char **argv) {
TfliteConverter tfLiteConverter; TfliteConverter tfLiteConverter;
fb_graph = tfLiteConverter.Convert(flags); fb_graph = tfLiteConverter.Convert(flags);
} break; } break;
case FmkType::FmkType_ONNX: {
OnnxConverter onnxConverter;
fb_graph = onnxConverter.Convert(flags);
} break;
default: { default: {
MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk; MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk;
return 1; return 1;

@ -14,13 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include <regex> #include <regex>
#include <string> #include <string>
#include "ir/dtype/type_id.h" #include "ir/dtype/type_id.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace converter { namespace converter {
@ -89,8 +87,10 @@ int Flags::Init(int argc, const char **argv) {
this->fmk = FmkType_MS; this->fmk = FmkType_MS;
} else if (this->fmkIn == "TFLITE") { } else if (this->fmkIn == "TFLITE") {
this->fmk = FmkType_TFLITE; this->fmk = FmkType_TFLITE;
} else if (this->fmkIn == "ONNX") {
this->fmk = FmkType_ONNX;
} else { } else {
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS"; std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX";
return 1; return 1;
} }

@ -138,6 +138,12 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
} }
beforeNodeType = kNCHW2NHWC; beforeNodeType = kNCHW2NHWC;
afterNodeType = kNHWC2NCHW; afterNodeType = kNHWC2NCHW;
} else if (fmkType == converter::FmkType_ONNX) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
continue;
}
beforeNodeType = kNCHW2NHWC;
afterNodeType = kNHWC2NCHW;
} else { } else {
MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; MS_LOG(ERROR) << "Unsupported fmk: " << fmkType;
return RET_ERROR; return RET_ERROR;
@ -197,4 +203,3 @@ void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkTy
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -189,7 +189,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
if (opType == schema::PrimitiveType_Conv2D) { if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KCHW; weightTensor->format = schema::Format_KCHW;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { } else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CKHW; weightTensor->format = schema::Format_KCHW;
} else if (opType == schema::PrimitiveType_DeConv2D) { } else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_CKHW; weightTensor->format = schema::Format_CKHW;
} else { } else {

@ -15,14 +15,15 @@
*/ */
#include <memory> #include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h" #include "tools/converter/parser/onnx/onnx_argmax_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); MS_LOG(DEBUG) << "onnx ArgMaxParser";
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") { if (attribute_name == "axis") {

@ -17,8 +17,8 @@
#ifndef MS_ONNX_ARGMAX_PARSER_H #ifndef MS_ONNX_ARGMAX_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H #define MS_ONNX_ARGMAX_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -17,8 +17,8 @@
#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H #ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H #define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -14,14 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */
#include "mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h" #include "tools/converter/parser/onnx/onnx_batchnorm_parser.h"
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
unique_ptr<schema::FusedBatchNormT> attr(new schema::FusedBatchNormT()); MS_LOG(DEBUG) << "onnx BatchNormParser";
std::unique_ptr<schema::FusedBatchNormT> attr(new schema::FusedBatchNormT());
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "epsilon") { if (onnx_node_attr.name() == "epsilon") {
attr->epsilon = onnx_node_attr.f(); attr->epsilon = onnx_node_attr.f();

@ -17,8 +17,8 @@
#ifndef MS_ONNX_ADD_PARSER_H #ifndef MS_ONNX_ADD_PARSER_H
#define MS_ONNX_ADD_PARSER_H #define MS_ONNX_ADD_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -15,7 +15,7 @@
*/ */
#include <memory> #include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h" #include "tools/converter/parser/onnx/onnx_biasadd_parser.h"
// using namespace mindspore::predict; // using namespace mindspore::predict;
// using namespace onnx; // using namespace onnx;
@ -25,7 +25,8 @@ namespace lite {
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
unique_ptr<schema::BiasAddT> attr(new schema::BiasAddT()); MS_LOG(DEBUG) << "onnx BiasAddParser";
std::unique_ptr<schema::BiasAddT> attr(new schema::BiasAddT());
// use channel dim as axis // use channel dim as axis
attr->axis = {1}; attr->axis = {1};
if (op != nullptr) { if (op != nullptr) {

@ -17,8 +17,8 @@
#ifndef MS_ONNX_BIASADD_PARSER_H #ifndef MS_ONNX_BIASADD_PARSER_H
#define MS_ONNX_BIASADD_PARSER_H #define MS_ONNX_BIASADD_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -15,12 +15,13 @@
*/ */
#include <memory> #include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h" #include "tools/converter/parser/onnx/onnx_cast_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::CastT> attr(new schema::CastT()); MS_LOG(DEBUG) << "onnx CastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT());
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") { if (attribute_name == "to") {

@ -17,8 +17,8 @@
#ifndef MS_ONNX_CAST_PARSER_H #ifndef MS_ONNX_CAST_PARSER_H
#define MS_ONNX_CAST_PARSER_H #define MS_ONNX_CAST_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -15,24 +15,32 @@
*/ */
#include <memory> #include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h" #include "tools/converter/parser/onnx/onnx_clip_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
unique_ptr<schema::ClipT> attr(new schema::ClipT()); MS_LOG(DEBUG) << "onnx ClipParser";
float min = -1, max = -1;
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "max") { if (attribute_name == "max") {
attr->max = onnx_node_attr.f(); max = onnx_node_attr.f();
} else if (attribute_name == "min") { } else if (attribute_name == "min") {
attr->min = onnx_node_attr.f(); min = onnx_node_attr.f();
} }
} }
if (op != nullptr) { if (min == 0 && max == 6) {
op->primitive = std::make_unique<schema::PrimitiveT>(); std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
op->primitive->value.type = schema::PrimitiveType_Clip; attr->type = schema::ActivationType_RELU6;
op->primitive->value.value = attr.release(); if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
} else {
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
return RET_PARAM_INVALID;
} }
return RET_OK; return RET_OK;
} }
@ -40,4 +48,3 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -17,8 +17,8 @@
#ifndef MS_ONNX_CLIP_PARSER_H #ifndef MS_ONNX_CLIP_PARSER_H
#define MS_ONNX_CLIP_PARSER_H #define MS_ONNX_CLIP_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -15,14 +15,15 @@
*/ */
#include <memory> #include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h" #include "tools/converter/parser/onnx/onnx_concat_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); MS_LOG(DEBUG) << "onnx ConcatParser";
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") { if (attribute_name == "axis") {

@ -17,8 +17,8 @@
#ifndef MS_ONNX_CONCAT_PARSER_H #ifndef MS_ONNX_CONCAT_PARSER_H
#define MS_ONNX_CONCAT_PARSER_H #define MS_ONNX_CONCAT_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -15,17 +15,19 @@
*/ */
#include <memory> #include <memory>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h" #include "tools/converter/parser/onnx/onnx_constant_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConstantParser";
if (op != nullptr) { if (op != nullptr) {
std::unique_ptr<schema::ConstantT> attr(new schema::ConstantT());
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Constant; op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = nullptr; op->primitive->value.value = attr.release();
} }
return RET_OK; return RET_OK;
} }

@ -17,8 +17,8 @@
#ifndef MS_ONNX_CONSTANT_PARSER_H #ifndef MS_ONNX_CONSTANT_PARSER_H
#define MS_ONNX_CONSTANT_PARSER_H #define MS_ONNX_CONSTANT_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -17,17 +17,18 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <algorithm> #include <algorithm>
#include "mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h" #include "tools/converter/parser/onnx/onnx_conv_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) { bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
if (attr == nullptr || attr->group != attr->channelIn) { if (attr == nullptr || attr->group != attr->channelIn) {
return false; return false;
} }
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new (std::nothrow) schema::DepthwiseConv2DT()); std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new (std::nothrow) schema::DepthwiseConv2DT());
if (depthwiseConv2DParam == nullptr) { if (depthwiseConv2DParam == nullptr) {
// MS_LOGW("new DepthwiseConv2DT failed"); MS_LOG(ERROR) << "new DepthwiseConv2DT failed";
return false; return false;
} }
depthwiseConv2DParam->format = attr->format; depthwiseConv2DParam->format = attr->format;
@ -48,12 +49,12 @@ bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *
depthwiseConv2DParam->activationType = attr->activationType; depthwiseConv2DParam->activationType = attr->activationType;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
delete (op->primitive->value.value);
op->primitive->value.value = depthwiseConv2DParam.release(); op->primitive->value.value = depthwiseConv2DParam.release();
return true; return true;
} }
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConvParser";
auto attr = new schema::Conv2DT(); auto attr = new schema::Conv2DT();
// set opdef each attr params // set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -61,30 +62,32 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->group = static_cast<int32_t>(onnx_node_attr.i()); attr->group = static_cast<int32_t>(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "dilations") { } else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0)); // TODO(wangzhe) verify the change
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") { } else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") { } else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0)); // TODO(wangzhe) verify the change
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "auto_pad") { } else if (onnx_node_attr.name() == "auto_pad") {
attr->padMode = GetOnnxPadMode(onnx_node_attr); attr->padMode = GetOnnxPadMode(onnx_node_attr);
} else if (onnx_node_attr.name() == "pads") { } else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) { if (onnx_node_attr.ints().size() != 4) {
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR; return RET_ERROR;
} }
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -93,16 +96,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3)); attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
} else if (onnx_node_attr.name() == "strides") { } else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0)); // TODO(wangzhe) verify the change
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "order") { } else if (onnx_node_attr.name() == "order") {
if (onnx_node_attr.s() == "NHWC") { if (onnx_node_attr.s() == "NHWC") {
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
} else { } else {
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s();
return RET_ERROR; return RET_ERROR;
} }
} }
@ -114,7 +118,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) { if (nodeIter == onnx_graph.initializer().end()) {
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) MS_LOG(ERROR) << "not find node: " << onnx_conv_weight;
return RET_ERROR; return RET_ERROR;
} }
std::vector<int> weight_shape; std::vector<int> weight_shape;
@ -129,7 +133,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; });
if (nodeIter == onnx_graph.node().end()) { if (nodeIter == onnx_graph.node().end()) {
// MS_LOGE("can not find node: %s", onnx_conv_weight.c_str()) MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight;
return RET_ERROR; return RET_ERROR;
} }
std::vector<int> dims; std::vector<int> dims;
@ -139,6 +143,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
} }
attr->channelOut = dims[0]; attr->channelOut = dims[0];
// TODO(wangzhe) verify this code
attr->channelIn = dims[3] * attr->group; attr->channelIn = dims[3] * attr->group;
} }
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
@ -156,7 +161,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
if (attr->group != 1) { if (attr->group != 1) {
if (!ParseGroupConvolution(op, attr)) { if (!ParseGroupConvolution(op, attr)) {
delete attr; delete attr;
// MS_LOGE("Convert Convolution to Depthwise failed"); MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR; return RET_ERROR;
} }
} }
@ -169,4 +174,3 @@ OnnxNodeRegistrar g_onnxConvReluParser("ConvRelu", new OnnxConvParser());
OnnxNodeRegistrar g_onnxInt8ConvReluParser("Int8ConvRelu", new OnnxConvParser()); OnnxNodeRegistrar g_onnxInt8ConvReluParser("Int8ConvRelu", new OnnxConvParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -17,8 +17,8 @@
#ifndef MS_ONNX_CONV_PARSER_H #ifndef MS_ONNX_CONV_PARSER_H
#define MS_ONNX_CONV_PARSER_H #define MS_ONNX_CONV_PARSER_H
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save