|
|
|
@ -24,7 +24,19 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
|
Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
|
|
|
|
|
auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
|
|
|
|
|
AllreduceNodePtr arnode;
|
|
|
|
|
auto cnode_emplace_return = cnode_set_.emplace(node);
|
|
|
|
|
if (!cnode_emplace_return.second) {
|
|
|
|
|
MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
|
|
|
|
|
auto cnode_arnode_pair = cnode_arnode_map_.find(node);
|
|
|
|
|
if (cnode_arnode_pair == cnode_arnode_map_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!";
|
|
|
|
|
}
|
|
|
|
|
arnode = cnode_arnode_pair->second;
|
|
|
|
|
} else {
|
|
|
|
|
arnode = std::make_shared<AllreduceNode>(AllreduceNode());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (arnode->Init(node) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "AllreduceNode Init failed";
|
|
|
|
|
return FAILED;
|
|
|
|
@ -39,10 +51,6 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
|
|
|
|
|
if (!arnode_emplace_return.second) {
|
|
|
|
|
MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
|
|
|
|
|
}
|
|
|
|
|
auto cnode_emplace_return = cnode_set_.emplace(node);
|
|
|
|
|
if (!cnode_emplace_return.second) {
|
|
|
|
|
MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
|
|
|
|
|
}
|
|
|
|
|
cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
|
|
|
|
|
if (!cnode_emplace_return.second) {
|
|
|
|
|
MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
|
|
|
|
@ -75,7 +83,7 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double
|
|
|
|
|
MS_LOG(ERROR) << "from_arnode AddNext failed";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (to_arnode->AddPrev(from_arnode, dist) != SUCCESS) {
|
|
|
|
|
if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "to_arnode AddPrev failed";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
@ -110,7 +118,7 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
|
|
|
|
|
double cur_para_size = 0;
|
|
|
|
|
double from = to;
|
|
|
|
|
for (auto& arnode : arnode_vec_) {
|
|
|
|
|
if (arnode.depend_feat_size() >= to) {
|
|
|
|
|
if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
|
|
|
|
|