|
|
@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
|
|
|
|
void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
|
|
|
|
|
|
|
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
|
|
|
|
|
|
|
const mindspore::lite::TensorCache &tensorCache,
|
|
|
|
schema::MetaGraphT *subGraphDef) {
|
|
|
|
schema::MetaGraphT *subGraphDef) {
|
|
|
|
auto opGraph = OpGraphT::Build(subGraphDef);
|
|
|
|
|
|
|
|
auto graphInputs = tensorCache.GetGraphInputs();
|
|
|
|
auto graphInputs = tensorCache.GetGraphInputs();
|
|
|
|
auto graphOutputs = opGraph->GetOutputNode();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end());
|
|
|
|
subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end());
|
|
|
|
|
|
|
|
|
|
|
|
for (const auto &output : graphOutputs) {
|
|
|
|
for (auto outputIndex : tflite_subgraph->outputs) {
|
|
|
|
auto op = opMap[output->ID()];
|
|
|
|
int i = 0;
|
|
|
|
for (auto outputIndex : op->outputIndex) {
|
|
|
|
bool found = false;
|
|
|
|
subGraphDef->outputIndex.emplace_back(outputIndex);
|
|
|
|
for (const auto &tfliteOp : tflite_subgraph->operators) {
|
|
|
|
|
|
|
|
int j = 0;
|
|
|
|
|
|
|
|
auto opType = GetTfliteNodeType(tfliteOp, tflite_model);
|
|
|
|
|
|
|
|
std::string opName = opType + "-" + std::to_string(i++);
|
|
|
|
|
|
|
|
for (auto opOutputIndex : tfliteOp->outputs) {
|
|
|
|
|
|
|
|
if (outputIndex == opOutputIndex) {
|
|
|
|
|
|
|
|
subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]);
|
|
|
|
|
|
|
|
found = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
j++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (found) {
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
SetGraphTensorIndex(tensorCache, subGraph.get());
|
|
|
|
SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get());
|
|
|
|
SetAllTensors(tensorCache, subGraph.get());
|
|
|
|
SetAllTensors(tensorCache, subGraph.get());
|
|
|
|
return subGraph.release();
|
|
|
|
return subGraph.release();
|
|
|
|
}
|
|
|
|
}
|
|
|
|