|
|
@ -81,7 +81,6 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
|
|
|
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
|
|
|
const auto &tflite_model_buffers = tflite_model_->buffers;
|
|
|
|
|
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
int op_idx = 0;
|
|
|
|
int op_idx = 0;
|
|
|
@ -117,6 +116,9 @@ STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
|
|
|
|
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
|
|
|
|
// parse inputs
|
|
|
|
// parse inputs
|
|
|
|
for (auto input_idx : op->inputs) {
|
|
|
|
for (auto input_idx : op->inputs) {
|
|
|
|
|
|
|
|
if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
if (input_idx < 0) {
|
|
|
|
if (input_idx < 0) {
|
|
|
|
input_idx += tflite_subgraph->tensors.size();
|
|
|
|
input_idx += tflite_subgraph->tensors.size();
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -126,18 +128,14 @@ STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// const tensor
|
|
|
|
// const tensor
|
|
|
|
if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) {
|
|
|
|
auto parameter = func_graph_->add_parameter();
|
|
|
|
auto parameter = func_graph_->add_parameter();
|
|
|
|
status = ConvertConstTensor(input_tensor.get(), parameter.get());
|
|
|
|
status = ConvertConstTensor(input_tensor.get(), parameter.get());
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
|
|
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
|
|
|
return status;
|
|
|
|
return status;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
op_inputs.emplace_back(parameter);
|
|
|
|
|
|
|
|
nodes_.insert(std::pair(input_idx, parameter));
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor.";
|
|
|
|
op_inputs.emplace_back(parameter);
|
|
|
|
|
|
|
|
nodes_.insert(std::pair(input_idx, parameter));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto new_cnode = func_graph_->NewCNode(op_inputs);
|
|
|
|
auto new_cnode = func_graph_->NewCNode(op_inputs);
|
|
|
|
new_cnode->set_fullname_with_scope(op_name);
|
|
|
|
new_cnode->set_fullname_with_scope(op_name);
|
|
|
@ -268,6 +266,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
|
|
|
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
|
|
|
|
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
|
|
|
|
make_tuple_inputs.emplace_back(make_tuple_prim);
|
|
|
|
make_tuple_inputs.emplace_back(make_tuple_prim);
|
|
|
|
for (auto outputNode : tflite_subgraph->outputs) {
|
|
|
|
for (auto outputNode : tflite_subgraph->outputs) {
|
|
|
|
|
|
|
|
outputNode = outputNode < 0 ? outputNode + tflite_subgraph->tensors.size() : outputNode;
|
|
|
|
auto cnode = nodes_.at(outputNode);
|
|
|
|
auto cnode = nodes_.at(outputNode);
|
|
|
|
if (nullptr == cnode) {
|
|
|
|
if (nullptr == cnode) {
|
|
|
|
MS_LOG(ERROR) << "Can't find input node.";
|
|
|
|
MS_LOG(ERROR) << "Can't find input node.";
|
|
|
@ -296,9 +295,12 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
|
|
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
|
|
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
int outputNode = tflite_subgraph->outputs.front() < 0
|
|
|
|
|
|
|
|
? static_cast<int>(tflite_subgraph->outputs.front() + tflite_subgraph->tensors.size())
|
|
|
|
|
|
|
|
: static_cast<int>(tflite_subgraph->outputs.front());
|
|
|
|
auto valueNode = NewValueNode(returnPrim);
|
|
|
|
auto valueNode = NewValueNode(returnPrim);
|
|
|
|
std::vector<AnfNodePtr> op_inputs{valueNode};
|
|
|
|
std::vector<AnfNodePtr> op_inputs{valueNode};
|
|
|
|
auto cnode = nodes_.at(tflite_subgraph->outputs.front());
|
|
|
|
auto cnode = nodes_.at(outputNode);
|
|
|
|
if (nullptr == cnode) {
|
|
|
|
if (nullptr == cnode) {
|
|
|
|
MS_LOG(ERROR) << "Can't find input node.";
|
|
|
|
MS_LOG(ERROR) << "Can't find input node.";
|
|
|
|
return RET_NOT_FIND_OP;
|
|
|
|
return RET_NOT_FIND_OP;
|
|
|
@ -345,8 +347,8 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::memcpy(tensor_data, data.data(), size);
|
|
|
|
std::memcpy(tensor_data, data.data(), size);
|
|
|
|
param_value->SetTensorData(tensor_data, size);
|
|
|
|
param_value->SetTensorData(tensor_data, size);
|
|
|
|
parameter->set_default_param(param_value);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
parameter->set_default_param(param_value);
|
|
|
|
return RET_OK;
|
|
|
|
return RET_OK;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|