|
|
@ -33,11 +33,11 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
|
|
|
|
GE_CHECK_NOTNULL(compute_graph);
|
|
|
|
GE_CHECK_NOTNULL(compute_graph);
|
|
|
|
|
|
|
|
|
|
|
|
if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) {
|
|
|
|
if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) {
|
|
|
|
GELOGI("No need FlowCtrl for graph %u", compute_graph->GetGraphID());
|
|
|
|
GELOGI("No need FlowCtrl for graph %u.", compute_graph->GetGraphID());
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
GELOGI("FlowCtrl pass begin.graph is [%s]", compute_graph->GetName().c_str());
|
|
|
|
GELOGI("FlowCtrl pass begin.graph is [%s].", compute_graph->GetName().c_str());
|
|
|
|
bool graph_change = false;
|
|
|
|
bool graph_change = false;
|
|
|
|
// 1. Add FP/BP flow ctrl (big cycle)
|
|
|
|
// 1. Add FP/BP flow ctrl (big cycle)
|
|
|
|
for (auto &node : compute_graph->GetDirectNode()) {
|
|
|
|
for (auto &node : compute_graph->GetDirectNode()) {
|
|
|
@ -347,11 +347,11 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c
|
|
|
|
NodePtr assign_node =
|
|
|
|
NodePtr assign_node =
|
|
|
|
InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node);
|
|
|
|
InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node);
|
|
|
|
if (assign_node == nullptr || switch_node == nullptr) {
|
|
|
|
if (assign_node == nullptr || switch_node == nullptr) {
|
|
|
|
GELOGE(PARAM_INVALID, "assign_node or switch node is null");
|
|
|
|
GELOGE(PARAM_INVALID, "assign_node or switch node is null.");
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed");
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed.");
|
|
|
|
|
|
|
|
|
|
|
|
graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor());
|
|
|
|
graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor());
|
|
|
|
if (add_ret != GRAPH_SUCCESS) {
|
|
|
|
if (add_ret != GRAPH_SUCCESS) {
|
|
|
@ -370,7 +370,7 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c
|
|
|
|
}
|
|
|
|
}
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed");
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed");
|
|
|
|
GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, switch_node->GetName()),
|
|
|
|
GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, switch_node->GetName()),
|
|
|
|
"set switch branch node label failed");
|
|
|
|
"set switch branch node label failed.");
|
|
|
|
|
|
|
|
|
|
|
|
string model_exit_name = switch_node->GetName() + "_ModelExit";
|
|
|
|
string model_exit_name = switch_node->GetName() + "_ModelExit";
|
|
|
|
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { model_exit_name }), "set active label list failed");
|
|
|
|
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { model_exit_name }), "set active label list failed");
|
|
|
@ -401,7 +401,7 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Status FlowCtrlPass::AddFpBpIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &pre_node) {
|
|
|
|
Status FlowCtrlPass::AddFpBpIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &pre_node) {
|
|
|
|
GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr"); return FAILED);
|
|
|
|
GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr."); return FAILED);
|
|
|
|
string pre_node_name = pre_node->GetName();
|
|
|
|
string pre_node_name = pre_node->GetName();
|
|
|
|
GELOGI("Add FpBp Iterator ctrl, pre node:%s.", pre_node_name.c_str());
|
|
|
|
GELOGI("Add FpBp Iterator ctrl, pre node:%s.", pre_node_name.c_str());
|
|
|
|
// 1. Get or add variables
|
|
|
|
// 1. Get or add variables
|
|
|
@ -477,7 +477,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
|
|
|
|
* itersPerLoop loopCond
|
|
|
|
* itersPerLoop loopCond
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
GE_IF_BOOL_EXEC(loop_after_node == nullptr || compute_graph == nullptr,
|
|
|
|
GE_IF_BOOL_EXEC(loop_after_node == nullptr || compute_graph == nullptr,
|
|
|
|
DOMI_LOGE("loop after node or compute graph is null"); return FAILED);
|
|
|
|
DOMI_LOGE("loop after node or compute graph is null."); return FAILED);
|
|
|
|
InDataAnchorPtr in_anchor = loop_after_node->GetInDataAnchor(0);
|
|
|
|
InDataAnchorPtr in_anchor = loop_after_node->GetInDataAnchor(0);
|
|
|
|
if (in_anchor == nullptr || in_anchor->GetPeerOutAnchor() == nullptr) {
|
|
|
|
if (in_anchor == nullptr || in_anchor->GetPeerOutAnchor() == nullptr) {
|
|
|
|
GELOGE(FAILED, "Find %s in data anchor failed.", loop_after_node->GetName().c_str());
|
|
|
|
GELOGE(FAILED, "Find %s in data anchor failed.", loop_after_node->GetName().c_str());
|
|
|
@ -498,7 +498,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. Add StreamSwitch and edges to switch_node.
|
|
|
|
// 2. Add StreamSwitch and edges to switch_node.
|
|
|
|
GE_IF_BOOL_EXEC(loop_pre_node == nullptr, DOMI_LOGE("loop pre node is null"); return FAILED);
|
|
|
|
GE_IF_BOOL_EXEC(loop_pre_node == nullptr, DOMI_LOGE("loop pre node is null."); return FAILED);
|
|
|
|
string switch_name = loop_pre_node->GetName() + "_" + NODE_NAME_STREAM_SWITCH;
|
|
|
|
string switch_name = loop_pre_node->GetName() + "_" + NODE_NAME_STREAM_SWITCH;
|
|
|
|
NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node);
|
|
|
|
NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node);
|
|
|
|
if (switch_node == nullptr) {
|
|
|
|
if (switch_node == nullptr) {
|
|
|
@ -506,7 +506,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed");
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed.");
|
|
|
|
|
|
|
|
|
|
|
|
graphStatus add_ret = GraphUtils::AddEdge(loop_pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
|
|
|
|
graphStatus add_ret = GraphUtils::AddEdge(loop_pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
|
|
|
|
if (add_ret != GRAPH_SUCCESS) {
|
|
|
|
if (add_ret != GRAPH_SUCCESS) {
|
|
|
@ -529,7 +529,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed");
|
|
|
|
GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed.");
|
|
|
|
|
|
|
|
|
|
|
|
GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
|
|
|
|
GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
|
|
|
|
DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED);
|
|
|
|
DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED);
|
|
|
@ -542,7 +542,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// used for stream assign to find true branch
|
|
|
|
// used for stream assign to find true branch
|
|
|
|
GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed");
|
|
|
|
GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed.");
|
|
|
|
// used for stream assign to find active stream
|
|
|
|
// used for stream assign to find active stream
|
|
|
|
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { loop_pre_node->GetName() }), "set active label list failed");
|
|
|
|
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { loop_pre_node->GetName() }), "set active label list failed");
|
|
|
|
active_nodes_in_iter_loop_.push_back(active_node);
|
|
|
|
active_nodes_in_iter_loop_.push_back(active_node);
|
|
|
|