!8770 tflite parser supported to anf

From: @cjh9368
Reviewed-by: 
Signed-off-by:
pull/8770/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f9e4af259a

@ -47,13 +47,17 @@ using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>;
constexpr int kAnfPopulaterInputNumOne = 1;
constexpr int kAnfPopulaterInputNumTwo = 2;
constexpr int kAnfPopulaterInputNumThree = 3;
static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU},
static std::map<std::string, schema::ActivationType> kActivationTypeMap{
{"ReLU", schema::ActivationType_RELU},
{"ReLU6", schema::ActivationType_RELU6},
{"Sigmoid", schema::ActivationType_SIGMOID},
{"HSwish", schema::ActivationType_HSWISH},
{"HSigmoid", schema::ActivationType_HSIGMOID}};
{"HSigmoid", schema::ActivationType_HSIGMOID},
{"Swish", schema::ActivationType_SWISH},
{"LeakyRelu", schema::ActivationType_LEAKY_RELU},
{"Tanh", schema::ActivationType_TANH},
{"Logistic", schema::ActivationType_SIGMOID}};
std::vector<int> CastToInt(const ValuePtr value, bool is_vector);
class PrimitiveC : public mindspore::Primitive {
public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().

@ -104,8 +104,8 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
if (inputs_.size() != kSplitInputNum) {
MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum;
if (inputs_.size() < kSplitInputNum) {
MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum;
return RET_ERROR;
}
auto output = outputs_.front();

@ -194,6 +194,8 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
${LITE_DIR}/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc

@ -135,6 +135,6 @@ mtk_convert_model.tflite
mtk_model_face_dress_fp16.tflite
smartreply.tflite
mindspore_text_classification_tflite.tflite
ml_location.tflite
# ml_location.tflite
ml_text_correction.tflite
ml_pic_shopping.tflite

@ -49,6 +49,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/group_depthwise_op_convert_pass.cc
../optimizer/graph/tflite_inputs_order_exchange_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
../optimizer/graph/unused_transpose_node_remove_pass.cc
../optimizer/graph/identity_remove_pass.cc

@ -25,7 +25,6 @@
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/optimizer/fusion/layer_norm_fusion.h"
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
@ -34,6 +33,8 @@
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include "tools/optimizer/graph/infershape_pass.h"
@ -43,8 +44,7 @@
#include "tools/converter/quantizer/weight_quantizer.h"
using std::string;
namespace mindspore {
namespace lite {
namespace mindspore::lite {
AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default;
@ -65,7 +65,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>());
// for now - trainning is not supporting fuse operations
if (config != nullptr && !config->trainModel) {
if (!config->trainModel) {
// remove quantdtype when awaretraining
pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
@ -119,6 +119,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
}
pm->AddPass(std::make_shared<opt::ConvConvFusion>());
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>());
}
optimizer->AddPassManager(cf_pm);
optimizer->AddPassManager(convert_pm);
optimizer->AddPassManager(pm);
@ -168,5 +172,4 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return new_graph;
}
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

@ -32,8 +32,9 @@ class ModelParser {
virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) {
auto *meta_graph = ParseToFb(modelFile, weightFile, quantType);
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
auto *meta_graph = ParseToFb(model_file, weight_file, quant_type);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "parse model to fb failed";
return nullptr;
@ -43,8 +44,8 @@ class ModelParser {
return func_graph;
}
virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) = 0;
public:
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {

@ -31,22 +31,22 @@ CaffeModelParser::~CaffeModelParser() {}
const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int status = ValidateFileStr(modelFile, ".prototxt");
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int status = ValidateFileStr(model_file, ".prototxt");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
if (weightFile.empty()) {
if (weight_file.empty()) {
MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
status = ValidateFileStr(weightFile, ".caffemodel");
status = ValidateFileStr(weight_file, ".caffemodel");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -57,18 +57,18 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
TensorCache tensorCache;
caffe::NetParameter proto;
status = ReadProtoFromText((const char *)modelFile.c_str(), &proto);
status = ReadProtoFromText((const char *)model_file.c_str(), &proto);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile;
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << model_file;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
metaGraph->name = proto.name();
caffe::NetParameter weight;
status = ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight);
status = ReadProtoFromBinaryFile((const char *)weight_file.c_str(), &weight);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile;
MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weight_file;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
@ -81,7 +81,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
}
NoSupportOp::GetInstance()->SetFmkType("CAFFE");
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType);
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quant_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "ParseLayer failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -97,7 +97,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
metaGraph->name = GetModelName(modelFile);
metaGraph->name = GetModelName(model_file);
SetAllTensors(tensorCache, metaGraph.get());

@ -34,8 +34,8 @@ class CaffeModelParser : public ModelParser {
virtual ~CaffeModelParser();
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
private:
STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache);

@ -623,9 +623,9 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
return RET_OK;
}
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int status = ValidateFileStr(modelFile, ".onnx");
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int status = ValidateFileStr(model_file, ".onnx");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -633,9 +633,9 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
}
onnx::ModelProto onnx_model;
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
status = ReadProtoFromBinaryFile((const char *)model_file.c_str(), &onnx_model);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
@ -645,13 +645,13 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
auto dst_graph = std::make_unique<schema::MetaGraphT>();
auto dst_sub_graph = std::make_unique<schema::SubGraphT>();
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quantType);
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quant_type);
dst_graph->subGraph.push_back(std::move(dst_sub_graph));
subGraphNum += 1;
if (ret == RET_ERROR) {
return nullptr;
}
dst_graph->name = GetModelName(modelFile);
dst_graph->name = GetModelName(model_file);
std::vector<uint32_t> input_temp_index;
for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) {

@ -45,8 +45,8 @@ class OnnxModelParser : public ModelParser {
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
const QuantType &quantType);
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);

