modify tensor name

pull/9589/head
yvette 4 years ago
parent 49172089f1
commit 35c6f9325e

@ -322,11 +322,6 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
}
int LiteSession::CompileGraph(Model *model) {
if (!ModelVerify(*model)) {
MS_LOG(ERROR) << "wrong model input, please check";
return RET_ERROR;
}
bool expected = false;
if (!is_running_.compare_exchange_strong(expected, true)) {
MS_LOG(ERROR) << "Not support multi-threading";
@ -343,6 +338,11 @@ int LiteSession::CompileGraph(Model *model) {
is_running_.store(false);
return RET_PARAM_INVALID;
}
if (!ModelVerify(*model)) {
MS_LOG(ERROR) << "wrong model input, please check";
is_running_.store(false);
return RET_ERROR;
}
auto ret = ConvertTensors(model);
if (ret != RET_OK) {

@ -44,7 +44,11 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() {
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
MS_ASSERT(nullptr != abstract_tensor);
parameter->set_abstract(abstract_tensor);
parameter->set_name("const_" + std::to_string(i) + "_parameter");
if (!tensor->name.empty()) {
parameter->set_name(tensor->name);
} else {
parameter->set_name("const-" + std::to_string(i));
}
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
MS_ASSERT(nullptr != param_value);

@ -74,6 +74,7 @@ class TensorCache {
} else {
tensor->nodeType = schema::NodeType_Parameter;
}
tensor->name = name;
tensors.push_back(tensor);
if (Category == GRAPH_INPUT) {

@ -180,24 +180,25 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// topological sorting
// tensor name
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
Optimizer nameOptimizer;
nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass());
status = nameOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed";
return status;
}
}
// tensor name
// topological sorting
{
Optimizer nameOptimizer;
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass());
status = nameOptimizer.Run(graphDefT);
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}

@ -21,54 +21,31 @@
namespace mindspore::lite {
STATUS TensorNamePass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (int i = 0; i < static_cast<int>(graph->inputIndex.size()); i++) {
auto tensor_id = graph->inputIndex.at(i);
auto &tensor = graph->allTensors.at(tensor_id);
tensor->name = "graph_input-" + std::to_string(i);
if (graph == nullptr) {
MS_LOG(ERROR) << "graph is nullptr";
return RET_NULL_PTR;
}
for (auto &node : graph->nodes) {
if (node == nullptr || node->primitive == nullptr) {
MS_LOG(ERROR) << "node or node->primitive is nullptr";
return RET_ERROR;
return RET_NULL_PTR;
}
for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) {
auto tensor_id = node->outputIndex.at(i);
for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) {
auto tensor_id = node->inputIndex.at(i);
auto &tensor = graph->allTensors.at(tensor_id);
if (tensor->name.empty()) {
tensor->name = node->name + "/output-" + std::to_string(i);
}
MS_LOG(WARNING) << "input tensor (id = " << tensor_id << ") name is null";
tensor->name = node->name + "/input-" + std::to_string(i);
}
auto type = node->primitive->value.type;
if (type == PrimitiveType_Conv2D || type == PrimitiveType_DeConv2D || type == PrimitiveType_DepthwiseConv2D ||
type == PrimitiveType_DeDepthwiseConv2D || type == PrimitiveType_FullConnection) {
auto input_size = node->inputIndex.size();
if (input_size > 1) {
auto weight_tensor_id = node->inputIndex.at(1);
auto &weight_tensor = graph->allTensors.at(weight_tensor_id);
if (weight_tensor->name.empty()) {
weight_tensor->name = node->name + "/weight";
}
if (input_size > 2) {
auto bias_tensor_id = node->inputIndex.at(2);
auto &bias_tensor = graph->allTensors.at(bias_tensor_id);
if (bias_tensor->name.empty()) {
bias_tensor->name = node->name + "/bias";
}
}
}
} else {
for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) {
auto tensor_id = node->inputIndex.at(i);
for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) {
auto tensor_id = node->outputIndex.at(i);
auto &tensor = graph->allTensors.at(tensor_id);
if (tensor->name.empty()) {
tensor->name = node->name + "/input-" + std::to_string(i);
}
tensor->name = node->name + "/output-" + std::to_string(i);
}
}
}

@ -115,7 +115,8 @@ STATUS TfliteModelParser::ConvertOps() {
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
// parse inputs
for (auto input_idx : op->inputs) {
for (int i = 0; i < static_cast<int>(op->inputs.size()); i++) {
auto input_idx = op->inputs.at(i);
if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) {
continue;
}
@ -127,9 +128,27 @@ STATUS TfliteModelParser::ConvertOps() {
op_inputs.emplace_back(nodes_.at(input_idx));
continue;
}
// const tensor
std::string tensor_name;
if (!input_tensor->name.empty()) {
tensor_name = input_tensor->name;
} else {
tensor_name = op_name + "/input-" + std::to_string(op_inputs.size());
if (tflite_op_type == tflite::BuiltinOperator_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_TRANSPOSE_CONV ||
tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
if (i == 1) {
tensor_name = op_name + "/weight";
}
if (i == 2) {
tensor_name = op_name + "/bias";
}
}
}
auto parameter = func_graph_->add_parameter();
status = ConvertConstTensor(input_tensor.get(), parameter.get());
status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name);
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
continue;
@ -248,11 +267,12 @@ STATUS TfliteModelParser::ConvertGraphInputs() {
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter");
parameter->set_name("graph_input-" + std::to_string(tflite_graph_input));
nodes_.insert(std::pair(tflite_graph_input, parameter));
}
return RET_OK;
}
STATUS TfliteModelParser::ConvertGraphOutputs() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
if (tflite_subgraph->outputs.size() > 1) {
@ -312,7 +332,8 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
return RET_OK;
}
STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) {
STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter,
const std::string &tensor_name) {
if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null, get const tensor failed.";
return RET_NULL_PTR;
@ -329,7 +350,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter");
parameter->set_name(tensor_name);
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
MS_ASSERT(param_value != nullptr);

@ -42,13 +42,13 @@ class TfliteModelParser : public ModelParser {
FuncGraphPtr func_graph_;
char *tflite_model_buf_ = nullptr;
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter);
STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter, const std::string &tensor_name);
STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode);
STATUS ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c);
STATUS ConvertOps();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params);
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params);
};
} // namespace mindspore::lite
#endif // LITE_TFLITE_MODEL_PARSER_H

Loading…
Cancel
Save