From 35c6f9325e3b2890534db064b81cb2120e9d1363 Mon Sep 17 00:00:00 2001 From: yvette Date: Fri, 11 Dec 2020 14:24:54 +0800 Subject: [PATCH] modify tensor name --- mindspore/lite/src/lite_session.cc | 10 ++-- .../anf_importer/import_from_meta_graphT.cc | 6 ++- mindspore/lite/tools/common/tensor_util.h | 1 + .../tools/converter/graphdef_transform.cc | 21 ++++---- .../graph/tensor_name_pass.cc | 51 +++++-------------- .../parser/tflite/tflite_model_parser.cc | 31 +++++++++-- .../parser/tflite/tflite_model_parser.h | 4 +- 7 files changed, 64 insertions(+), 60 deletions(-) diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 2166cc80b0..e5d419bdcd 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -322,11 +322,6 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { } int LiteSession::CompileGraph(Model *model) { - if (!ModelVerify(*model)) { - MS_LOG(ERROR) << "wrong model input, please check"; - return RET_ERROR; - } - bool expected = false; if (!is_running_.compare_exchange_strong(expected, true)) { MS_LOG(ERROR) << "Not support multi-threading"; @@ -343,6 +338,11 @@ int LiteSession::CompileGraph(Model *model) { is_running_.store(false); return RET_PARAM_INVALID; } + if (!ModelVerify(*model)) { + MS_LOG(ERROR) << "wrong model input, please check"; + is_running_.store(false); + return RET_ERROR; + } auto ret = ConvertTensors(model); if (ret != RET_OK) { diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index 7ee035d808..5dd94a0bfb 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -44,7 +44,11 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { auto abstract_tensor = std::make_shared(type_ptr, shape_vector); MS_ASSERT(nullptr != abstract_tensor); parameter->set_abstract(abstract_tensor); - parameter->set_name("const_" + std::to_string(i) + "_parameter"); + if (!tensor->name.empty()) { + parameter->set_name(tensor->name); + } else { + parameter->set_name("const-" + std::to_string(i)); + } ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(nullptr != param_value); diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h index c53fc8d302..71c199b9a4 100644 --- a/mindspore/lite/tools/common/tensor_util.h +++ b/mindspore/lite/tools/common/tensor_util.h @@ -74,6 +74,7 @@ class TensorCache { } else { tensor->nodeType = schema::NodeType_Parameter; } + tensor->name = name; tensors.push_back(tensor); if (Category == GRAPH_INPUT) { diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 6ed31418b1..c06b73491e 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -180,24 +180,25 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // topological sorting + // tensor name { - Optimizer topologicalOptimizer; - topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - status = topologicalOptimizer.Run(graphDefT); + Optimizer nameOptimizer; + nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); + status = nameOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed"; return status; } } - // tensor name + // topological sorting { - Optimizer nameOptimizer; - nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); - status = nameOptimizer.Run(graphDefT); + Optimizer topologicalOptimizer; + topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + status = topologicalOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed"; + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; return status; } } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc index a203b3c515..566709a848 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc @@ -21,54 +21,31 @@ namespace mindspore::lite { STATUS TensorNamePass::Run(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - - for (int i = 0; i < static_cast(graph->inputIndex.size()); i++) { - auto tensor_id = graph->inputIndex.at(i); - auto &tensor = graph->allTensors.at(tensor_id); - tensor->name = "graph_input-" + std::to_string(i); + if (graph == nullptr) { + MS_LOG(ERROR) << "graph is nullptr"; + return RET_NULL_PTR; } for (auto &node : graph->nodes) { if (node == nullptr || node->primitive == nullptr) { - MS_LOG(ERROR) << " node or node->primitive is nullptr"; - return RET_ERROR; + MS_LOG(ERROR) << "node or node->primitive is nullptr"; + return RET_NULL_PTR; } - for (int i = 0; i < static_cast(node->outputIndex.size()); i++) { - auto tensor_id = node->outputIndex.at(i); + for (int i = 0; i < static_cast(node->inputIndex.size()); i++) { + auto tensor_id = node->inputIndex.at(i); auto &tensor = graph->allTensors.at(tensor_id); if (tensor->name.empty()) { - tensor->name = node->name + "/output-" + std::to_string(i); + MS_LOG(WARNING) << "input tensor (id = " << tensor_id << ") name is null"; + tensor->name = node->name + "/input-" + std::to_string(i); } } - auto type = node->primitive->value.type; - if (type == PrimitiveType_Conv2D || type == PrimitiveType_DeConv2D || type == PrimitiveType_DepthwiseConv2D || - type == PrimitiveType_DeDepthwiseConv2D || type == PrimitiveType_FullConnection) { - auto input_size = node->inputIndex.size(); - if (input_size > 1) { - auto weight_tensor_id = node->inputIndex.at(1); - auto &weight_tensor = graph->allTensors.at(weight_tensor_id); - if (weight_tensor->name.empty()) { - weight_tensor->name = node->name + "/weight"; - } - - if (input_size > 2) { - auto bias_tensor_id = node->inputIndex.at(2); - auto &bias_tensor = graph->allTensors.at(bias_tensor_id); - if (bias_tensor->name.empty()) { - bias_tensor->name = node->name + "/bias"; - } - } - } - } else { - for (int i = 0; i < static_cast(node->inputIndex.size()); i++) { - auto tensor_id = node->inputIndex.at(i); - auto &tensor = graph->allTensors.at(tensor_id); - if (tensor->name.empty()) { - tensor->name = node->name + "/input-" + std::to_string(i); - } + for (int i = 0; i < static_cast(node->outputIndex.size()); i++) { + auto tensor_id = node->outputIndex.at(i); + auto &tensor = graph->allTensors.at(tensor_id); + if (tensor->name.empty()) { + tensor->name = node->name + "/output-" + std::to_string(i); } } } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 3dcce9d923..8b8077d0d8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -115,7 +115,8 @@ STATUS TfliteModelParser::ConvertOps() { std::vector op_inputs = {NewValueNode(std::shared_ptr(primitiveC))}; // parse inputs - for (auto input_idx : op->inputs) { + for (int i = 0; i < static_cast(op->inputs.size()); i++) { + auto input_idx = op->inputs.at(i); if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) { continue; } @@ -127,9 +128,27 @@ STATUS TfliteModelParser::ConvertOps() { op_inputs.emplace_back(nodes_.at(input_idx)); continue; } + // const tensor + std::string tensor_name; + if (!input_tensor->name.empty()) { + tensor_name = input_tensor->name; + } else { + tensor_name = op_name + "/input-" + std::to_string(op_inputs.size()); + if (tflite_op_type == tflite::BuiltinOperator_CONV_2D || + tflite_op_type == tflite::BuiltinOperator_TRANSPOSE_CONV || + tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D || + tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) { + if (i == 1) { + tensor_name = op_name + "/weight"; + } + if (i == 2) { + tensor_name = op_name + "/bias"; + } + } + } auto parameter = func_graph_->add_parameter(); - status = ConvertConstTensor(input_tensor.get(), parameter.get()); + status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name); if (status != RET_OK) { MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; continue; @@ -248,11 +267,12 @@ STATUS TfliteModelParser::ConvertGraphInputs() { auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); auto abstract_tensor = std::make_shared(type_ptr, shape_vector); parameter->set_abstract(abstract_tensor); - parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter"); + parameter->set_name("graph_input-" + std::to_string(tflite_graph_input)); nodes_.insert(std::pair(tflite_graph_input, parameter)); } return RET_OK; } + STATUS TfliteModelParser::ConvertGraphOutputs() { const auto &tflite_subgraph = tflite_model_->subgraphs.front(); if (tflite_subgraph->outputs.size() > 1) { @@ -312,7 +332,8 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { return RET_OK; } -STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) { +STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter, + const std::string &tensor_name) { if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null, get const tensor failed."; return RET_NULL_PTR; @@ -329,7 +350,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para [](const int32_t &value) { return static_cast(value); }); auto abstract_tensor = std::make_shared(type_ptr, shape_vector); parameter->set_abstract(abstract_tensor); - parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter"); + parameter->set_name(tensor_name); ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(param_value != nullptr); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 86e4810e4f..3c69b9aa05 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -42,13 +42,13 @@ class TfliteModelParser : public ModelParser { FuncGraphPtr func_graph_; char *tflite_model_buf_ = nullptr; std::unique_ptr ReadTfliteModel(const char *model_path); - STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter); + STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter, const std::string &tensor_name); STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode); STATUS ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c); STATUS ConvertOps(); STATUS ConvertGraphInputs(); STATUS ConvertGraphOutputs(); - STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector *quant_params); + static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector *quant_params); }; } // namespace mindspore::lite #endif // LITE_TFLITE_MODEL_PARSER_H