!4890 modify onnx parsers format

Merge pull request !4890 from lyvette/parser
pull/4890/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0ce052e4f9

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_argmax_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ArgMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
@ -32,11 +47,9 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
attr->keepDims = static_cast<bool>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ARGMAX_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxArgMaxParser : public OnnxNodeParser {
public:
OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H
#endif // MMINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -167,5 +167,5 @@ class OnnxTanhParser : public OnnxNodeParser {
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H

@ -19,10 +19,26 @@
namespace mindspore {
namespace lite {
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx BatchNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "epsilon") {
attr->epsilon = onnx_node_attr.f();
@ -32,11 +48,9 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
attr->spatial = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ADD_PARSER_H
#define MS_ONNX_ADD_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxBatchNormParser : public OnnxNodeParser {
public:
OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ADD_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H

@ -14,26 +14,36 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_biasadd_parser.h"
#include <memory>
// using namespace mindspore::predict;
// using namespace onnx;
// using namespace std;
namespace mindspore {
namespace lite {
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx BiasAddParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// use channel dim as axis
attr->axis = {1};
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_BIASADD_PARSER_H
#define MS_ONNX_BIASADD_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -26,9 +26,11 @@ class OnnxBiasAddParser : public OnnxNodeParser {
public:
OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_BIASADD_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H

@ -14,25 +14,40 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_cast_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CastParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") {
attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CAST_PARSER_H
#define MS_ONNX_CAST_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxCastParser : public OnnxNodeParser {
public:
OnnxCastParser() : OnnxNodeParser("Cast") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CAST_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H

@ -14,13 +14,25 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_clip_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ClipParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
float min = -1, max = -1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
@ -32,15 +44,17 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}
if (min == 0 && max == 6) {
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
attr->type = schema::ActivationType_RELU6;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_RELU6;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
} else {
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
return RET_PARAM_INVALID;
return RET_ERROR;
}
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CLIP_PARSER_H
#define MS_ONNX_CLIP_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxClipParser : public OnnxNodeParser {
public:
OnnxClipParser() : OnnxNodeParser("Clip") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_concat_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConcatParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONCAT_PARSER_H
#define MS_ONNX_CONCAT_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxConcatParser : public OnnxNodeParser {
public:
OnnxConcatParser() : OnnxNodeParser("Concat") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONCAT_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_constant_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,12 +23,24 @@ STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConstantParser";
if (op != nullptr) {
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONSTANT_PARSER_H
#define MS_ONNX_CONSTANT_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxConstantParser : public OnnxNodeParser {
public:
OnnxConstantParser() : OnnxNodeParser("Constant") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONSTANT_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H

@ -14,21 +14,22 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
namespace mindspore {
namespace lite {
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) {
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
if (attr == nullptr || attr->group != attr->channelIn) {
return false;
}
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
if (depthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new DepthwiseConv2DT failed";
MS_LOG(ERROR) << "new op failed";
return false;
}
depthwiseConv2DParam->format = attr->format;
@ -47,15 +48,32 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT
depthwiseConv2DParam->dilateH = attr->dilateH;
depthwiseConv2DParam->hasBias = attr->hasBias;
depthwiseConv2DParam->activationType = attr->activationType;
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = depthwiseConv2DParam.release();
return true;
}
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
@ -149,13 +167,13 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
if (attr->group != 1) {
if (!ParseGroupConvolution(attr, op)) {
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;
}
} else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONV_PARSER_H
#define MS_ONNX_CONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
#include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
@ -26,11 +26,16 @@ namespace lite {
class OnnxConvParser : public OnnxNodeParser {
public:
OnnxConvParser() : OnnxNodeParser("Conv") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
private:
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op);
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::CNodeT *op);
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H

@ -22,6 +22,7 @@ namespace lite {
OnnxConverter::OnnxConverter() {
modelParser = new OnnxModelParser();
}
} // namespace lite
} // namespace mindspore

@ -14,8 +14,9 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONVERTER_H
#define MS_ONNX_CONVERTER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#include <string>
#include <memory>
#include "tools/converter/converter.h"
@ -27,10 +28,10 @@ class OnnxConverter : public Converter {
public:
OnnxConverter();
~OnnxConverter() override = default;
~OnnxConverter() = default;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONVERTER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H

@ -14,21 +14,21 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
namespace mindspore {
namespace lite {
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DeConvParser";
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
schema::CNodeT *op) {
if (attr == nullptr || attr->group != attr->channelOut) {
return false;
}
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
if (deDepthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new DeDepthwiseConv2DT failed";
MS_LOG(WARNING) << "new op failed";
return false;
}
deDepthwiseConv2DParam->format = attr->format;
@ -47,38 +47,53 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
deDepthwiseConv2DParam->dilateH = attr->dilateH;
deDepthwiseConv2DParam->hasBias = attr->hasBias;
deDepthwiseConv2DParam->activationType = attr->activationType;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.value = deDepthwiseConv2DParam.release();
}
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.value = deDepthwiseConv2DParam.release();
return true;
}
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DeConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
attr->group = static_cast<int32_t>(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -87,7 +102,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->padMode = GetOnnxPadMode(onnx_node_attr);
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR;
}
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -96,7 +111,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -105,7 +120,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
if (onnx_node_attr.s() == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str());
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
return RET_ERROR;
}
}
@ -116,7 +131,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) {
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str())
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
return RET_ERROR;
}
std::vector<int> weight_shape;
@ -137,7 +152,6 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
return RET_ERROR;
}
} else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_DECONV_PARSER_H
#define MS_ONNX_DECONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
#include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
@ -26,11 +26,16 @@ namespace lite {
class OnnxDeConvParser : public OnnxNodeParser {
public:
OnnxDeConvParser() : OnnxNodeParser("DeConv") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
private:
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op);
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
schema::CNodeT *op);
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_DECONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DepthToSpaceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") {
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H
#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxDepthToSpaceParser : public OnnxNodeParser {
public:
OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_DEPTH_TO_SPACE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_dropout_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DropoutParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "ratio") {
attr->ratio = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Dropout;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Dropout;
op->primitive->value.value = attr.release();
return RET_OK;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save