tflite parser remove useless code and recitify registry

pull/8934/head
cjh9368 4 years ago
parent 1a7347d29f
commit 8e32dbb959

@ -23,71 +23,6 @@
#include "tools/converter/parser/tflite/tflite_util.h"
namespace mindspore::lite {
STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU;
} else if (std::strcmp(node_name, "Relu6") == 0) {
MS_LOG(DEBUG) << "parse TfliteRelu6Parser";
attr->type = schema::ActivationType_RELU6;
} else if (std::strcmp(node_name, "Tanh") == 0) {
MS_LOG(DEBUG) << "parse TfliteTanhParser";
attr->type = schema::ActivationType_TANH;
} else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "Swish") == 0) {
MS_LOG(DEBUG) << "parse TfliteSwishParser";
attr->type = schema::ActivationType_SWISH;
} else if (std::strcmp(node_name, "HardSwish") == 0) {
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
attr->type = schema::ActivationType_HSWISH;
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->alpha = tflite_attr->alpha;
attr->type = schema::ActivationType_LEAKY_RELU;
} else {
MS_LOG(ERROR) << node_name << " hasn't been supported";
return RET_NOT_FIND_OP;
}
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
@ -117,11 +52,10 @@ lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_p
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_TfliteReluParser("ReLU", new TfliteActivationParser());
TfliteNodeRegister g_TfliteRelu6Parser("ReLU6", new TfliteActivationParser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteActivationParser());
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteActivationParser());
TfliteNodeRegister g_TfliteHardSwishParser("HSwish", new TfliteActivationParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser());
TfliteNodeRegister g_TfliteReluParser(tflite::BuiltinOperator_RELU, new TfliteActivationParser());
TfliteNodeRegister g_TfliteRelu6Parser(tflite::BuiltinOperator_RELU6, new TfliteActivationParser());
TfliteNodeRegister g_TfliteTanhParser(tflite::BuiltinOperator_TANH, new TfliteActivationParser());
TfliteNodeRegister g_TfliteSwishParser(tflite::BuiltinOperator_HARD_SWISH, new TfliteActivationParser());
TfliteNodeRegister g_tfliteLogisticParser(tflite::BuiltinOperator_LOGISTIC, new TfliteActivationParser());
TfliteNodeRegister g_TfliteLeakyReluParser(tflite::BuiltinOperator_LEAKY_RELU, new TfliteActivationParser());
} // namespace mindspore::lite

