!9285 Fix GPU stream assign bug that StreamSwitch node is not found.

From: @zpac
Reviewed-by: @kisnwang,@limingqi107
Signed-off-by: @limingqi107
pull/9285/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9da41e6f18

@ -71,28 +71,32 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
std::vector<CNodePtr>::iterator mock_send_node_iter = std::vector<CNodePtr>::iterator mock_send_node_iter =
FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch); FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch);
if (mock_send_node_iter == iter + 1) { if (mock_send_node_iter == iter + 1) {
MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; MS_LOG(INFO) << "Can't find send node place before AllReduce node.";
continue; } else if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) {
}
if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) {
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1); send_recv_pairs->push_back(pair1);
} else {
MS_LOG(INFO) << "mock_send_node is AllReduce, no need to add stream switch node.";
} }
// Find node which uses AllReduce as input[0]. // Find node which uses AllReduce as input[0].
std::vector<CNodePtr>::iterator mock_recv_node_iter = std::vector<CNodePtr>::iterator mock_recv_node_iter =
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
if (mock_recv_node_iter == iter_end) { if (mock_recv_node_iter == iter_end) {
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node."; MS_LOG(INFO) << "Can't find recv node place after AllReduce node.";
return false; } else if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) {
}
if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) {
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)}; IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2); send_recv_pairs->push_back(pair2);
} else {
MS_LOG(INFO) << "mock_recv_node is AllReduce, no need to add stream switch node.";
} }
} }
} }
if (send_recv_pairs->empty()) {
MS_LOG(INFO) << "No stream switch node is found.";
return false;
}
return true; return true;
} }

Loading…
Cancel
Save