parallel group

pull/1162/head
陈华 4 years ago
parent 6d51265781
commit cf101e0aa2

@ -320,6 +320,7 @@ set(TRAIN_SRC_LIST
"graph/passes/variable_ref_useless_control_out_delete_pass.cc"
"graph/passes/end_of_sequence_add_control_pass.cc"
"graph/passes/memcpy_addr_async_pass.cc"
"graph/passes/parallel_group_pass.cc"
"graph/passes/set_input_output_offset_pass.cc"
"graph/preprocess/graph_preprocess.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc"
@ -607,6 +608,7 @@ set(INFER_SRC_LIST
"graph/passes/hccl_group_pass.cc"
"graph/passes/memcpy_addr_async_pass.cc"
"graph/passes/set_input_output_offset_pass.cc"
"graph/passes/parallel_group_pass.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"

@ -376,6 +376,48 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr
return SUCCESS;
}
Status UpdateForParallelGroupPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) {
std::map<int, vector<OpDescPtr>> stream_op_map;
for (const SubgraphPtr &subgraph : subgraphs) {
auto compute_graph = subgraph->subgraph_info.GetSubGraph();
for (const NodePtr &node : compute_graph->GetDirectNode()) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (op_desc->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
int64_t op_desc_stream_id = op_desc->GetStreamId();
stream_op_map[op_desc_stream_id].push_back(op_desc);
}
}
}
for (const auto &itr : stream_op_map) {
if (itr.first == kInvalidStream) {
continue;
}
std::map<std::string, int64_t> group_2_stream_id;
for (const auto &op_desc : itr.second) {
std::string group_name;
if (!AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) {
GELOGE(FAILED, "[GetAttr][OpDesc]Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str());
return FAILED;
}
const auto &itr = group_2_stream_id.find(group_name);
int64_t new_stream_id = kInvalidStream;
int64_t old_stream_id = op_desc->GetStreamId();
if (itr != group_2_stream_id.end()) {
new_stream_id = itr->second;
} else {
new_stream_id = context.next_stream++;
group_2_stream_id[group_name] = new_stream_id;
}
op_desc->SetStreamId(new_stream_id);
GELOGD("Node %s assigned stream %ld from stream %ld.",
op_desc->GetName().c_str(), new_stream_id, old_stream_id);
}
}
return SUCCESS;
}
int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const {
set<int64_t> stream_ids;
@ -665,6 +707,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec
passes.emplace_back(MakeShared<IndependentStreamPass>());
passes.emplace_back(MakeShared<AssignByDependencyPass>());
passes.emplace_back(MakeShared<NodeStreamUpdatePass>());
passes.emplace_back(MakeShared<UpdateForParallelGroupPass>());
passes.emplace_back(MakeShared<AllReduceParallelPass>());
passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>());
}

@ -149,6 +149,13 @@ class NodeStreamUpdatePass : public LogicalStreamPass {
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
// assign stream by parallel group
class UpdateForParallelGroupPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
// Update the stream of subgraphs to nodes.
class UpdateForSkippedEnginePass : public LogicalStreamPass {
public:

@ -93,6 +93,7 @@
#include "graph/passes/global_step_insert_pass.h"
#include "graph/passes/memcpy_addr_async_pass.h"
#include "graph/passes/hccl_continuous_memcpy_pass.h"
#include "graph/passes/parallel_group_pass.h"
#include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h"
#include "inc/pass_manager.h"
@ -2381,6 +2382,12 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed.");
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run.");
// Handle parallel group .
GE_TIMESTAMP_START(ParallelGroup);
ParallelGroupPass parallel_group_pass;
GE_CHK_STATUS_RET(parallel_group_pass.Run(compute_graph), "Handle parallel group failed.");
GE_TIMESTAMP_END(ParallelGroup, "ParallelGroupPass::Run.");
// After while sub graph handle, mark all node rw type
auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph);
if (result != SUCCESS) {

@ -22,6 +22,10 @@
using std::string;
namespace ge {
namespace {
const int64_t kLoopType = 1;
}
Status NextIterationPass::Run(ComputeGraphPtr graph) {
GELOGD("NextIterationPass Enter");
/// Enter-----------+
@ -121,7 +125,10 @@ Status NextIterationPass::FindWhileGroups() {
if (switch_node == nullptr) {
continue;
}
if (!AttrUtils::SetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, kLoopType)) {
GELOGE(INTERNAL_ERROR, "set int failed");
return INTERNAL_ERROR;
}
NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str());

File diff suppressed because it is too large Load Diff

@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H
#define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H
#include <map>
#include <unordered_set>
#include "graph/graph.h"
#include "inc/graph_pass.h"
namespace ge {
class ParallelGroupPass : public GraphPass {
public:
Status Run(ComputeGraphPtr graph) override;
private:
Status ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth, std::unordered_set<std::string> &parallel_group);
Status AddCtrlEdge(NodePtr pre_node, NodePtr cur_node);
Status ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node,
const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge);
bool HasSameSwitch(const std::set<NodePtr> &a, const std::set<NodePtr> &b);
Status ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge);
void FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes,
std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels);
Status MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_set, const std::vector<NodePtr> &merge_vec,
const NodePtr &cast_node, const NodePtr &switch_node,
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge);
bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc);
bool IsWhileStreamSwitch(OpDescPtr switch_op_desc);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H

