diff --git a/ge/graph/passes/cond_pass.cc b/ge/graph/passes/cond_pass.cc index 372af921..06a209ed 100644 --- a/ge/graph/passes/cond_pass.cc +++ b/ge/graph/passes/cond_pass.cc @@ -26,9 +26,9 @@ namespace { namespace ge { Status CondPass::Run(NodePtr &node) { ComputeGraphPtr graph = nullptr; - OutDataAnchorPtr cond_out_anchor = nullptr; + OutDataAnchorPtr peer_out_anchor = nullptr; InDataAnchorPtr cond_in_anchor = nullptr; - Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); + Status ret = GetCondInfo(node, graph, peer_out_anchor, cond_in_anchor); if (ret == NOT_CHANGED) { return SUCCESS; } else if (ret != SUCCESS) { @@ -48,18 +48,18 @@ Status CondPass::Run(NodePtr &node) { if (cond_tensor.MutableShape().GetDim(0) == UNKNOWN_DIM_NUM) { GELOGI("Output tensor rank of Cond is unknown."); if (cond_tensor.GetDataType() == DT_STRING) { - GE_CHK_STATUS_RET(HandleStringCond(graph, cond_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", + GE_CHK_STATUS_RET(HandleStringCond(graph, peer_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", op_desc->GetName().c_str()) } return SUCCESS; } if (!cond_tensor.GetShape().IsScalar()) { - GE_CHK_STATUS_RET(HandleNonScalarCond(graph, cond_out_anchor, cond_in_anchor), "HandleNonScalarCond for %s failed.", + GE_CHK_STATUS_RET(HandleNonScalarCond(graph, peer_out_anchor, cond_in_anchor), "HandleNonScalarCond for %s failed.", op_desc->GetName().c_str()) } else { switch (cond_tensor.GetDataType()) { case DT_STRING: - GE_CHK_STATUS_RET(HandleStringCond(graph, cond_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", + GE_CHK_STATUS_RET(HandleStringCond(graph, peer_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", op_desc->GetName().c_str()) break; case DT_BOOL: @@ -69,7 +69,7 @@ Status CondPass::Run(NodePtr &node) { case DT_INT16: case DT_INT8: case DT_INT64: - GE_CHK_STATUS_RET(HandleScalarCond(graph, cond_out_anchor, cond_in_anchor, cond_tensor.GetDataType()), + GE_CHK_STATUS_RET(HandleScalarCond(graph, peer_out_anchor, cond_in_anchor, cond_tensor.GetDataType()), "HandleScalarCond for %s failed.", op_desc->GetName().c_str()) break; case DT_INT32: @@ -96,21 +96,21 @@ Status CondPass::Run(NodePtr &node) { /// @brief Get cond info for if / while /// @param [in] node: If / While op /// @param [out] graph: owner_graph of if node / while_cond subgraph -/// @param [out] cond_out_anchor: peer_cond_anchor +/// @param [out] peer_out_anchor: peer_cond_anchor /// @param [out] cond_in_anchor: cond_input /// @return Status /// -Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, +Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, InDataAnchorPtr &cond_in_anchor) { GE_CHECK_NOTNULL(node); std::string type = node->GetType(); if (kIfOpTypes.count(type) != 0) { - if (GetCondInfoForIf(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { + if (GetCondInfoForIf(node, graph, peer_out_anchor, cond_in_anchor) != SUCCESS) { GELOGE(FAILED, "Get cond_info for if node failed."); return FAILED; } } else if (kWhileOpTypes.count(type) != 0) { - if (GetCondInfoForWhile(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { + if (GetCondInfoForWhile(node, graph, peer_out_anchor, cond_in_anchor) != SUCCESS) { GELOGE(FAILED, "Get cond_info for while node failed."); return FAILED; } @@ -126,19 +126,19 @@ Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDat /// @brief Get cond info for if node /// @param [in] node: If op /// @param [out] graph: owner_graph of if node -/// @param [out] cond_out_anchor: peer_cond_anchor +/// @param [out] peer_out_anchor: peer_cond_anchor /// @param [out] cond_in_anchor: cond_input of if /// @return Status /// -Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, +Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, InDataAnchorPtr &cond_in_anchor) { GE_CHECK_NOTNULL(node); graph = node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(graph); cond_in_anchor = node->GetInDataAnchor(IF_COND_INPUT); GE_CHECK_NOTNULL(cond_in_anchor); - cond_out_anchor = cond_in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(cond_out_anchor); + peer_out_anchor = cond_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); return SUCCESS; } @@ -146,11 +146,11 @@ Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, O /// @brief Get cond info for while node /// @param [in] node: While op /// @param [out] graph: while_cond subgraph -/// @param [out] cond_out_anchor: peer_cond_anchor +/// @param [out] peer_out_anchor: peer_cond_anchor /// @param [out] cond_in_anchor: input of NetOutput in cond_graph /// @return Status /// -Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, +Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, InDataAnchorPtr &cond_in_anchor) { GE_CHECK_NOTNULL(node); OpDescPtr op_desc = node->GetOpDesc(); @@ -177,8 +177,8 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph cond_in_anchor = net_output_node->GetInDataAnchor(0); GE_CHECK_NOTNULL(cond_in_anchor); - cond_out_anchor = cond_in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(cond_out_anchor); + peer_out_anchor = cond_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); return SUCCESS; } @@ -186,56 +186,56 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph /// /// @brief Process Cond Op with non-scalar cond_input: cond->Size->If / NetOutput(while) /// @param [in] graph -/// @param [in] out_anchor: peer_cond_anchor -/// @param [in] in_anchor: cond_input +/// @param [in] peer_out_anchor: peer_cond_anchor +/// @param [in] cond_in_anchor: cond_input /// @return Status /// -Status CondPass::HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor) { +Status CondPass::HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &cond_in_anchor) { GELOGI("Handle cond with non-scalar cond-input."); - return InsertNode(graph, out_anchor, in_anchor, SIZE); + return InsertNode(graph, peer_out_anchor, cond_in_anchor, SIZE); } /// /// @brief Process Cond Op with scalar-string cond_input: cond->StringLength(int32)->If / NetOutput(while) /// @param [in] graph -/// @param [in] out_anchor: peer_cond_anchor -/// @param [in] in_anchor: cond_input +/// @param [in] peer_out_anchor: peer_cond_anchor +/// @param [in] cond_in_anchor: cond_input /// @return Status /// -Status CondPass::HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor) { +Status CondPass::HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &cond_in_anchor) { GELOGI("Handle cond with scalar-string cond-input."); - return InsertNode(graph, out_anchor, in_anchor, kStringLength); + return InsertNode(graph, peer_out_anchor, cond_in_anchor, kStringLength); } /// /// @brief Process Cond Op with scalar cond_input: cond->Cast(2int32)->If / NetOutput(while) /// @param [in] graph -/// @param [in] out_anchor: peer_cond_anchor -/// @param [in] in_anchor: cond_input +/// @param [in] peer_out_anchor: peer_cond_anchor +/// @param [in] cond_in_anchor: cond_input /// @param [in] src_type /// @return Status /// -Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor, DataType src_type) { - GE_CHECK_NOTNULL(in_anchor); - GE_CHECK_NOTNULL(out_anchor); - GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); +Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &cond_in_anchor, DataType src_type) { + GE_CHECK_NOTNULL(cond_in_anchor); + GE_CHECK_NOTNULL(peer_out_anchor); + GE_CHECK_NOTNULL(peer_out_anchor->GetOwnerNode()->GetOpDesc()); GELOGI("Handle cond with scalar cond-input."); - GeTensorDesc tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); - std::string cast_name = in_anchor->GetOwnerNode()->GetName() + "_Cast"; + GeTensorDesc tensor = peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(peer_out_anchor->GetIdx()); + std::string cast_name = cond_in_anchor->GetOwnerNode()->GetName() + "_Cast"; NodePtr cast_node = AddCastNode(graph, cast_name, tensor, src_type, DT_INT32); if (cast_node == nullptr) { GELOGE(FAILED, "Add Cast node failed, name:%s.", cast_name.c_str()); return FAILED; } - if (GraphUtils::InsertNodeAfter(out_anchor, { in_anchor }, cast_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(peer_out_anchor, { cond_in_anchor }, cast_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.", - cast_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), - in_anchor->GetOwnerNode()->GetName().c_str()); + cast_node->GetName().c_str(), peer_out_anchor->GetOwnerNode()->GetName().c_str(), + cond_in_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; } @@ -245,27 +245,27 @@ Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnc /// /// @brief Insert node /// @param [in] graph -/// @param [in] out_anchor -/// @param [in] in_anchor +/// @param [in] peer_out_anchor +/// @param [in] in_data_anchor /// @param [in] type /// @return Status /// -Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor, const std::string &type) { - GE_CHECK_NOTNULL(out_anchor); - GE_CHECK_NOTNULL(in_anchor); +Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &in_data_anchor, const std::string &type) { + GE_CHECK_NOTNULL(peer_out_anchor); + GE_CHECK_NOTNULL(in_data_anchor); GELOGD("Begin to insert %s node.", type.c_str()); - GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); - GE_CHECK_NOTNULL(in_anchor->GetOwnerNode()->GetOpDesc()); - GeTensorDesc in_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); - GeTensorDesc out_tensor = in_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(out_anchor->GetIdx()); + GE_CHECK_NOTNULL(peer_out_anchor->GetOwnerNode()->GetOpDesc()); + GE_CHECK_NOTNULL(in_data_anchor->GetOwnerNode()->GetOpDesc()); + GeTensorDesc in_tensor = peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(peer_out_anchor->GetIdx()); + GeTensorDesc out_tensor = in_data_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); out_tensor.SetDataType(DT_INT32); out_tensor.SetOriginDataType(DT_INT32); out_tensor.SetShape(in_tensor.GetShape()); out_tensor.SetOriginShape(in_tensor.GetOriginShape()); - OpDescBuilder op_desc_builder(in_anchor->GetOwnerNode()->GetName() + "_" + type, type); + OpDescBuilder op_desc_builder(in_data_anchor->GetOwnerNode()->GetName() + "_" + type, type); OpDescPtr op_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build(); if (op_desc == nullptr) { GELOGE(FAILED, "Create op_desc failed."); @@ -278,10 +278,10 @@ Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr } AddRePassNode(new_node); - if (GraphUtils::InsertNodeAfter(out_anchor, { in_anchor }, new_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(peer_out_anchor, { in_data_anchor }, new_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(), - new_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), - in_anchor->GetOwnerNode()->GetName().c_str()); + new_node->GetName().c_str(), peer_out_anchor->GetOwnerNode()->GetName().c_str(), + in_data_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; } diff --git a/ge/graph/passes/cond_pass.h b/ge/graph/passes/cond_pass.h index 5c0c83bc..eeb3d5e5 100644 --- a/ge/graph/passes/cond_pass.h +++ b/ge/graph/passes/cond_pass.h @@ -28,76 +28,76 @@ class CondPass : public BaseNodePass { /// @brief Get cond info for if / while /// @param [in] node: If / While op /// @param [out] graph: owner_graph of if node / while_cond subgraph - /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] peer_out_anchor: peer_cond_anchor /// @param [out] cond_in_anchor: cond_input /// @return Status /// - static Status GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, - InDataAnchorPtr &cond_in_anchor); + static Status GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, + InDataAnchorPtr &cond_in_anchor); /// /// @brief Get cond info for if node /// @param [in] node: If op /// @param [out] graph: owner_graph of if node - /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] peer_out_anchor: peer_cond_anchor /// @param [out] cond_in_anchor: cond_input of if /// @return Status /// - static Status GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, - InDataAnchorPtr &cond_in_anchor); + static Status GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, + InDataAnchorPtr &cond_in_anchor); /// /// @brief Get cond info for while node /// @param [in] node: While op /// @param [out] graph: while_cond subgraph - /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] peer_out_anchor: peer_cond_anchor /// @param [out] cond_in_anchor: input of NetOutput in cond_graph /// @return Status /// - static Status GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, - InDataAnchorPtr &cond_in_anchor); + static Status GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, + InDataAnchorPtr &cond_in_anchor); /// /// @brief Process Cond Op with non-scalar cond_input /// @param [in] graph - /// @param [in] out_anchor: peer_cond_anchor - /// @param [in] in_anchor: cond_input + /// @param [in] peer_out_anchor: peer_cond_anchor + /// @param [in] cond_in_anchor: cond_input /// @return Status /// - Status HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor); + Status HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &cond_in_anchor); /// /// @brief Process Cond Op with scalar-string cond_input /// @param [in] graph - /// @param [in] out_anchor: peer_cond_anchor - /// @param [in] in_anchor: cond_input + /// @param [in] peer_out_anchor: peer_cond_anchor + /// @param [in] cond_in_anchor: cond_input /// @return Status /// - Status HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor); + Status HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &cond_in_anchor); /// /// @brief Process Cond Op with scalar cond_input /// @param [in] graph - /// @param [in] out_anchor: peer_cond_anchor - /// @param [in] in_anchor: cond_input + /// @param [in] peer_out_anchor: peer_cond_anchor + /// @param [in] cond_in_anchor: cond_input /// @param [in] src_type /// @return Status /// - Status HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor, DataType src_type); + Status HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &cond_in_anchor, DataType src_type); /// /// @brief Insert node /// @param [in] graph - /// @param [in] out_anchor - /// @param [in] in_anchor + /// @param [in] peer_out_anchor + /// @param [in] in_data_anchor /// @param [in] type /// @return Status /// - Status InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor, const std::string &type); + Status InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, + const InDataAnchorPtr &in_data_anchor, const std::string &type); /// /// @brief Add cast node diff --git a/metadef b/metadef index 85ed8691..8c5be2db 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 85ed8691aab1f0c7d3b45785129e9063c84993ed +Subproject commit 8c5be2db907c26fcd9e7ffdd27f631302140bc2a diff --git a/parser b/parser index b45f4e83..81eb1792 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit b45f4e83d6a10bc22e15123a13fc8544c29f8c5d +Subproject commit 81eb1792fbbca4b569b391fd31d7dd7281e3c228