You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/set_input_output_offset_pas...

285 lines
13 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/set_input_output_offset_pass.h"
#include "runtime/mem.h"
namespace ge {
Status SetInputOutputOffsetPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
for (auto &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
vector<int> connect_input;
(void)AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, connect_input);
if (!connect_input.empty()) {
Status ret = SetInputOffset(node, connect_input);
if (ret != SUCCESS) {
GELOGE(ret, "SetInputOffset failed.");
return ret;
}
}
vector<int> connect_output;
(void)AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_NODE_CONNECT_OUTPUT, connect_output);
if (!connect_output.empty()) {
Status ret = SetOutputOffset(node, connect_output);
if (ret != SUCCESS) {
GELOGE(ret, "SetOutputOffset failed.");
return ret;
}
}
}
return SUCCESS;
}
Status SetInputOutputOffsetPass::SetInputOffsetForFusion(const std::vector<int64_t> &memory_type,
const ge::NodePtr &node) {
GELOGI("Start to SetInputOffsetForFusion for %s", node->GetName().c_str());
auto op_desc = node->GetOpDesc();
for (size_t i = 0; i < memory_type.size(); ++i) {
if (memory_type.at(i) != RT_MEMORY_L1) {
std::vector<int64_t> input_offset_of_node;
input_offset_of_node = op_desc->GetInputOffset();
if (input_offset_of_node.size() < i) {
GELOGE(PARAM_INVALID, "not get input_offset of %zu", i);
return PARAM_INVALID;
}
int64_t input_offset = input_offset_of_node.at(i);
GELOGI("input_offset of %s is %ld.", node->GetName().c_str(), input_offset);
auto in_anchor = node->GetInDataAnchor(i);
GE_IF_BOOL_EXEC(in_anchor == nullptr, continue);
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
int out_index = peer_out_anchor->GetIdx();
auto data_op_desc = peer_out_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL(data_op_desc);
int64_t out_offset = data_op_desc->GetOutputOffset().at(out_index);
GELOGI("output_offset of %s is %ld.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), out_offset);
vector<int64_t> zero_copy_basic_offset;
vector<int64_t> zero_copy_relative_offset;
(void)ge::AttrUtils::GetListInt(data_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset);
(void)ge::AttrUtils::GetListInt(data_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset);
zero_copy_basic_offset.emplace_back(out_offset);
int64_t relative_offset = input_offset - out_offset;
zero_copy_relative_offset.emplace_back(relative_offset);
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(data_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset),
GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed.");
return FAILED);
GE_CHK_BOOL_EXEC(
ge::AttrUtils::SetListInt(data_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset),
GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed.");
return FAILED);
}
}
return SUCCESS;
}
Status SetInputOutputOffsetPass::SetInputOffsetForHcom(const ge::NodePtr &node, const vector<int> &connect_input) {
GELOGI("Start SetInputOffsetForHcom for %s.", node->GetName().c_str());
auto op_desc = node->GetOpDesc();
vector<int64_t> input_offset_of_node;
input_offset_of_node = node->GetOpDesc()->GetInputOffset();
for (size_t input_index = 0; input_index < connect_input.size(); ++input_index) {
int connect_input_index = connect_input.at(input_index);
int64_t input_offset = input_offset_of_node.at(connect_input_index);
NodePtr in_data = node->GetInDataNodes().at(connect_input_index);
auto in_op_desc = in_data->GetOpDesc();
GE_CHECK_NOTNULL(in_op_desc);
if (in_op_desc->GetType() == DATA) {
int64_t output_offset = in_op_desc->GetOutputOffset().at(0);
if (output_offset == input_offset) {
continue;
} else {
vector<int64_t> zero_copy_basic_offset;
vector<int64_t> zero_copy_relative_offset;
(void)ge::AttrUtils::GetListInt(in_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset);
(void)ge::AttrUtils::GetListInt(in_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset);
GELOGI("input offset from %s to %s is %ld to %ld.", in_op_desc->GetName().c_str(), op_desc->GetName().c_str(),
output_offset, input_offset);
int64_t relative_offset = input_offset - output_offset;
zero_copy_basic_offset.emplace_back(output_offset);
zero_copy_relative_offset.emplace_back(relative_offset);
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(in_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset),
GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed.");
return FAILED);
GE_CHK_BOOL_EXEC(
ge::AttrUtils::SetListInt(in_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset),
GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed.");
return FAILED);
}
}
}
return SUCCESS;
}
Status SetInputOutputOffsetPass::SetInputOffset(const NodePtr &node, const vector<int> &connect_input) {
GELOGI("Start to SetInputOffset for %s.", node->GetName().c_str());
std::vector<int64_t> memory_type;
auto op_desc = node->GetOpDesc();
(void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, memory_type);
if (!memory_type.empty()) {
Status ret = SetInputOffsetForFusion(memory_type, node);
if (ret != SUCCESS) {
GELOGE(ret, "SetInputOffsetForFusion failed.");
return ret;
}
}
// Data->Hcom
bool is_input_continuous = false;
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);
if (is_input_continuous) {
Status ret = SetInputOffsetForHcom(node, connect_input);
if (ret != SUCCESS) {
GELOGE(ret, "SetInputOffsetForHcom failed.");
return ret;
}
}
return SUCCESS;
}
Status SetInputOutputOffsetPass::SetOutputOffsetForConcat(const NodePtr &node) {
GELOGI("Start SetOutputOffsetForConcat for %s.", node->GetName().c_str());
auto op_desc = node->GetOpDesc();
std::vector<int64_t> output_offset_of_concat;
output_offset_of_concat = op_desc->GetOutputOffset();
// phony_concat has one output
GE_IF_BOOL_EXEC(output_offset_of_concat.size() != 1,
GELOGE(PARAM_INVALID, "%s should has one output.", node->GetName().c_str());
return PARAM_INVALID);
NodePtr net_output = node->GetOutDataNodes().at(0);
auto out_op_desc = net_output->GetOpDesc();
GE_CHECK_NOTNULL(out_op_desc);
vector<int64_t> zero_copy_basic_offset;
vector<int64_t> zero_copy_relative_offset;
(void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset);
(void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset);
int64_t basic_offset = output_offset_of_concat.at(0);
GELOGI("output_offset of %s is %ld.", op_desc->GetName().c_str(), basic_offset);
for (InDataAnchorPtr &in_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
NodePtr in_node = peer_out_anchor->GetOwnerNode();
auto out_index = peer_out_anchor->GetIdx();
std::vector<int64_t> output_offset_of_in_node;
GE_CHECK_NOTNULL(in_node->GetOpDesc());
output_offset_of_in_node = in_node->GetOpDesc()->GetOutputOffset();
GELOGI("input offset from %s to %s is %ld.", in_node->GetName().c_str(), op_desc->GetName().c_str(),
output_offset_of_in_node.at(out_index));
int64_t relative_offset = output_offset_of_in_node.at(out_index) - basic_offset;
zero_copy_basic_offset.emplace_back(basic_offset);
zero_copy_relative_offset.emplace_back(relative_offset);
}
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset),
GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed.");
return FAILED);
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset),
GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed.");
return FAILED);
return SUCCESS;
}
Status SetInputOutputOffsetPass::SetOutputOffsetForHcom(const NodePtr &node, const vector<int> &connect_output) {
GELOGI("Start SetOutputOffsetForHcom, %s connect with %zu output.", node->GetName().c_str(), connect_output.size());
vector<int64_t> output_offset_of_node;
output_offset_of_node = node->GetOpDesc()->GetOutputOffset();
int connect_output_index = connect_output.at(0);
int64_t basic_offset = output_offset_of_node.at(connect_output_index);
GELOGI("basic_offset of %s is %ld.", node->GetName().c_str(), basic_offset);
NodePtr net_output = node->GetOutDataNodes().at(connect_output_index);
auto out_op_desc = net_output->GetOpDesc();
GE_CHECK_NOTNULL(out_op_desc);
vector<int64_t> zero_copy_basic_offset;
vector<int64_t> zero_copy_relative_offset;
(void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset);
(void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset);
for (auto &out_anchor : node->GetAllOutDataAnchors()) {
GE_IF_BOOL_EXEC(out_anchor == nullptr, continue);
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_IF_BOOL_EXEC(in_anchor == nullptr, continue);
if (in_anchor->GetOwnerNode()->GetType() == NETOUTPUT && out_anchor->GetIdx() != connect_output_index) {
continue;
} else {
NodePtr out_node = in_anchor->GetOwnerNode();
auto in_index = in_anchor->GetIdx();
std::vector<int64_t> input_offset_of_out_node;
GE_CHECK_NOTNULL(out_node->GetOpDesc());
input_offset_of_out_node = out_node->GetOpDesc()->GetInputOffset();
GELOGI("input offset from %s to %s is %ld.", node->GetName().c_str(), out_node->GetName().c_str(),
input_offset_of_out_node.at(in_index));
int64_t relative_offset = input_offset_of_out_node.at(in_index) - basic_offset;
zero_copy_basic_offset.emplace_back(basic_offset);
zero_copy_relative_offset.emplace_back(relative_offset);
}
}
}
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset),
GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed.");
return FAILED);
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset),
GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed.");
return FAILED);
return SUCCESS;
}
Status SetInputOutputOffsetPass::SetOutputOffset(const NodePtr &node, const vector<int> &connect_output) {
GELOGI("Start SetOutputOffset of %s.", node->GetName().c_str());
bool attr_no_task = false;
bool get_attr_no_task = ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NOTASK, attr_no_task);
if (get_attr_no_task && attr_no_task) {
bool is_input_continuous = false;
(void)ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);
bool buffer_fusion = CheckBufferFusion(node);
// A/B/C -> Phony_concat -> Netoutput : input_continuous
if (is_input_continuous || buffer_fusion) {
Status ret = SetOutputOffsetForConcat(node);
if (ret != SUCCESS) {
GELOGE(ret, "SetOutputOffsetForConcat failed.");
return ret;
}
}
}
// allreduce->netoutput : output_continuous
bool is_output_continuous = false;
(void)ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_OUTPUT, is_output_continuous);
if (is_output_continuous) {
Status ret = SetOutputOffsetForHcom(node, connect_output);
if (ret != SUCCESS) {
GELOGE(ret, "SetOutputOffsetForHcom failed.");
return ret;
}
}
return SUCCESS;
}
bool SetInputOutputOffsetPass::CheckBufferFusion(const NodePtr &node) {
for (auto &in_node : node->GetInDataNodes()) {
GE_CHECK_NOTNULL(in_node);
auto op_desc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!op_desc->HasAttr(ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION)) {
GELOGI("The node: %s not have ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION.", node->GetName().c_str());
return false;
}
}
return true;
}
} // namespace ge