|
|
|
@ -74,9 +74,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
|
|
|
|
|
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
|
|
|
|
|
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
|
|
|
|
|
send_recv_pairs->push_back(pair1);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) {
|
|
|
|
|
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
|
|
|
|
|
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
|
|
|
|
|
send_recv_pairs->push_back(pair1);
|
|
|
|
|
}
|
|
|
|
|
// Find node which uses AllReduce as input[0].
|
|
|
|
|
std::vector<CNodePtr>::iterator mock_recv_node_iter =
|
|
|
|
|
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
|
|
|
|
@ -84,9 +86,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
|
|
|
|
|
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
|
|
|
|
|
IntToSize(mock_recv_node_iter - iter_begin)};
|
|
|
|
|
send_recv_pairs->push_back(pair2);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) {
|
|
|
|
|
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
|
|
|
|
|
IntToSize(mock_recv_node_iter - iter_begin)};
|
|
|
|
|
send_recv_pairs->push_back(pair2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
@ -110,17 +114,22 @@ std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator
|
|
|
|
|
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
|
|
|
|
|
StreamSwitchType stream_switch_type) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mock_send_node);
|
|
|
|
|
auto ret = end;
|
|
|
|
|
for (auto iter = begin; iter != end; iter++) {
|
|
|
|
|
auto node = *iter;
|
|
|
|
|
if (stream_switch_type == kAllReduceStreamSwitch) {
|
|
|
|
|
for (auto input : node->inputs()) {
|
|
|
|
|
if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) {
|
|
|
|
|
return iter;
|
|
|
|
|
if (AnfAlgo::GetCNodeName(node) != kAllReduceOpName) {
|
|
|
|
|
return iter;
|
|
|
|
|
} else if (ret == end) {
|
|
|
|
|
ret = iter;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return end;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
|
|
|
|