You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/transop_nearby_allreduce_fu...

177 lines
7.2 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/transop_nearby_allreduce_fusion_pass.h"
#include "framework/common/debug/ge_log.h"
#include "common/debug/log.h"
#include "common/types.h"
#include "graph/utils/graph_utils.h"
#include "graph/common/transop_util.h"
namespace ge {
Status TransOpNearbyAllreduceFusionPass::Run(NodePtr &node) {
if (node == nullptr) {
GELOGW("null node is existed in graph");
return SUCCESS;
}
if (node->GetType() == HCOMALLREDUCE || node->GetType() == HVDCALLBACKALLREDUCE) {
GELOGI("found allreduce op %s", node->GetName().c_str());
Status ret = RemoveNearbyPairedTransOps(node);
if (ret != SUCCESS) {
GELOGE(FAILED, "failed to remove paired transop for allreduce op %s", node->GetName().c_str());
return FAILED;
}
GELOGI("successfully remove paired transop for allreduce op (%s)", node->GetName().c_str());
}
return SUCCESS;
}
bool TransOpNearbyAllreduceFusionPass::IsSymmetricTransOps(const NodePtr &node1, const NodePtr &node2) {
if (node1 == nullptr || node2 == nullptr || node1->GetOpDesc() == nullptr || node2->GetOpDesc() == nullptr) {
return false;
}
if (node1->GetType() != TRANSDATA || node2->GetType() != TRANSDATA) {
return false;
}
// two symmetric trans ops should have same type
if (node1->GetType() != node2->GetType()) {
return false;
}
const auto &node1_input_desc = node1->GetOpDesc()->MutableInputDesc(0);
const auto &node1_output_desc = node1->GetOpDesc()->MutableOutputDesc(0);
GE_CHECK_NOTNULL_EXEC(node1_input_desc, return false);
GE_CHECK_NOTNULL_EXEC(node1_output_desc, return false);
const auto &node2_input_desc = node2->GetOpDesc()->MutableInputDesc(0);
const auto &node2_output_desc = node2->GetOpDesc()->MutableOutputDesc(0);
GE_CHECK_NOTNULL_EXEC(node2_input_desc, return false);
GE_CHECK_NOTNULL_EXEC(node2_output_desc, return false);
// two symmetric trans ops should have symmetric input/output datatype
GELOGD("format: nod1_input=%d, nod1_output=%d, nod2_input=%d, nod2_output=%d",
node1_input_desc->GetFormat(), node1_output_desc->GetFormat(), node2_input_desc->GetFormat(),
node2_output_desc->GetFormat());
if (node1_input_desc->GetFormat() != node2_output_desc->GetFormat() ||
node1_output_desc->GetFormat() != node2_input_desc->GetFormat()) {
return false;
}
// two symmetric trans ops should have symmetric input/output format
GELOGD("datatype: nod1_input=%d, nod1_output=%d, nod2_input=%d, nod2_output=%d",
node1_input_desc->GetDataType(), node1_output_desc->GetDataType(), node2_input_desc->GetDataType(),
node2_output_desc->GetDataType());
if (node1_input_desc->GetDataType() != node2_output_desc->GetDataType() ||
node1_output_desc->GetDataType() != node2_input_desc->GetDataType()) {
return false;
}
// two symmetric trans ops should have symmetric input/output shape
if (node1_input_desc->GetShape().GetDims() != node2_output_desc->GetShape().GetDims() ||
node1_output_desc->GetShape().GetDims() != node2_input_desc->GetShape().GetDims()) {
return false;
}
return true;
}
Status TransOpNearbyAllreduceFusionPass::RemoveNearbyPairedTransOps(const NodePtr &node) {
if (node == nullptr) {
return FAILED;
}
GELOGI("find allReduce node %s", node->GetName().c_str());
auto in_data_anchors = node->GetAllInDataAnchors();
auto out_data_anchors = node->GetAllOutDataAnchors();
if (in_data_anchors.size() != out_data_anchors.size()) {
GELOGE(FAILED, "in and out data anchor size are not equal, node=%s, in_size=%zu, out_size=%zu",
node->GetName().c_str(), in_data_anchors.size(), out_data_anchors.size());
return FAILED;
}
size_t data_anchor_size = in_data_anchors.size();
GELOGI("node = %s, data_anchor_size = %zu", node->GetName().c_str(), data_anchor_size);
size_t removed_node_count = 0;
for (size_t i = 0; i < data_anchor_size; i++) {
if (in_data_anchors.at(i) == nullptr || out_data_anchors.at(i) == nullptr) {
GELOGW("node=%s has a null anchor at idx=%zu", node->GetName().c_str(), i);
continue;
}
if (in_data_anchors.at(i)->GetPeerAnchors().size() != 1) {
GELOGW("nodes=%s has abnormal in peer anchors at %zu", node->GetName().c_str(), i);
continue;
}
if (out_data_anchors.at(i)->GetPeerAnchors().size() != 1) {
GELOGW("nodes=%s has abnormal out peer anchors at %zu", node->GetName().c_str(), i);
continue;
}
auto in_first_peer_anchor = in_data_anchors.at(i)->GetFirstPeerAnchor();
if (in_first_peer_anchor == nullptr) {
GELOGW("node=%s, input anchor idx=%zu, first peer anchor is null", node->GetName().c_str(), i);
continue;
}
auto out_first_peer_anchor = out_data_anchors.at(i)->GetFirstPeerAnchor();
if (out_first_peer_anchor == nullptr) {
GELOGW("node=%s, output anchor idx=%zu, first peer anchor is null", node->GetName().c_str(), i);
continue;
}
auto in_node = in_first_peer_anchor->GetOwnerNode();
auto out_node = out_first_peer_anchor->GetOwnerNode();
GELOGI("in_node=%s, out_node=%s", in_node->GetName().c_str(), out_node->GetName().c_str());
if (!IsSymmetricTransOps(in_node, out_node)) {
GELOGD("ignore asymmetric transop %s and %s for node %s",
in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str());
continue;
}
// delete in_node
if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) {
GELOGE(FAILED, "remove node %s failed", in_node->GetName().c_str());
return FAILED;
}
removed_node_count++;
// delete out_node
if (IsolateAndDeleteNode(out_node, {0}) != SUCCESS) {
GELOGE(FAILED, "remove node %s failed", out_node->GetName().c_str());
return FAILED;
}
removed_node_count++;
// update allreduce input/output desc
GE_CHECK_NOTNULL(node->GetOpDesc());
GE_CHECK_NOTNULL(in_node->GetOpDesc());
GE_CHECK_NOTNULL(out_node->GetOpDesc());
auto input_desc = in_node->GetOpDesc()->GetInputDesc(0);
auto output_desc = out_node->GetOpDesc()->GetOutputDesc(0);
if (node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(i), input_desc) != GRAPH_SUCCESS) {
GELOGE(FAILED, "UpdateInputDesc fail.");
}
if (node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(i), output_desc) != GRAPH_SUCCESS) {
GELOGE(FAILED, "UpdateOutputDesc");
}
GELOGI("successfully remove paired transop (%s and %s) for node %s",
in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str());
}
GELOGI("successfully remove %zu pair of transops in total for node %s", removed_node_count, node->GetName().c_str());
return SUCCESS;
}
} // namespace ge