|
|
@ -306,6 +306,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
|
|
|
|
FuncGraphPtr paserTfFuction() { return nullptr; }
|
|
|
|
FuncGraphPtr paserTfFuction() { return nullptr; }
|
|
|
|
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
|
|
|
|
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
|
|
|
|
const QuantType &quantType) {
|
|
|
|
const QuantType &quantType) {
|
|
|
|
|
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TF");
|
|
|
|
auto status = ValidateFileStr(modelFile, ".pb");
|
|
|
|
auto status = ValidateFileStr(modelFile, ".pb");
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
|
|
|
|
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
|
|
|
@ -321,7 +322,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
|
|
|
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
|
|
|
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
|
|
|
|
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
anf_root_graph_ = std::make_shared<FuncGraph>();
|
|
|
|
anf_root_graph_ = std::make_shared<FuncGraph>();
|
|
|
@ -346,13 +347,13 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
|
|
|
|
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
|
|
|
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
|
|
|
auto &node_def = tf_root_graph_->node(i);
|
|
|
|
auto &node_def = tf_root_graph_->node(i);
|
|
|
|
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
|
|
|
|
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
|
|
|
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
success_flag = false;
|
|
|
|
success_flag = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!success_flag) {
|
|
|
|
if (!success_flag) {
|
|
|
|
MS_LOG(ERROR) << "Convert ops failed.";
|
|
|
|
MS_LOG(ERROR) << "Convert ops failed.";
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
status = ConvertRootGraphOutputs();
|
|
|
|
status = ConvertRootGraphOutputs();
|
|
|
@ -376,6 +377,7 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|
|
|
auto subgraph_size = graph_def_liarary.function_size();
|
|
|
|
auto subgraph_size = graph_def_liarary.function_size();
|
|
|
|
std::map<CNodePtr, FuncGraphPtr> while_cond_map;
|
|
|
|
std::map<CNodePtr, FuncGraphPtr> while_cond_map;
|
|
|
|
std::map<CNodePtr, FuncGraphPtr> while_body_map;
|
|
|
|
std::map<CNodePtr, FuncGraphPtr> while_body_map;
|
|
|
|
|
|
|
|
bool success_flag = true;
|
|
|
|
for (int i = 0; i < subgraph_size; i++) {
|
|
|
|
for (int i = 0; i < subgraph_size; i++) {
|
|
|
|
auto &tf_sub_fuction = graph_def_liarary.function(i);
|
|
|
|
auto &tf_sub_fuction = graph_def_liarary.function(i);
|
|
|
|
auto &tf_sub_signature = tf_sub_fuction.signature();
|
|
|
|
auto &tf_sub_signature = tf_sub_fuction.signature();
|
|
|
@ -421,12 +423,16 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|
|
|
for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
|
|
|
|
for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
|
|
|
|
auto &node_def = tf_sub_fuction.node_def(j);
|
|
|
|
auto &node_def = tf_sub_fuction.node_def(j);
|
|
|
|
status = ConvertOps(node_def, tf_sub_node_map, sub_func_graph, &anf_sub_node_map);
|
|
|
|
status = ConvertOps(node_def, tf_sub_node_map, sub_func_graph, &anf_sub_node_map);
|
|
|
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "Convert subgraph ops failed.";
|
|
|
|
MS_LOG(ERROR) << "Convert subgraph ops failed.";
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
success_flag = false;
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!success_flag) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Convert subgraph is failed.";
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// convert subgraph outputs
|
|
|
|
// convert subgraph outputs
|
|
|
|
std::vector<AnfNodePtr> sub_output_nodes;
|
|
|
|
std::vector<AnfNodePtr> sub_output_nodes;
|
|
|
@ -483,6 +489,10 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
|
|
|
|
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!success_flag) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Convert subgraph is failed.";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
auto status = WhileNodePostProcess(while_cond_map, while_body_map);
|
|
|
|
auto status = WhileNodePostProcess(while_cond_map, while_body_map);
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "while node post process failed";
|
|
|
|
MS_LOG(ERROR) << "while node post process failed";
|
|
|
@ -593,7 +603,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|
|
|
std::unordered_map<std::string, AnfNodePtr> *anf_node_map) {
|
|
|
|
std::unordered_map<std::string, AnfNodePtr> *anf_node_map) {
|
|
|
|
MS_ASSERT(node_def != nullptr);
|
|
|
|
MS_ASSERT(node_def != nullptr);
|
|
|
|
MS_ASSERT(func_graph_ptr != nullptr);
|
|
|
|
MS_ASSERT(func_graph_ptr != nullptr);
|
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TF");
|
|
|
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
const auto &op_type = node_def.op();
|
|
|
|
const auto &op_type = node_def.op();
|
|
|
|
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") {
|
|
|
|
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") {
|
|
|
@ -645,8 +654,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|
|
|
status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
|
|
|
|
status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
|
|
|
|
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return status;
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
}
|
|
|
|