!4830 [gpu] fix continuous allreduces bug

Merge pull request !4830 from yuchaojie/gpu_allreduce
pull/4830/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b69b1ca8a8

@ -74,9 +74,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
continue;
}
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1);
if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) {
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1);
}
// Find node which uses AllReduce as input[0].
std::vector<CNodePtr>::iterator mock_recv_node_iter =
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
@ -84,9 +86,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node.";
return false;
}
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2);
if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) {
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2);
}
}
}
return true;
@ -110,17 +114,22 @@ std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
StreamSwitchType stream_switch_type) {
MS_EXCEPTION_IF_NULL(mock_send_node);
auto ret = end;
for (auto iter = begin; iter != end; iter++) {
auto node = *iter;
if (stream_switch_type == kAllReduceStreamSwitch) {
for (auto input : node->inputs()) {
if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) {
return iter;
if (AnfAlgo::GetCNodeName(node) != kAllReduceOpName) {
return iter;
} else if (ret == end) {
ret = iter;
}
}
}
}
}
return end;
return ret;
}
void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,

@ -41,7 +41,8 @@ struct StreamSwitchNode {
if (offset < n.offset) {
return true;
} else if (offset == n.offset) {
return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false;
return (AnfAlgo::GetCNodeName(cnode) == kRecvOpName && AnfAlgo::GetCNodeName(n.cnode) == kSendOpName) ? false
: true;
} else {
return false;
}

Loading…
Cancel
Save