@ -307,6 +307,13 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
hccl_group_id.c_str());
}
int64_t switch_type;
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, switch_type)) {
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, switch_type);
GELOGD("Set attr ATTR_NAME_STREAM_SWITCH_TYPE for Stream_Switch %s, value is %ld.", node_name.c_str(),
switch_type);
}
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) ||
!AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) {
GELOGE(INTERNAL_ERROR, "set int failed");

@ -273,6 +273,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc"
"${GE_CODE_DIR}/ge/model/ge_model.cc"
"${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc"
@ -518,6 +519,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc"
"${GE_CODE_DIR}/ge/graph/common/transop_util.cc"
"${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc"
#"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc"
@ -695,6 +697,7 @@ set(PASS_TEST_FILES
"graph/passes/multi_batch_clone_pass_unittest.cc"
"graph/passes/replace_with_empty_const_pass_unittest.cc"
"graph/passes/transpose_transdata_pass_unittest.cc"
"graph/passes/parallel_group_pass_unittest.cc"
)
set(KERNEL_TEST_FILES

@ -32,6 +32,7 @@
#include "graph/compute_graph.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"
using namespace std;
@ -153,6 +154,22 @@ class UtestLogicalStreamAllocator : public testing::Test {
return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num);
}
SubGraphInfoPtr CreateParallelGroupSubgraphWithName(const string &name, const string &engine,
const string &stream_label = "",
std::string group_name = "1") {
ComputeGraphPtr compute_graph = make_shared<ComputeGraph>(name);
OpDescPtr op_desc = std::make_shared<OpDesc>("relu", "Relu");
op_desc->AddInputDesc(GeTensorDesc());
op_desc->AddOutputDesc(GeTensorDesc());
AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name);
compute_graph->AddNode(op_desc);
SubGraphInfoPtr subgraph = BuildSubGraph(compute_graph, engine, stream_label);
AddPlaceHolderAndEnd(subgraph, 1, 1);
return subgraph;
}
void LinkSubGraph(SubGraphInfoPtr subgraph1, const string &end_name, SubGraphInfoPtr subgraph2,
const string &placeholder_name) {
NodePtr end_node = subgraph1->GetSubGraph()->FindNode(end_name);
@ -878,4 +895,30 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) {
EXPECT_EQ(ret, NOT_CHANGED);
}
TEST_F(UtestLogicalStreamAllocator, test_parallel_group) {
SubGraphInfoPtr data = CreateDataSubgraph();
SubGraphInfoPtr subgraph1 = CreateParallelGroupSubgraphWithName("graph1", "engine1", "");
SubGraphInfoPtr subgraph2 = CreateParallelGroupSubgraphWithName("graph2", "engine2", "", "2");
SubGraphInfoPtr subgraph3 = CreateParallelGroupSubgraphWithName("graph3", "engine3", "", "3");
SubGraphInfoPtr subgraph4 = CreateParallelGroupSubgraphWithName("graph4", "engine4", "", "4");
LinkSubGraph(data, "end", subgraph1, "placeholder");
LinkSubGraph(subgraph1, "end", subgraph2, "placeholder");
LinkSubGraph(subgraph2, "end", subgraph3, "placeholder");
LinkSubGraph(subgraph3, "end", subgraph4, "placeholder");
EngineConfPtr conf1 = make_shared<EngineConf>();
conf1->id = subgraph1->GetEngineName();
EngineConfPtr conf2 = make_shared<EngineConf>();
conf2->id = subgraph2->GetEngineName();
conf2->attach = false;
EngineConfPtr conf3 = make_shared<EngineConf>();
conf3->id = subgraph3->GetEngineName();
conf3->attach = false;
EngineConfPtr conf4 = make_shared<EngineConf>();
conf4->id = subgraph4->GetEngineName();
Status status = AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4}, {conf1, conf2, conf3, conf4});
EXPECT_EQ(status, ge::SUCCESS);
}
} // namespace ge

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save