!9560 [lite]reconstruct onnx

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/9560/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4a7c87f442

@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef PRIMITIVE_WRITEABLE
#ifndef LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Constant : public PrimitiveC {
public:
Constant() = default;
~Constant() = default;
MS_DECLARE_PARENT(Constant, PrimitiveC);
explicit Constant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
#endif

@ -149,6 +149,7 @@
#include "src/ops/oneslike.h"
#include "src/ops/unsorted_segment_sum.h"
#include "src/ops/reciprocal.h"
#include "src/ops/constant.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -186,7 +187,7 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> CastToInt(const ValuePtr value) {
std::vector<int> CastToInt(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(WARNING) << "valueptr is nullptr.";
return {};
@ -903,6 +904,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Dequant(primitive);
case schema::PrimitiveType_Reciprocal:
return new (std::nothrow) Reciprocal(primitive);
case schema::PrimitiveType_Constant:
return new (std::nothrow) Constant(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

@ -57,7 +57,7 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{
{"LeakyRelu", schema::ActivationType_LEAKY_RELU},
{"Tanh", schema::ActivationType_TANH},
{"Logistic", schema::ActivationType_SIGMOID}};
std::vector<int> CastToInt(const ValuePtr value);
std::vector<int> CastToInt(const ValuePtr &value);
class PrimitiveC : public mindspore::Primitive {
public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().

@ -204,6 +204,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
)
endif()
### train

@ -58,6 +58,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/infershape_pass.cc
../optimizer/graph/slice_prepose_pass.cc
../optimizer/graph/mindir_adjust_pass.cc
../optimizer/graph/onnx_inputs_adjust_pass.cc
)
add_subdirectory(../anf_importer anf_importer)

@ -36,6 +36,7 @@
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h"
#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h"
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
@ -74,6 +75,16 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
}
}
// onnx pre adjustment
if (config->fmk == converter::FmkType_ONNX) {
auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>();
if (!onnx_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "onnx adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}
// for now - trainning is not supporting fuse operations
if (!config->trainModel) {
// remove quantdtype when awaretraining

@ -90,6 +90,7 @@ STATUS CaffeModelParser::ConvertLayers() {
auto primitive_c = node_parser->ParseLitePrimitive(layer, weight);
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "parse node " << layer.name() << " failed.";
status = RET_ERROR;
continue;
}
@ -98,8 +99,7 @@ STATUS CaffeModelParser::ConvertLayers() {
status = ConvertBottom(layer, &input_nodes);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert layer bottom for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
// build weights
@ -107,8 +107,7 @@ STATUS CaffeModelParser::ConvertLayers() {
status = ConvertBlobs(weight, &const_parameters);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert blobs for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
// build cnode
@ -122,15 +121,13 @@ STATUS CaffeModelParser::ConvertLayers() {
status = ConvertTop(layer, new_cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert outputs for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
status = ConvertLayerQuantParams(layer, weight, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
}
return status;

@ -19,27 +19,22 @@
namespace mindspore {
namespace lite {
STATUS OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxAdderParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx AdderParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::AdderT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_Adder;
op->primitive->value.value = attr.release();
return RET_OK;
primitive->value.type = schema::PrimitiveType_Adder;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser());

@ -26,8 +26,7 @@ class OnnxAdderParser : public OnnxNodeParser {
public:
OnnxAdderParser() : OnnxNodeParser("Adder") {}
~OnnxAdderParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -19,23 +19,14 @@
namespace mindspore {
namespace lite {
STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxArgMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ArgMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
auto attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -46,10 +37,14 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->keepDims = static_cast<bool>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_ArgMax;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser());

@ -27,7 +27,7 @@ class OnnxArgMaxParser : public OnnxNodeParser {
OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {}
~OnnxArgMaxParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -26,203 +26,203 @@ class OnnxAddParser : public OnnxNodeParser {
public:
OnnxAddParser() : OnnxNodeParser("Add") {}
~OnnxAddParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSubParser : public OnnxNodeParser {
public:
OnnxSubParser() : OnnxNodeParser("Sub") {}
~OnnxSubParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxMulParser : public OnnxNodeParser {
public:
OnnxMulParser() : OnnxNodeParser("Mul") {}
~OnnxMulParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxDivParser : public OnnxNodeParser {
public:
OnnxDivParser() : OnnxNodeParser("Div") {}
~OnnxDivParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxPowParser : public OnnxNodeParser {
public:
OnnxPowParser() : OnnxNodeParser("Power") {}
~OnnxPowParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxEqualParser : public OnnxNodeParser {
public:
OnnxEqualParser() : OnnxNodeParser("Equal") {}
~OnnxEqualParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxLessParser : public OnnxNodeParser {
public:
OnnxLessParser() : OnnxNodeParser("Less") {}
~OnnxLessParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxGreaterParser : public OnnxNodeParser {
public:
OnnxGreaterParser() : OnnxNodeParser("Greater") {}
~OnnxGreaterParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxMinParser : public OnnxNodeParser {
public:
OnnxMinParser() : OnnxNodeParser("Min") {}
~OnnxMinParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxEltwiseParser : public OnnxNodeParser {
public:
OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {}
~OnnxEltwiseParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxFloorParser : public OnnxNodeParser {
public:
OnnxFloorParser() : OnnxNodeParser("Floor") {}
~OnnxFloorParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAbsParser : public OnnxNodeParser {
public:
OnnxAbsParser() : OnnxNodeParser("Abs") {}
~OnnxAbsParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxNegParser : public OnnxNodeParser {
public:
OnnxNegParser() : OnnxNodeParser("Neg") {}
~OnnxNegParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxExpParser : public OnnxNodeParser {
public:
OnnxExpParser() : OnnxNodeParser("Exp") {}
~OnnxExpParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxCosParser : public OnnxNodeParser {
public:
OnnxCosParser() : OnnxNodeParser("Cos") {}
~OnnxCosParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSinParser : public OnnxNodeParser {
public:
OnnxSinParser() : OnnxNodeParser("Sin") {}
~OnnxSinParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSqrtParser : public OnnxNodeParser {
public:
OnnxSqrtParser() : OnnxNodeParser("Sqrt") {}
~OnnxSqrtParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxCeilParser : public OnnxNodeParser {
public:
OnnxCeilParser() : OnnxNodeParser("Ceil") {}
~OnnxCeilParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxLogParser : public OnnxNodeParser {
public:
OnnxLogParser() : OnnxNodeParser("Log") {}
~OnnxLogParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxTanParser : public OnnxNodeParser {
public:
OnnxTanParser() : OnnxNodeParser("Tan") {}
~OnnxTanParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAtanParser : public OnnxNodeParser {
public:
OnnxAtanParser() : OnnxNodeParser("Atan") {}
~OnnxAtanParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAsinParser : public OnnxNodeParser {
public:
OnnxAsinParser() : OnnxNodeParser("Asin") {}
~OnnxAsinParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxTanhParser : public OnnxNodeParser {
public:
OnnxTanhParser() : OnnxNodeParser("Tanh") {}
~OnnxTanhParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSignParser : public OnnxNodeParser {
public:
OnnxSignParser() : OnnxNodeParser("Sign") {}
~OnnxSignParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAndParser : public OnnxNodeParser {
public:
OnnxAndParser() : OnnxNodeParser("And") {}
~OnnxAndParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxOrParser : public OnnxNodeParser {
public:
OnnxOrParser() : OnnxNodeParser("Or") {}
~OnnxOrParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxNotParser : public OnnxNodeParser {
public:
OnnxNotParser() : OnnxNodeParser("Not") {}
~OnnxNotParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxRoundParser : public OnnxNodeParser {
public:
OnnxRoundParser() : OnnxNodeParser("Round") {}
~OnnxRoundParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxReciprocalParser : public OnnxNodeParser {
public:
OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {}
~OnnxReciprocalParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxBatchNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx BatchNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>();
auto attr = std::make_unique<schema::FusedBatchNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -47,10 +37,14 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
attr->spatial = static_cast<int32_t>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser());

@ -27,7 +27,7 @@ class OnnxBatchNormParser : public OnnxNodeParser {
OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {}
~OnnxBatchNormParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -19,30 +19,25 @@
namespace mindspore {
namespace lite {
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxBiasAddParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx BiasAddParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>();
auto attr = std::make_unique<schema::BiasAddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->axis = {1};
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_BiasAdd;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser());

@ -27,7 +27,7 @@ class OnnxBiasAddParser : public OnnxNodeParser {
OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {}
~OnnxBiasAddParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -20,22 +20,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxCastParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx CastParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
auto attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -48,10 +39,14 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->dstT = static_cast<int>(dst_type);
}
}
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Cast;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser());

@ -27,7 +27,7 @@ class OnnxCastParser : public OnnxNodeParser {
OnnxCastParser() : OnnxNodeParser("Cast") {}
~OnnxCastParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -19,39 +19,32 @@
namespace mindspore {
namespace lite {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxClipParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ClipParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
auto attr = std::make_unique<schema::ClipT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
float min = -1, max = -1;
attr->max = -1;
attr->min = -1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "max") {
max = onnx_node_attr.f();
attr->max = onnx_node_attr.f();
} else if (attribute_name == "min") {
min = onnx_node_attr.f();
attr->min = onnx_node_attr.f();
}
}
std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
attr->max = max;
attr->min = min;
op->primitive->value.type = schema::PrimitiveType_Clip;
op->primitive->value.value = attr.release();
return RET_OK;
primitive->value.type = schema::PrimitiveType_Clip;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser());

@ -27,7 +27,7 @@ class OnnxClipParser : public OnnxNodeParser {
OnnxClipParser() : OnnxNodeParser("Clip") {}
~OnnxClipParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxConcatParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ConcatParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
auto attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -44,10 +34,14 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser());

@ -27,7 +27,7 @@ class OnnxConcatParser : public OnnxNodeParser {
OnnxConcatParser() : OnnxNodeParser("Concat") {}
~OnnxConcatParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

@ -20,23 +20,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxConstantOfShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ConstantOfShapeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConstantOfShapeT> attr = std::make_unique<schema::ConstantOfShapeT>();
auto attr = std::make_unique<schema::ConstantOfShapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -55,19 +45,24 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
const auto &tensor = onnx_node_attr.t();
auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType);
if (ret != RET_OK) {
return ret;
MS_LOG(ERROR) << "get data from tensor failed";
return nullptr;
}
} break;
default:
MS_LOG(ERROR) << "The data type is not supported.";
return RET_ERROR;
return nullptr;
}
}
}
op->primitive->value.type = schema::PrimitiveType_ConstantOfShape;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_ConstantOfShape;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser());

@ -27,7 +27,7 @@ class OnnxConstantOfShapeParser : public OnnxNodeParser {
OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {}
~OnnxConstantOfShapeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save