You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/hccl_group_pass.cc

74 lines
2.5 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "hccl_group_pass.h"
#include <deque>
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "framework/common/util.h"
namespace ge {
Status HcclGroupPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
bool is_fused_node = false;
if (!AttrUtils::GetBool(op_desc, ATTR_NAME_HCCL_FUSED_FLAG, is_fused_node)) {
GELOGW("Get attr ATTR_NAME_GRADIENT_FUSED_GROUP failed.");
return SUCCESS;
}
GELOGI("Recoginzed fused node %s", node->GetName().c_str());
if (op_desc->HasAttr(ATTR_NAME_HCCL_FUSED_GROUP)) {
GELOGD("Current node %s already marked group id, ignore it.", node->GetName().c_str());
return SUCCESS;
}
if (!is_fused_node) {
GELOGD("Current node %s is not gradient fused node , ignore it.", node->GetName().c_str());
return SUCCESS;
}
Status ret = MarkGroupForFusedNode(node);
if (ret != SUCCESS) {
GELOGW("Mark group for fused node %s failed. It might cause performance problem.", node->GetName().c_str());
}
return SUCCESS;
}
Status HcclGroupPass::MarkGroupForFusedNode(NodePtr &fused_node) {
std::deque<NodePtr> queue;
queue.push_back(fused_node);
string group_id = fused_node->GetName();
while (!queue.empty()) {
NodePtr node = queue.front();
queue.pop_front();
for (auto out_data_node : node->GetOutDataNodes()) {
if (out_data_node->GetType() == fused_node->GetType()) {
// if meet fused node, it is the end of current group
break;
}
if (!AttrUtils::SetStr(out_data_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_id)) {
GELOGW("Set attr ATTR_NAME_GRADIENT_FUSED_GROUP failed.");
return FAILED;
}
GELOGI("Set group_id %s for node %s", group_id.c_str(), out_data_node->GetName().c_str());
queue.emplace_back(out_data_node);
}
}
return SUCCESS;
}
} // namespace ge