@ -28,10 +28,6 @@ class TfliteActivationParser : public TfliteNodeParser {
public:
TfliteActivationParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -22,40 +22,6 @@
#include "src/ops/addn.h"
namespace mindspore::lite {
STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteAddNParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AddNT> attr = std::make_unique<schema::AddNT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->N = tflite_subgraph->tensors.size() - 1;
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();
for (int input : tflite_op->inputs) {
AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto attr = std::make_unique<schema::AddNT>();
@ -69,5 +35,5 @@ lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tfl
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser());
TfliteNodeRegister g_tfliteAddNParser(tflite::BuiltinOperator_ADD_N, new TfliteAddNParser());
} // namespace mindspore::lite

@ -29,10 +29,6 @@ class TfliteAddNParser : public TfliteNodeParser {
public:
TfliteAddNParser() : TfliteNodeParser("AddN") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -19,63 +19,7 @@
#include <vector>
#include <map>
namespace mindspore {
namespace lite {
STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->outMaxValue = false;
attr->topK = 1;
attr->keepDims = false;
attr->axisType = 1;
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto axis_tensor = tflite_subgraph->tensors[axis_idx].get();
if (axis_tensor == nullptr) {
MS_LOG(ERROR) << "axis_tensor is null";
return RET_NULL_PTR;
}
auto buffer_idx = axis_tensor->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR;
}
auto data_ptr = buf_data->data.data();
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "the data is null";
return RET_NULL_PTR;
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
namespace mindspore::lite {
PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
@ -110,6 +54,5 @@ PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite:
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_tfliteArgmaxParser(tflite::BuiltinOperator_ARG_MAX, new TfliteArgmaxParser());
} // namespace mindspore::lite

@ -29,10 +29,6 @@ class TfliteArgmaxParser : public TfliteNodeParser {
public:
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -19,63 +19,7 @@
#include <vector>
#include <map>
namespace mindspore {
namespace lite {
STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteArgminParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ArgMinT> attr = std::make_unique<schema::ArgMinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->outMaxValue = false;
attr->topK = 1;
attr->keepDims = false;
attr->axisType = 1;
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto axis_tensor = tflite_subgraph->tensors[axis_idx].get();
if (axis_tensor == nullptr) {
MS_LOG(ERROR) << "axis_tensor is null";
return RET_NULL_PTR;
}
auto buffer_idx = axis_tensor->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR;
}
auto data_ptr = buf_data->data.data();
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "the data is null";
return RET_NULL_PTR;
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
namespace mindspore::lite {
PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
@ -110,6 +54,5 @@ PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite:
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_tfliteArgminParser(tflite::BuiltinOperator_ARG_MIN, new TfliteArgminParser());
} // namespace mindspore::lite

@ -23,19 +23,14 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteArgminParser : public TfliteNodeParser {
public:
TfliteArgminParser() : TfliteNodeParser("Argmin") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H

@ -29,10 +29,6 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
public:
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
@ -41,10 +37,6 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
public:
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
@ -53,10 +45,6 @@ class TfliteCompareOpParser : public TfliteNodeParser {
public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -21,59 +21,7 @@
#include <string>
#include <map>
namespace mindspore {
namespace lite {
STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "BatchToSpace") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
} else {
MS_LOG(ERROR) << node_name << " hasn't been supported";
return RET_NOT_FIND_OP;
}
std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) {
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
return RET_ERROR;
}
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) {
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
namespace mindspore::lite {
PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
@ -98,7 +46,6 @@ PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<t
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser());
TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_tfliteBatchToSpaceNDParser(tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
new TfliteBatchToSpaceParser());
} // namespace mindspore::lite

@ -23,20 +23,15 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteBatchToSpaceParser : public TfliteNodeParser {
public:
TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H

@ -19,44 +19,7 @@
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) {
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
namespace mindspore::lite {
PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
@ -80,6 +43,4 @@ PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr<tf
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

@ -28,10 +28,6 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
public:
TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -18,51 +18,7 @@
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteCastParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
}
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
namespace mindspore::lite {
PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
@ -90,6 +46,5 @@ PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr<tflite::O
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_tfliteCastParser(tflite::BuiltinOperator_CAST, new TfliteCastParser());
} // namespace mindspore::lite

@ -29,9 +29,6 @@ class TfliteCastParser : public TfliteNodeParser {
public:
TfliteCastParser() : TfliteNodeParser("Cast") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -18,48 +18,7 @@
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteConcatParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->axis = tfliteAttr->axis;
attr->n = tflite_op->inputs.size();
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
for (int input : tflite_op->inputs) {
AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
namespace mindspore::lite {
PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
@ -81,6 +40,5 @@ PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite:
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_tfliteConcatParser(tflite::BuiltinOperator_CONCATENATION, new TfliteConcatParser());
} // namespace mindspore::lite

@ -29,9 +29,6 @@ class TfliteConcatParser : public TfliteNodeParser {
public:
TfliteConcatParser() : TfliteNodeParser("Concat") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -19,82 +19,6 @@
#include <memory>
namespace mindspore::lite {
STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteConvParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->group = 1;
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->dilateH = tflite_attr->dilation_h_factor;
attr->dilateW = tflite_attr->dilation_w_factor;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format::Format_NHWC;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
attr->hasBias = true;
// get the conv op weight tensor
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];
// calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
@ -153,5 +77,5 @@ lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr<tfl
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser());
TfliteNodeRegister g_tfliteConv2DParser(tflite::BuiltinOperator_CONV_2D, new TfliteConvParser());
} // namespace mindspore::lite

@ -28,9 +28,6 @@ class TfliteConvParser : public TfliteNodeParser {
public:
TfliteConvParser() : TfliteNodeParser("Conv2D") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

@ -207,70 +207,6 @@ STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr,
return RET_OK;
}
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteCustomParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
const auto &custom_attr = tflite_op->custom_options;
const auto opcode_index = tflite_op->opcode_index;
const auto &operator_code = tflite_model->operator_codes[opcode_index];
if (operator_code == nullptr) {
MS_LOG(ERROR) << "operator_code is null";
return RET_NULL_PTR;
}
const auto &custom_type = operator_code->custom_code;
int status = RET_OK;
if (custom_type == "TFLite_Detection_PostProcess") {
status = DetectPostProcess(custom_attr, op, tflite_op);
} else if (custom_type == "Predict") {
status = Predict(custom_attr, op, tflite_op);
} else if (custom_type == "Normalize") {
status = Normalize(custom_attr, op, tflite_op);
} else if (custom_type == "ExtractFeatures") {
status = ExtractFeatures(custom_attr, op, tflite_op);
} else if (custom_type == "AudioSpectrogram") {
status = AudioSpectrogram(custom_attr, op, tflite_op);
} else if (custom_type == "Mfcc") {
status = Mfcc(custom_attr, op, tflite_op);
} else if (custom_type == "FlexRFFT") {
status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph);
} else if (custom_type == "FlexReal") {
status = FftReal(custom_attr, op, tflite_op);
} else if (custom_type == "FlexImag") {
status = FftImag(custom_attr, op, tflite_op);
} else if (custom_type == "FlexIdentityN" || custom_type == "FlexIdentity") {
status = Identity(custom_attr, op, tflite_op);
} else if (custom_type == "FlexBatchMatMul") {
status = BatchMatMul(custom_attr, op, tflite_op);
} else {
MS_LOG(ERROR) << "the custom op hasn't been supported now";
status = RET_NOT_FIND_OP;
}
if (status != RET_OK) {
return status;
}
for (int input : tflite_op->inputs) {
AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
for (int output : tflite_op->outputs) {
AddOpOutput(op, tensors_info, output, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
return status;
}
PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
@ -314,6 +250,6 @@ PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr<tflite:
return PrimitiveC::Create(primitive);
}
TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser());
TfliteNodeRegister g_tfliteCustomParser(tflite::BuiltinOperator_CUSTOM, new TfliteCustomParser());
} // namespace lite
} // namespace mindspore

@ -28,9 +28,6 @@ class TfliteCustomParser : public TfliteNodeParser {
public:
TfliteCustomParser() : TfliteNodeParser("Custom") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;

@ -18,84 +18,7 @@
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->group = 1;
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->dilateH = 1;
attr->dilateW = 1;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format::Format_NHWC;
attr->activationType = schema::ActivationType_NO_ACTIVATION;
attr->hasBias = true;
// get the conv op weight tensor
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];
// calculate pad params
auto data_index = tflite_op->inputs[2];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
namespace mindspore::lite {
PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
@ -155,6 +78,5 @@ PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr<tflite:
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_tfliteDeConv2DParser(tflite::BuiltinOperator_TRANSPOSE_CONV, new TfliteDeConvParser());
} // namespace mindspore::lite

@ -23,19 +23,14 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteDeConvParser : public TfliteNodeParser {
public:
TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H

@ -21,45 +21,6 @@
namespace mindspore {
namespace lite {
STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
MS_ASSERT(tflite_op != nullptr);
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
return RET_NULL_PTR;
}
attr->blockSize = tflite_attr->block_size;
attr->format = schema::Format::Format_NHWC;
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
@ -81,6 +42,6 @@ PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr<t
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser());
TfliteNodeRegister g_tfliteDepthToSpaceParser(tflite::BuiltinOperator_DEPTH_TO_SPACE, new TfliteDepthToSpaceParser());
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save