Modify the method for getting output index of metagraph.

pull/4372/head
wsc 5 years ago
parent 04371f6d38
commit 067beb2590

@ -58,7 +58,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
for (size_t i = 0; i < in_shape.size(); i++) { for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false; bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) { for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>((*axes)[idx]) == i) { if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) {
reduce_axis = true; reduce_axis = true;
break; break;
} }

@ -71,7 +71,7 @@ int ReduceCPUKernel::CheckParameters() {
return RET_ERROR; return RET_ERROR;
} }
for (auto i = 0; i < num_axes_; i++) { for (auto i = 0; i < num_axes_; i++) {
if (axes_[i] < -static_cast<int>(input_rank) || static_cast<size_t>(axes_[i]) >= input_rank) { if (axes_[i] < -static_cast<int>(input_rank) || axes_[i] >= static_cast<int>(input_rank)) {
MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in [" MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in ["
<< -static_cast<int>(input_rank) << ", " << input_rank - 1 << "]."; << -static_cast<int>(input_rank) << ", " << input_rank - 1 << "].";
return RET_ERROR; return RET_ERROR;

@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
} }
} }
void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef) { schema::MetaGraphT *subGraphDef) {
auto opGraph = OpGraphT::Build(subGraphDef);
auto graphInputs = tensorCache.GetGraphInputs(); auto graphInputs = tensorCache.GetGraphInputs();
auto graphOutputs = opGraph->GetOutputNode();
subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end());
for (const auto &output : graphOutputs) { for (auto outputIndex : tflite_subgraph->outputs) {
auto op = opMap[output->ID()]; int i = 0;
for (auto outputIndex : op->outputIndex) { bool found = false;
subGraphDef->outputIndex.emplace_back(outputIndex); for (const auto &tfliteOp : tflite_subgraph->operators) {
int j = 0;
auto opType = GetTfliteNodeType(tfliteOp, tflite_model);
std::string opName = opType + "-" + std::to_string(i++);
for (auto opOutputIndex : tfliteOp->outputs) {
if (outputIndex == opOutputIndex) {
subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]);
found = true;
break;
}
j++;
}
if (found) {
break;
}
} }
} }
} }
@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
return nullptr; return nullptr;
} }
SetGraphTensorIndex(tensorCache, subGraph.get()); SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get());
SetAllTensors(tensorCache, subGraph.get()); SetAllTensors(tensorCache, subGraph.get());
return subGraph.release(); return subGraph.release();
} }

@ -50,7 +50,10 @@ class TfliteModelParser : public ModelParser {
void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache); void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef); void SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef);
STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model, STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph, const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph,

Loading…
Cancel
Save