@ -1,44 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MODEL_PARSER_FOR_TFLITE_H
#define LITE_MODEL_PARSER_FOR_TFLITE_H
#include <string>
#include <unordered_map>
#include <memory>
#include "tools/converter/parser/tflite/tflite_model_parser.h"
namespace mindspore::lite {
class ModelParserForTflite : public TfliteModelParser {
public:
ModelParserForTflite() = default;
~ModelParserForTflite() override = default;
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) override;
private:
std::unordered_map<int, AnfNodePtr> nodes;
std::unique_ptr<tflite::ModelT> tfliteModel;
FuncGraphPtr funcGraphPtr;
STATUS ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter);
STATUS ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode);
STATUS ConvertOps();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
};
} // namespace mindspore::lite
#endif // LITE_MODEL_PARSER_FOR_TFLITE_H

@ -18,9 +18,11 @@
#include <memory>
#include <vector>
#include <string>
#include "src/ops/activation.h"
#include "src/ops/primitive_c.h"
#include "tools/converter/parser/tflite/tflite_util.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
@ -86,12 +88,40 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
return RET_OK;
}
TfliteNodeRegister g_tfliteReluParser("Relu", new TfliteActivationParser());
TfliteNodeRegister g_tfliteRelu6Parser("Relu6", new TfliteActivationParser());
TfliteNodeRegister g_tfliteTanhParser("Tanh", new TfliteActivationParser());
TfliteNodeRegister g_tfliteSwishParser("Swish", new TfliteActivationParser());
TfliteNodeRegister g_tfliteHardSwishParser("HardSwish", new TfliteActivationParser());
lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto ms_op_type = GetMSOpType(tflite_op_type);
if (kActivationTypeMap.find(ms_op_type) == kActivationTypeMap.end()) {
MS_LOG(ERROR) << ms_op_type << "is a not supported activation type";
return nullptr;
}
attr->type = kActivationTypeMap.find(GetMSOpType(tflite_op_type))->second;
if (attr->type == schema::ActivationType_LEAKY_RELU) {
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << GetMSOpType(tflite_op_type) << " attr failed";
return nullptr;
}
attr->alpha = tflite_attr->alpha;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_TfliteReluParser("ReLU", new TfliteActivationParser());
TfliteNodeRegister g_TfliteRelu6Parser("ReLU6", new TfliteActivationParser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteActivationParser());
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteActivationParser());
TfliteNodeRegister g_TfliteHardSwishParser("HSwish", new TfliteActivationParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser());
TfliteNodeRegister g_tfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser());
} // namespace mindspore::lite

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteActivationParser : public TfliteNodeParser {
public:
TfliteActivationParser() : TfliteNodeParser("node_name") {}
@ -32,9 +31,10 @@ class TfliteActivationParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H

@ -18,9 +18,10 @@
#include "tools/converter/parser/tflite/tflite_addn_parser.h"
#include <vector>
#include <memory>
#include <map>
#include "src/ops/addn.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
@ -55,7 +56,18 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto attr = std::make_unique<schema::AddNT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_AddN;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

@ -32,6 +32,9 @@ class TfliteAddNParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

@ -76,6 +76,39 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->outMaxValue = false;
attr->topK = 1;
attr->keepDims = false;
attr->axisType = 1;
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return nullptr;
}
auto data_ptr = buf_data->data.data();
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "the data is null";
return nullptr;
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_ArgMax;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser());
} // namespace lite

@ -32,6 +32,9 @@ class TfliteArgmaxParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

@ -76,6 +76,39 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::ArgMinT> attr = std::make_unique<schema::ArgMinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->outMaxValue = false;
attr->topK = 1;
attr->keepDims = false;
attr->axisType = 1;
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return nullptr;
}
auto data_ptr = buf_data->data.data();
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "the data is null";
return nullptr;
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_ArgMin;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser());
} // namespace lite

@ -32,6 +32,8 @@ class TfliteArgminParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

@ -32,6 +32,9 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
class TfliteSingleInputOpParser : public TfliteNodeParser {
@ -41,6 +44,9 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
class TfliteCompareOpParser : public TfliteNodeParser {
@ -50,7 +56,11 @@ class TfliteCompareOpParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

@ -74,6 +74,29 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) {
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) {
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_BatchToSpace;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser());
TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser());

@ -32,7 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save