You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
398 lines
20 KiB
398 lines
20 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "graph/optimize/optimizer/allreduce_fusion_pass.h"
|
|
#include <string>
|
|
#include "common/debug/log.h"
|
|
#include "framework/common/debug/ge_log.h"
|
|
#include "common/types.h"
|
|
#include "common/util.h"
|
|
#include "graph/anchor.h"
|
|
#include "graph/node.h"
|
|
#include "graph/op_desc.h"
|
|
#include "graph/utils/attr_utils.h"
|
|
#include "graph/utils/graph_utils.h"
|
|
#include "graph/utils/tensor_utils.h"
|
|
#include "graph/debug/ge_attr_define.h"
|
|
#include "hccl/base.h"
|
|
#include "hccl/hcom.h"
|
|
|
|
namespace ge {
|
|
Status AllReducePass::Run(ge::ComputeGraphPtr graph) {
|
|
GELOGI("FusionAllReducePass: start");
|
|
std::vector<NodePtr> fusionOps;
|
|
std::vector<float> inputGradientSize;
|
|
std::vector<float> inputGradientTime;
|
|
|
|
static const float inputGradientSizeTemp = 0.0;
|
|
static const float inputGradientTimeTemp = 0.0;
|
|
|
|
// Get all nodes
|
|
for (auto nodePtr : graph->GetDirectNode()) {
|
|
GE_IF_BOOL_EXEC(nullptr == nodePtr, GELOGW("FusionAllReducePass: null node exists"); continue;);
|
|
|
|
ge::OpDescPtr opDescPtr = nodePtr->GetOpDesc();
|
|
GE_IF_BOOL_EXEC(nullptr == opDescPtr,
|
|
GELOGW("FusionAllReducePass: desc of node %s is null", nodePtr->GetName().c_str());
|
|
continue;)
|
|
GE_IF_BOOL_EXEC(HCOMALLREDUCE == opDescPtr->GetType(),
|
|
// the op is allreduce and fusion > 0, then run fusion
|
|
std::int64_t hcom_fusion = 1;
|
|
GE_IF_BOOL_EXEC(!ge::AttrUtils::GetInt(opDescPtr, HCOM_ATTR_FUSION, hcom_fusion),
|
|
GELOGW("FusionAllReducePass: not get hcom_fusion from opDescPtr "
|
|
"by HCOM_ATTR_FUSION"));
|
|
GELOGI("after GetInt, hcom_fusion is :%ld", hcom_fusion); GE_IF_BOOL_EXEC(
|
|
hcom_fusion > 0, fusionOps.push_back(nodePtr); inputGradientSize.push_back(inputGradientSizeTemp);
|
|
inputGradientTime.push_back(inputGradientTimeTemp);))
|
|
}
|
|
// The number of allredecue operator must be more than 1
|
|
GE_IF_BOOL_EXEC(1 >= fusionOps.size(), GELOGW("FusionAllReducePass NOT_CHANGED: the graph has "
|
|
"%lu allreduce operator",
|
|
fusionOps.size());
|
|
return NOT_CHANGED;);
|
|
|
|
string group = "group";
|
|
u32 gradientNum = fusionOps.size();
|
|
string model_name_str = graph->GetName();
|
|
const char *model_name = model_name_str.c_str();
|
|
model_feature modelFeature{model_name, gradientNum, inputGradientSize.data(), inputGradientTime.data()};
|
|
|
|
u32 segmentNum = 0;
|
|
u32 segmentIndex[HCCL_MAX_SEGMENT_NUM] = {};
|
|
|
|
// Call HCCL function: hcom_gradient_segment
|
|
GELOGI("FusionAllReducePass: invoking hcom_get_split_strategy");
|
|
GE_IF_BOOL_EXEC(HCCL_SUCCESS != hcom_get_split_strategy(group.c_str(), &modelFeature, HCCL_MAX_SEGMENT_NUM,
|
|
&segmentNum, segmentIndex),
|
|
GELOGE(FAILED, "FusionAllReducePass FAILED: the graph has %lu allreduce operator", fusionOps.size());
|
|
return FAILED;)
|
|
GELOGI("FusionAllReducePass: invoke hcom_get_split_strategy successfully");
|
|
|
|
// check whether segmentNum is legal or not
|
|
GE_IF_BOOL_EXEC((HCCL_MAX_SEGMENT_NUM < segmentNum || 1 > segmentNum || segmentNum > gradientNum),
|
|
GELOGE(FAILED,
|
|
"FusionAllReducePass FAILED: illegal segmentNum=%u, "
|
|
"HCCL_MAX_SEGMENT_NUM=%u, gradientNum=%u",
|
|
segmentNum, HCCL_MAX_SEGMENT_NUM, gradientNum);
|
|
return FAILED;);
|
|
|
|
// check whether segmentIndex is legal or not
|
|
GE_IF_BOOL_EXEC((segmentIndex[segmentNum - 1] != gradientNum - 1),
|
|
GELOGE(FAILED,
|
|
"FusionAllReducePass FAILED: illegal segmentIndex[0]=%u, "
|
|
"segmentIndex[segmentNum-1]=%u, gradientNum=%u",
|
|
segmentIndex[0], segmentIndex[(segmentNum)-1], gradientNum);
|
|
return FAILED;);
|
|
|
|
for (uint32_t i = 0; i < segmentNum - 1; i++) {
|
|
GE_IF_BOOL_EXEC(segmentIndex[i] >= segmentIndex[i + 1], GELOGE(FAILED,
|
|
"FusionAllReducePass FAILED: illegal "
|
|
"segmentIndex[%u]=%u, segmentIndex[%u]=%u",
|
|
i, segmentIndex[i], i + 1, segmentIndex[i + 1]);
|
|
return FAILED;);
|
|
}
|
|
|
|
// check whether fusion is needed or not
|
|
GE_IF_BOOL_EXEC(
|
|
segmentNum == gradientNum,
|
|
GELOGE(NOT_CHANGED, "FusionAllReducePass NOT_CHANGED: segmentNum=%u, gradientNum=%u", segmentNum, gradientNum);
|
|
return NOT_CHANGED;)
|
|
|
|
std::unordered_set<void *> anchorPtrSet;
|
|
std::vector<ge::OutDataAnchorPtr> fusionOpPeerOutDataAnchor;
|
|
std::vector<ge::OutDataAnchorPtr> fusionOpPeerOutDataToInControl;
|
|
std::vector<ge::OutControlAnchorPtr> fusionOpPeerOutControlAnchor;
|
|
std::vector<std::pair<int, ge::InDataAnchorPtr>> fusionOpPeerInDataAnchor;
|
|
std::vector<std::pair<int, ge::InControlAnchorPtr>> fusionOpPeerInControlFromOutData;
|
|
std::vector<ge::InControlAnchorPtr> fusionOpPeerInControlAnchor;
|
|
ge::OutControlAnchorPtr previousNewAllreduceOutControlAnchor = nullptr;
|
|
|
|
// Traversing the segmentNum
|
|
uint32_t start = 0;
|
|
uint32_t end = 0;
|
|
for (uint32_t segmentIdx = 0; segmentIdx < segmentNum; segmentIdx++) {
|
|
end = segmentIndex[segmentIdx];
|
|
GE_IF_BOOL_EXEC(end - start < 1,
|
|
GELOGI("FusionAllReducePass: segmentIndex[%u]=%u", segmentIdx, segmentIndex[segmentIdx]);
|
|
start = end + 1; continue;);
|
|
|
|
ge::OpDescPtr originDescPtr = fusionOps[start]->GetOpDesc();
|
|
GE_CHECK_NOTNULL(originDescPtr);
|
|
ge::OpDescPtr newAllreduceDesc = AttrUtils::CloneOpDesc(originDescPtr);
|
|
GE_CHECK_NOTNULL(newAllreduceDesc);
|
|
|
|
// Cleat buffer
|
|
anchorPtrSet.clear();
|
|
fusionOpPeerOutDataAnchor.clear();
|
|
fusionOpPeerOutDataToInControl.clear();
|
|
fusionOpPeerOutControlAnchor.clear();
|
|
fusionOpPeerInDataAnchor.clear();
|
|
fusionOpPeerInControlFromOutData.clear();
|
|
fusionOpPeerInControlAnchor.clear();
|
|
|
|
// Traversing the Allreduce operators of each group
|
|
int outDataAnchorIndex = 0;
|
|
GE_CHK_STATUS_RET(GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[start]),
|
|
"Get peer outDataAnchor to inDataAnchor failed");
|
|
|
|
GE_CHK_STATUS_RET(GetPeerInAnchorToOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData,
|
|
fusionOps[start]),
|
|
"Get peer inDataAnchor and inControlAnchor to outDataAnchor failed");
|
|
|
|
GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[start]),
|
|
"Get peer outDataAnchor to inControlAnchor failed");
|
|
GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[start]),
|
|
"Get peer outControlAnchor to inControlAnchor failed");
|
|
GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[start]),
|
|
"Get peer outControlAnchor from inControlAnchor failed");
|
|
GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[start]), "FusionAllReducePass FAILED: remove node %s\n.",
|
|
fusionOps[start]->GetName().c_str());
|
|
|
|
for (uint32_t idx = start + 1; idx <= end; idx++) {
|
|
GE_CHK_STATUS_RET(
|
|
GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[idx], newAllreduceDesc),
|
|
"Get peer outDataAnchor to inDataAnchor failed");
|
|
GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[idx]),
|
|
"Get peer outDataAnchor to inControlAnchor failed");
|
|
GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[idx]),
|
|
"Get peer outControlAnchor to inControlAnchor failed");
|
|
GE_CHK_STATUS_RET(
|
|
GetPeerAnchorFromOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData,
|
|
fusionOps[idx], newAllreduceDesc, outDataAnchorIndex),
|
|
"Get peerAnchor from outDataAnchor failed");
|
|
GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[idx]),
|
|
"Get peer outControlAnchor from inControlAnchor failed");
|
|
|
|
// Delete the node
|
|
GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[idx]), "FusionAllReducePass FAILED: remove node %s\n.",
|
|
fusionOps[idx]->GetName().c_str());
|
|
}
|
|
|
|
NodePtr newAllReducePtr = graph->AddNode(newAllreduceDesc);
|
|
GE_CHECK_NOTNULL(newAllReducePtr);
|
|
// Link the inputDataAnchor
|
|
for (uint32_t i = 0; i < fusionOpPeerOutDataAnchor.size(); i++) {
|
|
GE_CHK_STATUS_RET(
|
|
GraphUtils::AddEdge(fusionOpPeerOutDataAnchor[i], newAllReducePtr->GetInDataAnchor(static_cast<int>(i))),
|
|
"FusionAllReducePass FAILED: add input data edge failed");
|
|
}
|
|
|
|
// Link the inputControlAnchor
|
|
for (uint32_t i = 0; i < fusionOpPeerOutControlAnchor.size(); i++) {
|
|
GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutControlAnchor[i], newAllReducePtr->GetInControlAnchor()),
|
|
"FusionAllReducePass FAILED: add input control edge failed");
|
|
}
|
|
|
|
for (uint32_t i = 0; i < fusionOpPeerOutDataToInControl.size(); i++) {
|
|
GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutDataToInControl[i], newAllReducePtr->GetInControlAnchor()),
|
|
"FusionAllReducePass FAILED: add edge from out data to incontrol "
|
|
"failed");
|
|
}
|
|
|
|
// Link the outputDataAnchor
|
|
for (uint32_t i = 0; i < fusionOpPeerInDataAnchor.size(); i++) {
|
|
auto peerInDataAnchor = fusionOpPeerInDataAnchor[i].second;
|
|
GE_CHK_STATUS_RET(
|
|
GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInDataAnchor[i].first), peerInDataAnchor),
|
|
"FusionAllReducePass FAILED: add output data edge failed");
|
|
}
|
|
for (uint32_t i = 0; i < fusionOpPeerInControlFromOutData.size(); i++) {
|
|
auto peerInControlAnchor = fusionOpPeerInControlFromOutData[i].second;
|
|
GE_CHK_STATUS_RET(
|
|
GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInControlFromOutData[i].first),
|
|
peerInControlAnchor),
|
|
"FusionAllReducePass FAILED: add edge from out data to in control "
|
|
"failed");
|
|
}
|
|
|
|
// Link the outputControlAnchor
|
|
for (uint32_t i = 0; i < fusionOpPeerInControlAnchor.size(); i++) {
|
|
GE_CHK_STATUS_RET(GraphUtils::AddEdge(newAllReducePtr->GetOutControlAnchor(), fusionOpPeerInControlAnchor[i]),
|
|
"FusionAllReducePass FAILED: add output control edge failed");
|
|
}
|
|
|
|
// Link the newAllreduce
|
|
if (segmentIdx > 0 && previousNewAllreduceOutControlAnchor != nullptr) {
|
|
GE_CHK_STATUS_RET(
|
|
GraphUtils::AddEdge(previousNewAllreduceOutControlAnchor, newAllReducePtr->GetInControlAnchor()),
|
|
"FusionAllReducePass FAILED: add input previous control edge failed");
|
|
}
|
|
|
|
previousNewAllreduceOutControlAnchor = newAllReducePtr->GetOutControlAnchor();
|
|
start = end + 1;
|
|
}
|
|
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AllReducePass::GetPeerOutDataToInData(std::unordered_set<void *> &anchorSet,
|
|
vector<ge::OutDataAnchorPtr> &peerOutDataAnchorVec,
|
|
ge::NodePtr &srcNodePtr) {
|
|
for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) {
|
|
GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;);
|
|
OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor();
|
|
GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;);
|
|
if (anchorSet.count(peerOutDataAnchor.get()) == 0) {
|
|
peerOutDataAnchorVec.push_back(peerOutDataAnchor);
|
|
anchorSet.insert(peerOutDataAnchor.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor));
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AllReducePass::GetPeerInAnchorToOutData(
|
|
std::unordered_set<void *> &anchorSet, std::vector<std::pair<int, ge::InDataAnchorPtr>> &fusionOpPeerInDataAnchor,
|
|
std::vector<std::pair<int, ge::InControlAnchorPtr>> &fusionOpPeerInControlFromOutData, ge::NodePtr &srcNodePtr) {
|
|
for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) {
|
|
GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;);
|
|
for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) {
|
|
GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;);
|
|
if (anchorSet.count(peerInDataAnchor.get()) == 0) {
|
|
std::pair<int, ge::InDataAnchorPtr> pairPeerInDataAnchor;
|
|
pairPeerInDataAnchor.first = 0;
|
|
pairPeerInDataAnchor.second = peerInDataAnchor;
|
|
fusionOpPeerInDataAnchor.push_back(pairPeerInDataAnchor);
|
|
anchorSet.insert(peerInDataAnchor.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor));
|
|
}
|
|
}
|
|
|
|
for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) {
|
|
GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;);
|
|
if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) {
|
|
std::pair<uint32_t, ge::InControlAnchorPtr> pairPeerInControlAnchorFromData;
|
|
pairPeerInControlAnchorFromData.first = 0;
|
|
pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData;
|
|
fusionOpPeerInControlFromOutData.push_back(pairPeerInControlAnchorFromData);
|
|
anchorSet.insert(peerInControlAnchorFromData.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData));
|
|
}
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AllReducePass::GetPeerOutDataToInData(std::unordered_set<void *> &anchorSet,
|
|
vector<ge::OutDataAnchorPtr> &peerOutDataAnchorVec,
|
|
ge::NodePtr &srcNodePtr, ge::OpDescPtr &dstOpDescPtr) {
|
|
for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) {
|
|
GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;);
|
|
OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor();
|
|
GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;);
|
|
if (anchorSet.count(peerOutDataAnchor.get()) == 0) {
|
|
peerOutDataAnchorVec.push_back(peerOutDataAnchor);
|
|
anchorSet.insert(peerOutDataAnchor.get());
|
|
if (dstOpDescPtr->AddInputDesc(inDataAnchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(inDataAnchor->GetIdx())) !=
|
|
ge::GRAPH_SUCCESS) {
|
|
GELOGW("GetPeerOutDataToInData: AddInputDesc failed");
|
|
}
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor));
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AllReducePass::GetPeerOutDataToInControl(std::unordered_set<void *> &anchorSet,
|
|
vector<ge::OutDataAnchorPtr> &peerOutDataToInControlVec,
|
|
ge::NodePtr &srcNodePtr) {
|
|
InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor();
|
|
GE_CHECK_NOTNULL(inControlAnchor);
|
|
for (auto peerOutDataToInControl : inControlAnchor->GetPeerOutDataAnchors()) {
|
|
GE_IF_BOOL_EXEC(peerOutDataToInControl == nullptr, continue;);
|
|
if (anchorSet.count(peerOutDataToInControl.get()) == 0) {
|
|
peerOutDataToInControlVec.push_back(peerOutDataToInControl);
|
|
anchorSet.insert(peerOutDataToInControl.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataToInControl, inControlAnchor));
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AllReducePass::GetPeerOutControlToInControl(std::unordered_set<void *> &anchorSet,
|
|
vector<ge::OutControlAnchorPtr> &peerOutControlToInControlVec,
|
|
ge::NodePtr &srcNodePtr) {
|
|
InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor();
|
|
GE_CHECK_NOTNULL(inControlAnchor);
|
|
for (auto peerOutControlAnchor : inControlAnchor->GetPeerOutControlAnchors()) {
|
|
GE_IF_BOOL_EXEC(peerOutControlAnchor == nullptr, continue;);
|
|
if (anchorSet.count(peerOutControlAnchor.get()) == 0) {
|
|
peerOutControlToInControlVec.push_back(peerOutControlAnchor);
|
|
anchorSet.insert(peerOutControlAnchor.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutControlAnchor, inControlAnchor));
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AllReducePass::GetPeerAnchorFromOutData(
|
|
std::unordered_set<void *> &anchorSet, vector<std::pair<int, ge::InDataAnchorPtr>> &peerInDataFromOutDataVec,
|
|
vector<std::pair<int, ge::InControlAnchorPtr>> &peerInControlFromOutDataVec, ge::NodePtr &srcNodePtr,
|
|
ge::OpDescPtr &dstOpDescPtr, int &index) {
|
|
for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) {
|
|
GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;)
|
|
if (outDataAnchor->GetPeerInDataAnchors().size() > 0 || outDataAnchor->GetPeerInControlAnchors().size() > 0) {
|
|
if (dstOpDescPtr->AddOutputDesc(
|
|
outDataAnchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(outDataAnchor->GetIdx())) != ge::GRAPH_SUCCESS) {
|
|
GELOGW("GetPeerAnchorFromOutData: AddOutputDesc failed");
|
|
}
|
|
index++;
|
|
}
|
|
|
|
for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) {
|
|
GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;)
|
|
if (anchorSet.count(peerInDataAnchor.get()) == 0) {
|
|
std::pair<int, ge::InDataAnchorPtr> pairPeerInDataAnchor;
|
|
pairPeerInDataAnchor.first = index;
|
|
pairPeerInDataAnchor.second = peerInDataAnchor;
|
|
peerInDataFromOutDataVec.push_back(pairPeerInDataAnchor);
|
|
anchorSet.insert(peerInDataAnchor.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor))
|
|
}
|
|
}
|
|
|
|
for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) {
|
|
GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;)
|
|
if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) {
|
|
std::pair<int, ge::InControlAnchorPtr> pairPeerInControlAnchorFromData;
|
|
pairPeerInControlAnchorFromData.first = index;
|
|
pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData;
|
|
peerInControlFromOutDataVec.push_back(pairPeerInControlAnchorFromData);
|
|
anchorSet.insert(peerInControlAnchorFromData.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData))
|
|
}
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AllReducePass::GetPeerInControlFromOutControl(std::unordered_set<void *> &anchorSet,
|
|
vector<ge::InControlAnchorPtr> &peerInControlFromOutControlVec,
|
|
ge::NodePtr &srcNodePtr) {
|
|
OutControlAnchorPtr outControlAnchor = srcNodePtr->GetOutControlAnchor();
|
|
GE_CHECK_NOTNULL(outControlAnchor);
|
|
for (auto peerInControlAnchor : outControlAnchor->GetPeerInControlAnchors()) {
|
|
GE_IF_BOOL_EXEC(peerInControlAnchor == nullptr, continue;)
|
|
if (anchorSet.count(peerInControlAnchor.get()) == 0) {
|
|
peerInControlFromOutControlVec.push_back(peerInControlAnchor);
|
|
anchorSet.insert(peerInControlAnchor.get());
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outControlAnchor, peerInControlAnchor))
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
} // namespace ge
|