From d078cbfa99ceeef03c3ab8bd122f45da70d2f539 Mon Sep 17 00:00:00 2001 From: tronzhang <6517937+tronzhang@user.noreply.gitee.com> Date: Tue, 26 Jan 2021 10:25:29 +0800 Subject: [PATCH] support parallel fusion --- akg | 2 +- mindspore/_extends/graph_kernel/__init__.py | 3 +- .../_extends/graph_kernel/model/__init__.py | 3 +- .../graph_kernel/model/graph_parallel.py | 153 +++ .../graph_kernel/parallel_estimate.py | 49 + .../akg/akg_kernel_json_generator.cc | 55 +- .../akg/akg_kernel_json_generator.h | 14 +- .../graph_kernel/add_atomic_clean_gpu.cc | 6 +- .../optimizer/graph_kernel/depend_formater.cc | 155 ++++ .../optimizer/graph_kernel/depend_formater.h | 37 + .../graph_kernel/graph_kernel_helper.cc | 15 +- .../graph_kernel/graph_kernel_helper.h | 5 +- .../graph_kernel/parallel_cost_model.cc | 89 ++ .../graph_kernel/parallel_cost_model.h | 82 ++ .../optimizer/graph_kernel/parallel_fusion.cc | 876 ++++++++++++++++++ .../optimizer/graph_kernel/parallel_fusion.h | 122 +++ .../ccsrc/backend/session/gpu_session.cc | 6 +- mindspore/ccsrc/utils/utils.h | 3 +- .../graph_kernel/model/test_graph_parallel.py | 54 ++ 19 files changed, 1711 insertions(+), 18 deletions(-) create mode 100644 mindspore/_extends/graph_kernel/model/graph_parallel.py create mode 100644 mindspore/_extends/graph_kernel/parallel_estimate.py create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.h create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h create mode 100644 tests/st/graph_kernel/model/test_graph_parallel.py diff --git a/akg b/akg index 20ecddee01..c63b2e6f7e 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 20ecddee01cd07d0945240672597d7a36499e537 +Subproject commit c63b2e6f7e7704f18b217e42c8c5c0b95e04b9fb diff --git a/mindspore/_extends/graph_kernel/__init__.py b/mindspore/_extends/graph_kernel/__init__.py index 5ae3f3826f..44d79f4071 100644 --- a/mindspore/_extends/graph_kernel/__init__.py +++ b/mindspore/_extends/graph_kernel/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,3 +15,4 @@ """init""" from .splitter import split_with_json from .expander import get_op_expander +from .parallel_estimate import estimate_calulation_amount, estimate_ops diff --git a/mindspore/_extends/graph_kernel/model/__init__.py b/mindspore/_extends/graph_kernel/model/__init__.py index fb4d2b63a8..8125a8b18a 100644 --- a/mindspore/_extends/graph_kernel/model/__init__.py +++ b/mindspore/_extends/graph_kernel/model/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +16,4 @@ from .graph_split import split from .model_builder import GraphBuilder, load_composite +from .graph_parallel import parallel_estimate diff --git a/mindspore/_extends/graph_kernel/model/graph_parallel.py b/mindspore/_extends/graph_kernel/model/graph_parallel.py new file mode 100644 index 0000000000..93360ae5df --- /dev/null +++ b/mindspore/_extends/graph_kernel/model/graph_parallel.py @@ -0,0 +1,153 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""Cost model for parallel fusion""" +from .model import PrimLib + + +class ParalGain: + def __init__(self, fusion_type, bottleneck, gain, block_assign): + self.fusion_type = fusion_type + self.bottleneck = bottleneck + self.gain = gain + self.block_assign = block_assign + + +class ScheduleAnalyzer: + """schedule analyzer""" + WRAP_SIZE = 32 + MAX_SM = 80 # Volta + MAX_NUM_THREADS = 1024 + MAX_BLOCK = 256 + + def __init__(self, graph): + self.graph = graph + self.block_num = 0 + self.block_weight = 0 + _, outputs = graph.deduce_parameters() + self.ops = graph.ops + self.dom_op = [out.op for out in outputs] + + def prod(self, shape): + res = shape[0] + for i in range(1, len(shape)): + res = res * shape[i] + return res + + def _cal_weight(self, ops): + weight = 0 + for op in ops: + weight += self.prod(op.output.shape) * \ + PrimLib.dtype_bytes(op.output.dtype) + return weight + + def injective_analyze(self): + """analyze injective case""" + const_size = max([self.prod(op.output.shape) for op in self.dom_op]) + const_size = (const_size + self.MAX_NUM_THREADS - + 1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS + + total_weight = self._cal_weight(self.ops) + total_block = (const_size + self.MAX_NUM_THREADS - + 1) // self.MAX_NUM_THREADS + need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS + if need_block_split: + self.block_num = self.MAX_BLOCK + waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK + self.block_weight = total_weight // total_block * waves + else: + self.block_num = total_block + self.block_weight = total_weight // self.block_num + + def reduce_analyze(self): + """analyze reduce case""" + thread_x, thread_y = 32, 32 + reduce_op = None + for op in self.ops: + if PrimLib.iter_type(op) == PrimLib.REDUCE: + if reduce_op: + raise RuntimeError( + "Not support multiply reduce op in a graph now.") + reduce_op = op + if not reduce_op: + raise RuntimeError("Wrong analyze for reduce!") + shape = reduce_op.inputs[0].shape + reduce_axis = reduce_op.attrs['reduce_axis'] + total_space = self.prod(shape) + red_space = shape[reduce_axis[0]] + for i in range(1, len(reduce_axis)): + red_space *= shape[reduce_axis[i]] + dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype) + + weight = self._cal_weight(self.ops) # reduce + injective + block_x = (total_space // red_space + thread_y - 1) // thread_y + block_w = (weight + block_x - 1) // block_x + waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK + self.block_num = min(self.MAX_BLOCK, block_x) + all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write + self.block_weight = (block_w + all_reduce * + dtype_size * thread_x * thread_y) * waves + + def default_analyze(self): + """analyze default case""" + def _cal_default_space(op): + space = self.prod(op.output.shape) + for t in op.inputs: + size = self.prod(t.shape) + if size > space: + space = size + return space + space = max([_cal_default_space(op) for op in self.dom_op]) + + # each sm least 4 wrap + block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4) + self.block_num = min(self.MAX_BLOCK, block) + self.block_weight = self._cal_weight(self.ops) // self.block_num + + def analyze(self): + """analyze ops""" + def _ops_type(ops, dom_op): + have_reduce = any( + [PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops]) + if have_reduce: + return True + return PrimLib.iter_type(dom_op[0]) + + dom_type = _ops_type(self.ops, self.dom_op) + if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST): + self.injective_analyze() + elif dom_type == PrimLib.REDUCE: + self.reduce_analyze() + else: + self.default_analyze() + + +def block_parallel_estimate(graphs): + """estimate block parallel gain""" + sum_block, max_weight, sum_weight, blocks = 0, 0, 0, [] + for g in graphs: + s = ScheduleAnalyzer(g) + s.analyze() + sum_block += s.block_num + if s.block_weight > max_weight: + max_weight = s.block_weight + sum_weight += s.block_weight + blocks.append(s.block_num) + if sum_block > ScheduleAnalyzer.MAX_SM * 32: + return ParalGain("none", sum_weight, 0, []) + return ParalGain("block_fusion", max_weight, sum_weight - max_weight, blocks) + + +def parallel_estimate(graphs): + return block_parallel_estimate(graphs) diff --git a/mindspore/_extends/graph_kernel/parallel_estimate.py b/mindspore/_extends/graph_kernel/parallel_estimate.py new file mode 100644 index 0000000000..593eb558e9 --- /dev/null +++ b/mindspore/_extends/graph_kernel/parallel_estimate.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""estimate parallel case""" +import json +import json.decoder as jd +import traceback +from mindspore import log as logger +from . import model + +def estimate_ops(json_str: str): + """Call costmodel to estimate ops.""" + try: + json_obj = json.loads(json_str) + graph_descs = json_obj["graph_desc"] + graphs = [] + for gd in graph_descs: + graphs.append(model.load_composite(gd).graph) + estimation = model.parallel_estimate(graphs) + if estimation.fusion_type == "block_fusion" and estimation.gain > 0: + res = (estimation.block_assign, estimation.gain) + else: + res = ([0 for g in graphs], 0) + return res + except jd.JSONDecodeError: + logger.error(traceback.format_exc()) + return None + +def estimate_calulation_amount(json_str: str): + """Call costmodel to estimate calculation amount of op.""" + try: + graph_desc = json.loads(json_str) + comp = model.load_composite(graph_desc) + estimation = model.parallel_estimate([comp.graph]) + return estimation.bottleneck + except jd.JSONDecodeError: + logger.error(traceback.format_exc()) + return None diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index b56723277e..6d959adecc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -120,7 +120,7 @@ class OpInfoExtractor { } } if (op_attr->type().empty()) { - MS_LOG(DEBUG) << "Unknow type, ignore attr " << name; + MS_LOG(DEBUG) << "Unknown type, ignore attr " << name; continue; } op_info->add_attrs_ptr(op_attr); @@ -174,7 +174,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. auto inputs_ptr = op_info->inputs_ptr(); if (inputs_ptr.empty()) { - MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info"; + MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] info has no input info"; return false; } @@ -184,7 +184,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con for (size_t i = 0; i < inputs_ptr.size(); i++) { auto input_ptr = inputs_ptr[i]; if (input_ptr == nullptr) { - MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist input[" << i << "] is nullptr"; + MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] input[" << i << "] is nullptr"; return false; } @@ -204,7 +204,8 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con input_desc_json[kJsonKeyName] = input_ptr->name(); input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); auto input_shape = this->GetInputShape(anf_node, real_input_index); - if (AnfAlgo::IsNodeInGraphKernel(anf_node) && GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { + if (dump_option_.extract_opinfo_from_anfnode && + GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2) << "] as const tensor, shape: [" << Vector2Str(input_shape) << "], value: " << input_desc_json[kJsonKeyValue]; @@ -555,6 +556,30 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j return true; } +void AkgKernelJsonGenerator::SetParallelValueToJson(const std::string &processor, + const std::map &dim_infos, + nlohmann::json *sub_fusion_json) { + if (processor == kProcessorCuda) { + std::vector cnums; + std::transform(dim_infos.cbegin(), dim_infos.cend(), std::back_insert_iterator(cnums), + [](const std::pair &dim) { return dim.second; }); + (*sub_fusion_json)[kJsonKeyCoreNum] = cnums; + } else { + MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now."; + } +} + +void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json) { + nlohmann::json parallel_fusion_json; + parallel_fusion_json[kJsonKeyFusionType] = "block_fusion"; + std::vector> sgraphs; + std::transform(sub_graphs_.cbegin(), sub_graphs_.cend(), std::back_insert_iterator(sgraphs), + [](const std::pair> &sg) { return sg.second; }); + parallel_fusion_json[kJsonKeySubGraph] = sgraphs; + SetParallelValueToJson(processor, dim_infos_, ¶llel_fusion_json); + (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; +} + bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf_nodes, const std::vector &input_list, const std::vector &output_list, nlohmann::json *kernel_json) { @@ -581,6 +606,13 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf (*kernel_json)[kJsonKeyOutputDesc] = CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); + auto processor = GetProcessorStr(anf_nodes[0]); + + // Add parallel fusion information. + if (!sub_graphs_.empty()) { + AddParalleFusionJsonInfo(processor, kernel_json); + } + size_t hash_id = std::hash()(kernel_json->dump()); kernel_name_ = "Fused_"; auto fg = anf_nodes[0]->func_graph(); @@ -601,7 +633,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf (*kernel_json)[kJsonKeyId] = GetOpCntInc(); (*kernel_json)[kJsonKeyOp] = kernel_name_; (*kernel_json)[kJsonKeyPlatform] = "AKG"; - (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); + (*kernel_json)[kJsonKeyProcess] = processor; (*kernel_json)[kJsonKeyComposite] = true; (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); @@ -724,6 +756,17 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vectorcast(); + tcnode && AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) { + auto info = AnfAlgo::GetNodeAttr>(tcnode, kAttrParallelDimInfo); + if (info.size() != 2) { + MS_LOG(EXCEPTION) << "Parallel dim info is invalid!"; + } + sub_graphs_[info[0]].push_back(output_desc_json[kJsonKeyTensorName]); + if (dim_infos_.find(info[0]) == dim_infos_.end()) { + dim_infos_[info[0]] = info[1]; + } + } } outputs_json.emplace_back(output_desc_json); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h index dde8cdbd32..9f6f49cccc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,11 @@ constexpr auto kJsonKeyPtrAddress = "ptr_address"; constexpr auto kJsonKeyCompositeGraph = "composite_graph"; constexpr auto kJsonKeyPlatform = "platform"; constexpr auto kJsonKeyOpFullName = "op_full_name"; +constexpr auto kJsonKeyFusion = "fusion"; +constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; +constexpr auto kJsonKeyFusionType = "fusion_type"; +constexpr auto kJsonKeySubGraph = "sub_graph"; +constexpr auto kJsonKeyCoreNum = "core_num"; constexpr auto kAttrInputNames = "input_names"; @@ -81,6 +86,8 @@ class AkgKernelJsonGenerator { input_tensor_idx_.clear(); address_node_map_.clear(); output_tensor_idx_ = 0; + sub_graphs_.clear(); + dim_infos_.clear(); } void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; } std::map address_node_map() { return address_node_map_; } @@ -115,6 +122,9 @@ class AkgKernelJsonGenerator { std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index); void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json); OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node); + void SetParallelValueToJson(const std::string &processor, const std::map &dim_infos, + nlohmann::json *sub_fusion_json); + void AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json); DumpOption dump_option_; static int op_cnt_; @@ -127,6 +137,8 @@ class AkgKernelJsonGenerator { std::vector input_size_list_; std::vector output_size_list_; std::map address_node_map_; + std::map> sub_graphs_; + std::map dim_infos_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc index 2b73ff8840..ea9fa60ebc 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc @@ -133,8 +133,10 @@ bool AtomicCleanInsertter::CanActivateAtomicAdd(const AnfNodePtr &anf_node) { if (reduce_cnt != 1) { return false; } + real_output_num_ = inputs.size() - 1; } else if (IsPrimitiveCNode(real_return_node, prim::kPrimReduceSum)) { atomic_add_node_ = real_return_node->cast(); + real_output_num_ = 1; } else { return false; } @@ -200,7 +202,6 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex); if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) { const auto &outs = retrun_node->cast()->inputs(); - real_output_num_ = outs.size() - 1; for (size_t i = 1; i < outs.size(); ++i) { if (i != reduce_real_output_index_ + 1) { out_node = outs[i]; @@ -209,7 +210,6 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra } } } else { - real_output_num_ = 1; out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. fake_out = true; } @@ -456,7 +456,7 @@ std::vector > AtomicCleanInsertter::FindOriginCNodeUs } } for (auto &pair : getitem_user_nodes) { - // dirctory to find real user. + // Directory to find real user. auto real_users = mng->node_users()[pair.first]; reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end()); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc new file mode 100644 index 0000000000..ef3429e451 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/depend_formater.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool RemoveRedundantDepend(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { + const auto &users = mng->node_users()[node]; + std::vector> sons; + for (const auto &[user, index] : users) { + if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) { + sons.emplace_back(user, index); + continue; + } + auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin()); + sons.emplace_back(fake_first_grad_son, grad_index); + } + + AnfNodePtrList latter_to_delete; + for (const auto &[son, index] : sons) { + if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) { + continue; + } + + latter_to_delete.push_back(son); + } + + if (latter_to_delete.empty()) { + return false; + } + + std::vector::iterator delete_begin = latter_to_delete.begin(); + if (latter_to_delete.size() == sons.size()) { + // Left one Depend node relation and delete others! + ++delete_begin; + } + for (; delete_begin != latter_to_delete.end(); ++delete_begin) { + auto depend_anfnode = *delete_begin; + auto depend_cnode = depend_anfnode->cast(); + MS_EXCEPTION_IF_NULL(depend_cnode); + auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend); + mng->Replace(depend_anfnode, depend_prior_node); + } + return true; +} + +AnfNodePtr FindPatronNode(const FuncGraphPtr &main_graph, const FuncGraphManagerPtr &mng) { + AnfNodePtr patron_node; + + auto return_cnode = main_graph->get_return()->cast(); + MS_EXCEPTION_IF_NULL(return_cnode); + auto output_node = return_cnode->input(kFirstDataInputIndex); + if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) { + auto output_cnode = output_node->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + patron_node = output_cnode->input(kFirstDataInputIndex); + } else { + patron_node = output_node; + } + + return patron_node; +} + +void AddDepends(const AnfNodePtr &stable_node, const AnfNodePtrList &free_nodes, const FuncGraphPtr &main_graph, + const FuncGraphManagerPtr &mng) { + AnfNodePtr modified_node = stable_node; + for (const auto &free_node : free_nodes) { + AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), modified_node, free_node}; + auto depend_cnode = main_graph->NewCNode(d_inputs); + depend_cnode->set_abstract(modified_node->abstract()); + main_graph->AddNode(depend_cnode); + modified_node = depend_cnode; + } + + if (!free_nodes.empty()) { + mng->Replace(stable_node, modified_node); + } +} +} // namespace + +bool DependFormater::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + + // 1. Try to remove redundant depend. + bool changed = false; + auto nodes = TopoSort(func_graph->get_return()); + std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) { + if (RemoveRedundantDepend(node, mng)) { + changed = true; + } + }); + + // Should re-toposort for changed graph. + if (changed) { + nodes = TopoSort(func_graph->get_return()); + } + + // 2. Move depend to tail of graph. + AnfNodePtrList old_depends; + AnfNodePtrList free_nodes; + + // Find depend and its free nodes. + for (const auto &node : nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { + continue; + } + + old_depends.push_back(node); + free_nodes.push_back(node->cast()->input(kDependAttachNodeIndex)); + } + + if (old_depends.empty()) { + return changed; + } + + // Delete old depend. + for (const auto &depend_anfnode : old_depends) { + auto depend_cnode = depend_anfnode->cast(); + MS_EXCEPTION_IF_NULL(depend_cnode); + auto depend_prior_node = depend_cnode->input(kControlDependPriorIndex); + mng->Replace(depend_anfnode, depend_prior_node); + } + + // Add new depend node in tail. + AnfNodePtr patron_node = FindPatronNode(func_graph, mng); + AddDepends(patron_node, free_nodes, func_graph, mng); + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.h new file mode 100644 index 0000000000..23aec65aeb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.h @@ -0,0 +1,37 @@ + +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ + +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" + +namespace mindspore { +namespace opt { +class DependFormater : public Pass { + public: + DependFormater() : Pass("depend_formater") {} + ~DependFormater() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +using DependFormaterPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 3f4948e9ce..5134c2269f 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -274,7 +274,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i MS_EXCEPTION_IF_NULL(inputs_ptr); auto nodes = TopoSort(fg->get_return()); - std::map vmap; + OrderedMap vmap; for (const auto &node : nodes) { if (!node->isa()) { continue; @@ -590,7 +590,7 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n op_nodes = nodes; } else { // When there are basic and composite ops, the composite ops should be inline to the basic ones' graph, - // so a new graph generation should be done (beacuse they may in the main graph!). + // so a new graph generation should be done (because they may in the main graph!). // If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now. MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!"; } @@ -1016,5 +1016,16 @@ CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr & func_graph->AddNode(cnode); return cnode; } + +void MakeCNodeSafeForAttr(const AnfNodePtr &node) { + auto cnode = node->cast(); + if (cnode == nullptr) { + return; + } + AnfNodePtrList new_inputs = {NewValueNode(AnfAlgo::GetCNodePrimitive(cnode)->Clone())}; + auto inputs = cnode->inputs(); + new_inputs.insert(new_inputs.end(), inputs.begin() + 1, inputs.end()); + cnode->set_inputs(new_inputs); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index d04d4c5aae..695d8bcb22 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,6 +42,8 @@ using kernel::DumpOption; constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel"; +constexpr auto kGraphKernelEstimateOps = "estimate_ops"; +constexpr auto kGraphKernelGetNodeCalAmount = "estimate_calulation_amount"; constexpr auto kGraphKernelSplitFunc = "split_with_json"; constexpr auto kGetGraphKernelOpExpander = "get_op_expander"; constexpr auto kJsonKeyMultiGraph = "multi_graph"; @@ -88,6 +90,7 @@ ShapeVector GetShape(const AnfNodePtr &node); std::vector GetReduceAxis(const AnfNodePtr &node); CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); +void MakeCNodeSafeForAttr(const AnfNodePtr &node); template ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc new file mode 100644 index 0000000000..9a52d05699 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/graph_kernel/parallel_cost_model.h" + +#include + +#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +namespace opt { +std::string CommonDimInfo::ToString() { + std::ostringstream buffer; + buffer << "Dim(" << dim_info_ << ")"; + return buffer.str(); +} + +int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) { + nlohmann::json json_desc; + AnfNodePtrList nodes = {node}; + DumpOption dump_option; + if (!AnfToJsonDesc(nodes, dump_option, &json_desc)) { + MS_LOG(EXCEPTION) << "Collect json desc failed."; + } + + auto json_desc_str = json_desc.dump(); + auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelGetNodeCalAmount, json_desc_str); + if (py::isinstance(ret)) { + MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" + << json_desc_str; + } + return py::cast(ret); +} + +std::tuple, int> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { + nlohmann::json json_desc; + std::vector graphs; + std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), + [](const AnfNodePtr &node) -> AnfNodePtrList { return {node}; }); + DumpOption dump_option; + if (!AnfToJsonDesc(graphs, dump_option, &json_desc)) { + MS_LOG(EXCEPTION) << "Collect json desc failed."; + } + + auto json_desc_str = json_desc.dump(); + auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelEstimateOps, json_desc_str); + if (py::isinstance(ret)) { + MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" + << json_desc_str; + } + + py::tuple ret_tuple = py::cast(ret); + if (!py::isinstance(ret_tuple) || ret_tuple.size() != 2) { + MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!"; + } + + std::vector dim_infos; + py::list dim_list = py::cast(ret_tuple[0]); + for (size_t i = 0; i < dim_list.size(); ++i) { + dim_infos.push_back(std::make_shared(py::cast(dim_list[i]))); + } + int benefit = py::cast(ret_tuple[1]); + + return std::make_tuple(dim_infos, benefit); +} + +ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) { + if (target != kGPUDevice) { + MS_LOG(EXCEPTION) << "Parallel cost model only support " << kGPUDevice << " now."; + } + return cost_model_; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h new file mode 100644 index 0000000000..a5dd442b3f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h @@ -0,0 +1,82 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/graph_kernel/parallel_cost_model.h" +#include "backend/session/kernel_graph.h" +#include "utils/ms_context.h" + +namespace mindspore { +namespace opt { +class DimInfo { + public: + DimInfo() = default; + ~DimInfo() {} + virtual std::string ToString() = 0; +}; + +class CommonDimInfo : public DimInfo { + public: + explicit CommonDimInfo(size_t dim) : dim_info_(dim) {} + ~CommonDimInfo() {} + void set_dim_info(size_t d) { dim_info_ = d; } + size_t dim_info() const { return dim_info_; } + std::string ToString() override; + + private: + size_t dim_info_; +}; + +using DimInfoPtr = std::shared_ptr; +using CommonDimInfoPtr = std::shared_ptr; + +class ParallelCostModel { + public: + ParallelCostModel() {} + ~ParallelCostModel() {} + int GetNodeCalAmount(const AnfNodePtr &node); + std::tuple, int> CalFuseInfo(const AnfNodePtrList &nodes); +}; + +using ParallelCostModelPtr = std::shared_ptr; + +class ParellelCostModelWarehouse { + public: + static ParellelCostModelWarehouse &Instance() { + static ParellelCostModelWarehouse instance; + return instance; + } + ParallelCostModelPtr GetParallelCostModel(const std::string &target); + + private: + ParellelCostModelWarehouse() { cost_model_ = std::make_shared(); } + ParallelCostModelPtr cost_model_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc new file mode 100644 index 0000000000..d6389d8ecf --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc @@ -0,0 +1,876 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/graph_kernel/parallel_fusion.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "frontend/operator/ops.h" +#include "ir/func_graph_cloner.h" +#include "vm/segment_runner.h" + +namespace mindspore { +namespace opt { +namespace { +bool IsOneOf(const AnfNodePtr &node, const std::vector &ops_prim) { + return std::any_of(ops_prim.cbegin(), ops_prim.cend(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); +} + +void ProcessThroughPassCNode(std::function pass_fn, + OrderedMap *node_rels) { + std::set latter_to_be_erased; + for (const auto &[node, node_rel] : (*node_rels)) { + if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) { + continue; + } + + auto nexts = node_rel.nexts; + std::vector pre_nodes; + std::queue node_que; + node_que.push(node); + + // Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes. + while (!node_que.empty()) { + auto cur_node = node_que.front(); + node_que.pop(); + + if (!pass_fn(cur_node)) { + pre_nodes.push_back(cur_node); + continue; + } + + latter_to_be_erased.insert(cur_node); + auto predecessors = (*node_rels)[cur_node].pres; + if (predecessors.empty()) { + continue; + } + + for (const auto &pre_node : predecessors) { + (*node_rels)[cur_node].pres.erase(pre_node); + (*node_rels)[pre_node].nexts.erase(cur_node); + node_que.push(pre_node); + } + } + + // Modify the relation: delete node <-> next_node, add pre node <-> next_node. + for (const auto &next_node : nexts) { + (*node_rels)[next_node].pres.erase(node); + for (const auto &cur_node : pre_nodes) { + (*node_rels)[next_node].pres.insert(cur_node); + (*node_rels)[cur_node].nexts.insert(next_node); + } + } + } + + for (const auto &node : latter_to_be_erased) { + node_rels->erase(node); + } +} + +void ProcessDependCNode(OrderedMap *node_rels) { + for (auto &[node, node_rel] : (*node_rels)) { + if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { + continue; + } + + // Make attached nodes deattach with node. + auto cnode = node->cast(); + for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) { + auto attach_node = cnode->input(id); + if (auto iter = node_rels->find(attach_node); iter != node_rels->end()) { + iter->second.nexts.erase(node); + } + if (auto &cnode_pres = node_rel.pres; cnode_pres.count(attach_node) != 0) { + cnode_pres.erase(attach_node); + } + } + } + + // Eliminate depend node of node relations. + ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels); +} + +std::tuple, std::pair> FindRelationOfControlDepend( + const AnfNodePtr &node, OrderedMap *node_rels) { + auto cnode = node->cast(); + auto prior_node = cnode->input(kControlDependPriorIndex); + auto behind_node = cnode->input(kControlDependBehindIndex); + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(behind_node); + + OrderedSet prior_nodes; + prior_nodes.insert(prior_node); + OrderedSet behind_nodes; + behind_nodes.insert(behind_node); + + int64_t depend_mode = 0; + if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { + depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); + } + if (prior_node->isa() && depend_mode == 1) { + prior_nodes = (*node_rels)[prior_node].nexts; + } + if (behind_node->isa()) { + behind_nodes = depend_mode == 1 ? (*node_rels)[behind_node].nexts : OrderedSet(); + } + + // Get real nodes. + AnfNodePtrList real_prior_nodes; + std::set prior_visited; + for (const auto &tmp : prior_nodes) { + AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); + } + AnfNodePtrList real_behind_nodes; + std::set behind_visited; + for (const auto &tmp : behind_nodes) { + AnfAlgo::GetAllFatherRealNode(tmp, &real_behind_nodes, &behind_visited); + } + + return std::make_tuple(std::make_pair(prior_node, behind_node), std::make_pair(real_prior_nodes, real_behind_nodes)); +} + +void ReLinkNodesOfControlDependByRelation(const std::unordered_map &control_depend_info, + OrderedMap *node_rels) { + // Relink and its log. + for (const auto &m : control_depend_info) { + const auto &prior = m.second[0]; + const auto &behind = m.second[1]; + (*node_rels)[prior].nexts.insert(behind); + (*node_rels)[behind].pres.insert(prior); + MS_LOG(DEBUG) << "Relink relation of " << m.first->fullname_with_scope() << ": " << prior->fullname_with_scope() + << " -> " << behind->fullname_with_scope(); + } +} + +void ProcessControlDependCNode(OrderedMap *node_rels) { + std::unordered_map control_depend_info; + AnfNodePtrList latter_to_be_erased; + + // Collect ControlDepend node and its input and output nodes. + for (auto &[node, node_rel] : (*node_rels)) { + if (!IsPrimitiveCNode(node, prim::kPrimControlDepend)) { + continue; + } + + auto [direct_relation, real_relations] = FindRelationOfControlDepend(node, node_rels); + auto &[prior_node, behind_node] = direct_relation; + auto &[real_prior_nodes, real_behind_nodes] = real_relations; + + (*node_rels)[prior_node].nexts.erase(node); + (*node_rels)[behind_node].nexts.erase(node); + node_rel.pres.erase(prior_node); + node_rel.pres.erase(behind_node); + + for (auto &first_node : real_prior_nodes) { + for (auto &second_node : real_behind_nodes) { + MS_EXCEPTION_IF_NULL(first_node); + MS_EXCEPTION_IF_NULL(second_node); + control_depend_info.insert({node, {first_node, second_node}}); + } + } + latter_to_be_erased.push_back(node); + } + + // Delete ControlDepend node before relink its relation. + for (const auto &node : latter_to_be_erased) { + node_rels->erase(node); + } + + // Rebuild relation between prior and behind node. + ReLinkNodesOfControlDependByRelation(control_depend_info, node_rels); +} + +void ProcessTailMakeTupleCNode(OrderedMap *node_rels) { + AnfNodePtrList latter_to_be_erased; + for (auto &[node, node_rel] : (*node_rels)) { + if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + continue; + } + + AnfNodePtrList check_next_list; + check_next_list.push_back(node); + + bool disinterested = false; + for (auto &successor : node_rel.nexts) { + if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) { + disinterested = true; + break; + } + check_next_list.push_back(successor); + } + if (disinterested) { + continue; + } + + if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(), + [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) { + continue; + } + + latter_to_be_erased.push_back(node); + } + + // Delete Tail MakeTuple(including its getitem nodes). + for (const auto &node : latter_to_be_erased) { + for (auto &pre : (*node_rels)[node].pres) { + (*node_rels)[pre].nexts.erase(node); + } + + // Tail MakeTuple is just be consumed by nothing or invalid getitem node. + for (auto &getitem : (*node_rels)[node].nexts) { + node_rels->erase(getitem); + } + + node_rels->erase(node); + } +} + +bool IsSingleInputNode(const OrderedMap &node_rels, const AnfNodePtr &node) { + if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) { + return true; + } + return false; +} + +bool IsSingleOutputNode(const OrderedMap &node_rels, const AnfNodePtr &node) { + if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) { + return true; + } + return false; +} + +bool IsMultiInputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { + if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) { + return true; + } + return false; +} + +bool IsMultiOutputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { + if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) { + return true; + } + return false; +} + +bool IsNoInputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { + if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) { + return true; + } + return false; +} + +bool IsNoOutputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { + if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) { + return true; + } + return false; +} + +void ProcessLocalStructure(OrderedMap *node_rels, std::set *virtual_noout_nodes, + std::set *ignore_noin_nodes) { + // 1. Local relation + // Graph as following left part, relation D->B and D->E(D is a no input node) + // will make B and E to be multiply inputs node. + // But for parallel, this local relation can ignore for B and E, which make + // them be able to be paralleled. + // + // ************************************ + // * * + // * | | * + // * A D A D * + // * | /| | / \ * + // * | C | | C F * + // * |/ / | | | * + // * B F ====> B x x * + // * | / | * + // * |/ | * + // * E E * + // * | | * + // * * + // ************************************ + AnfNodePtrList no_input_nodes; + for (const auto &node_rel : *node_rels) { + auto &node = node_rel.first; + if (IsNoInputsNode(*node_rels, node)) { + no_input_nodes.push_back(node); + } + } + + std::vector> latter_delete; + + for (const auto &ninode : no_input_nodes) { + AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end()); + for (const auto &n : cnexts) { + AnfNodePtr serial_tail = ninode; + AnfNodePtr cur_node = n; + while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) { + serial_tail = cur_node; + cur_node = *((*node_rels)[cur_node].nexts.begin()); + } + latter_delete.emplace_back(serial_tail, cur_node); + } + } + + // Delete relation. + for (const auto &[serial_tail, cur_node] : latter_delete) { + virtual_noout_nodes->insert(serial_tail); + ignore_noin_nodes->insert(cur_node); + (*node_rels)[serial_tail].nexts.erase(cur_node); + (*node_rels)[cur_node].pres.erase(serial_tail); + MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> " + << cur_node->fullname_with_scope(); + } +} + +std::tuple GetInterestNodeIds( + const OrderedMap &node_rels, const std::set &virtual_noout_nodes, + const std::set &ignore_noin_nodes) { + AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes; + std::list> func_list = { + [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) { + if (IsMultiInputsNode(node_rels, node)) { + multi_inputs_nodes.push_back(node); + } + }, + [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) { + if (IsMultiOutputsNode(node_rels, node)) { + multi_outputs_nodes.push_back(node); + } + }, + [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) { + if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) { + no_input_nodes.push_back(node); + } + }, + [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) { + if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) { + no_output_nodes.push_back(node); + } + }}; + + for (const auto &node_rel : node_rels) { + for (const auto &func : func_list) { + func(node_rel.first); + } + } + + return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes); +} + +bool WhiteOpsFilter(const AnfNodePtr &node) { + std::vector whiteable_ops = {}; // Not special for now. + return session::AnfRuntimeAlgorithm::IsGraphKernel(node) || IsOneOf(node, whiteable_ops); +} + +std::vector SearchFromNodes(const AnfNodePtrList &nodes, + const std::function &filter_func, + const OrderedMap &node_rels, bool is_backward, + std::set *seen) { + // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes. + // For backward search, the other multi-inputs node can be contained in. + // For forward search, the other multi-outputs node can be contained in. + auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; } + : [](const NodeRelation &info) { return info.nexts; }; + auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; } + : [](const NodeRelation &info) { return info.pres; }; + std::vector group; + for (const auto &node : nodes) { + AnfNodePtrList stream; + AnfNodePtr n = node; + for (auto iter = node_rels.find(n); + seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1; + iter = node_rels.find(n)) { + if (filter_func(n)) { + stream.push_back(n); + seen->insert(n); + } + if (get_contain_node_set(iter->second).size() != 1) { + break; + } + n = *(get_contain_node_set(iter->second).begin()); + } + if (stream.size() > 0) { + group.push_back(stream); + } + } + + if (group.size() == 1) { + for (const auto &drop : group[0]) { + seen->erase(drop); + } + group.clear(); + } + + return group; +} + +void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes, + const OrderedMap &node_rels, bool is_backward, + std::vector> *groups, std::set *seen) { + auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; } + : [](const NodeRelation &info) { return info.nexts; }; + for (const auto &node : multi_nodes) { + if (auto iter = node_rels.find(node); iter != node_rels.end()) { + const auto &pre_nodes = get_related_nodes(iter->second); + AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end()); + groups->push_back(SearchFromNodes(related_nodes, WhiteOpsFilter, node_rels, is_backward, seen)); + } + } + + // Erase empty groups. + for (auto iter = groups->begin(); iter != groups->end();) { + if (iter->size() == 0) { + iter = groups->erase(iter); + } else { + ++iter; + } + } +} + +void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes, + const OrderedMap &node_rels, bool is_backward, + std::vector> *groups, std::set *seen) { + groups->push_back(SearchFromNodes(ud_nodes, WhiteOpsFilter, node_rels, is_backward, seen)); + + // Erase empty groups. + for (auto iter = groups->begin(); iter != groups->end();) { + if (iter->size() == 0) { + iter = groups->erase(iter); + } else { + ++iter; + } + } +} + +std::string DumpNode(const AnfNodePtr &node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::stringstream buf; + buf << (AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|" + << cnode->ToString(); + return buf.str(); +} + +void DumpParallelGroups(const std::vector> &groups) { + MS_LOG(INFO) << "There are " << groups.size() << " parallel groups, their detail is: "; + int i = 0; + for (const auto group : groups) { + std::stringstream buf; + buf << "[" << i << " group] " << group.size() << ":\n"; + for (const auto nodes : group) { + buf << " " << nodes.size() << ": [<"; + for (const auto node : nodes) { + buf << "(" << DumpNode(node) << ") -> "; + } + buf << ">]\n"; + } + i++; + MS_LOG(INFO) << buf.str(); + } +} + +void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) { + std::stringstream buf; + buf << "Parallel fusion detail: "; + for (const auto &node : source) { + buf << "(" << DumpNode(node) << ") + "; + } + buf << "==>" + << "(" << DumpNode(target) << ")"; + MS_LOG(INFO) << buf.str(); +} +} // namespace + +OrderedMap ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) { + // Based on anf node input information, build a simple graph for latter analyzation. + OrderedMap node_rels; + auto get_info = [&node_rels](const AnfNodePtr &node) { + if (node_rels.count(node) == 0) { + node_rels.insert({node, NodeRelation()}); + } + return &(node_rels[node]); + }; + + for (const auto &node : nodes) { + if (!node->isa()) { + continue; + } + + auto prior_node = get_info(node); + for (const auto &input : (node->cast())->inputs()) { + // Parameter for ControlDepend when depend mode is 1. + if (!input->isa() && !input->isa()) { + continue; + } + auto behind_node = get_info(input); + prior_node->pres.insert(input); + behind_node->nexts.insert(node); + } + } + + ProcessDependCNode(&node_rels); + ProcessControlDependCNode(&node_rels); + ProcessThroughPassCNode( + [](const AnfNodePtr &node) { + return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem}); + }, + &node_rels); + ProcessThroughPassCNode([](const AnfNodePtr &node) { return node->isa(); }, &node_rels); + ProcessTailMakeTupleCNode(&node_rels); + ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_); + + return node_rels; +} + +std::vector> ParallelOpFusion::SearchParallelGroups( + const OrderedMap &node_rels) { + // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes. + auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] = + GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_); + + // Get streams and group them + std::set seen; + std::vector> groups; + + SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen); + SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen); + SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen); + SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen); + + DumpParallelGroups(groups); + return groups; +} + +std::tuple> ParallelOpFusion::GetAvaliableNodesByOffset( + int start, const std::vector &offsets, const std::vector &used, const AnfNodePtrList &nodes, + const std::set &excludes) { + // Get unused nodes by offset index, the result will contain the node with start index. + int node_limit = nodes.size(); + if (start >= node_limit) { + MS_LOG(EXCEPTION) << "Index offset is exceed the limit of given nodes."; + } + AnfNodePtrList target_nodes = {nodes[start]}; + std::vector valid_indices; + std::vector unused; + for (size_t i = start; i < used.size(); ++i) { + if (!used[i] && excludes.count(i) == 0) { + unused.push_back(i); + } + } + int limit = unused.size(); + for (auto offset : offsets) { + if (offset >= limit) { + MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes."; + } + if (unused[offset] >= node_limit) { + MS_LOG(EXCEPTION) << "Index offset is exceed the limit of nodes."; + } + valid_indices.push_back(unused[offset]); + target_nodes.push_back(nodes[unused[offset]]); + } + + return std::make_tuple(target_nodes, valid_indices); +} + +std::tuple, std::vector> ParallelOpFusion::DoSearchInSortedCandidates( + size_t origin_size, const AnfNodePtrList &candidates, std::map *origin_indices, + std::map *sorted_indices) { + auto get_index = [](std::map *indices, const AnfNodePtr &node) -> int { + MS_EXCEPTION_IF_NULL(node); + if (indices->find(node) == indices->end()) { + MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString(); + } + return (*indices)[node]; + }; + + std::vector parallel_infos; + std::vector origin_candidates_used(origin_size, false); + std::vector sorted_candidates_used(candidates.size(), false); + + for (size_t i = 0; i < candidates.size(); ++i) { + if (sorted_candidates_used[i]) { + continue; + } + + int max_benefit = 0; + ParallelInfo best_parallel_info; + std::set bad_set; + size_t unused_num = 0; + for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) { + unused_num += sorted_candidates_used[j] ? 0 : 1; + } + if (unused_num < 1) { + break; + } + + unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1); + + size_t begin = 1, end = unused_num; + while (begin <= end) { + size_t mid = (begin + end) / 2; + std::vector tc(mid); + std::iota(tc.begin(), tc.end(), 1); + AnfNodePtrList other_candidates; + std::tie(other_candidates, std::ignore) = + GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set()); + int benefit; + std::tie(std::ignore, benefit) = cost_model_ptr_->CalFuseInfo(other_candidates); + if (benefit > 0) { + begin = mid + 1; + } else { + end = mid - 1; + } + } + + if (begin > 1) { + std::vector tc(begin - 1); + std::iota(tc.begin(), tc.end(), 1); + AnfNodePtrList other_candidates; + std::tie(other_candidates, std::ignore) = + GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set()); + auto [dim_infos, benefit] = cost_model_ptr_->CalFuseInfo(other_candidates); + if (benefit <= 0) { + MS_LOG(EXCEPTION) << "Internal error in candidate search!"; + } + max_benefit = benefit; + best_parallel_info = ParallelInfo(other_candidates, dim_infos); + i += begin - 1; + } + + if (max_benefit > 0) { + parallel_infos.push_back(best_parallel_info); + for (const auto &node : best_parallel_info.nodes()) { + sorted_candidates_used[get_index(sorted_indices, node)] = true; + origin_candidates_used[get_index(origin_indices, node)] = true; + } + } + } + + // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility. + if (parallel_infos.size() == 0) { + origin_candidates_used[get_index(origin_indices, candidates[0])] = true; + } + + return std::make_tuple(origin_candidates_used, parallel_infos); +} + +std::tuple, std::vector> ParallelOpFusion::SearchFuseNodesInCandidates( + const AnfNodePtrList &cs) { + std::map origin_indices; + std::vector indices; + for (size_t i = 0; i < cs.size(); ++i) { + if (cs[i]) { + origin_indices.insert({cs[i], i}); + indices.push_back(i); + } + } + + // A calculated heavy node can cover more lighter nodes' cost, so sort them first. + std::map cal_amounts; + for (auto id : indices) { + cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]); + } + std::sort(indices.begin(), indices.end(), + [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; }); + + AnfNodePtrList candidates; + for (size_t i = 0; i < indices.size(); ++i) { + candidates.push_back(cs[indices[i]]); + } + + std::map sorted_indices; + for (size_t i = 0; i < candidates.size(); ++i) { + sorted_indices.insert({candidates[i], i}); + } + + return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices); +} + +void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector &group, + std::vector *parallel_infos) { + std::vector tails; + std::vector ended; + for (const auto &node_list : group) { + tails.push_back(node_list.begin()); + ended.push_back(node_list.end()); + } + auto get_candidates = [&tails, &ended]() { + AnfNodePtrList candidates; + for (size_t id = 0; id < tails.size(); ++id) { + candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr()); + } + return candidates; + }; + auto update_tails = [&tails](const std::vector &used) { + if (used.size() != tails.size()) { + MS_LOG(EXCEPTION) << "Judged nodes size is not equal to left ones!"; + } + for (size_t id = 0; id < used.size(); ++id) { + if (used[id]) { + tails[id]++; + } + } + }; + auto valid_candidate_num = [](const AnfNodePtrList &cs) { + return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; }); + }; + + auto candidates = get_candidates(); + while (valid_candidate_num(candidates) > 1) { + auto [used, fnds] = SearchFuseNodesInCandidates(candidates); + std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos), + [](const ParallelInfo &pi) { return pi; }); + update_tails(used); + candidates = get_candidates(); + } +} + +std::vector ParallelOpFusion::SearchFusableParallelCNodes( + const std::vector> &groups) { + // Find core-fusable groups with cost model. + std::vector parallel_infos; + for (const auto &group : groups) { + SearchFuseNodesInParallelGroup(group, ¶llel_infos); + } + + return parallel_infos; +} + +void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) { + for (size_t i = 0; i < parallel_info.GetSize(); ++i) { + const auto &fuse_nodes = parallel_info.nodes(); + std::vector info = {i, std::dynamic_pointer_cast(parallel_info.dims()[i])->dim_info()}; + if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) { + MakeCNodeSafeForAttr(fuse_nodes[i]); + AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue>(info), fuse_nodes[i]); + } else { + auto node_g = GetValueNode((fuse_nodes[i]->cast())->input(0)); + auto out_node = node_g->output(); + if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { + auto inputs = out_node->cast()->inputs(); + for (size_t j = 1; j < inputs.size(); ++j) { + MakeCNodeSafeForAttr(inputs[j]); + AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue>(info), inputs[j]); + } + } else { + MakeCNodeSafeForAttr(out_node); + AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue>(info), out_node); + } + } + } +} + +void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr &kernel_graph) { + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + + const auto &users = mng->node_users()[node]; + std::vector> sons; + for (const auto &[user, index] : users) { + if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) { + sons.emplace_back(user, index); + continue; + } + auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin()); + sons.emplace_back(fake_first_grad_son, grad_index); + } + + AnfNodePtrList latter_to_delete; + for (const auto &[son, index] : sons) { + if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) { + continue; + } + + latter_to_delete.push_back(son); + } + + if (latter_to_delete.empty()) { + return; + } + + std::vector::iterator delete_begin = latter_to_delete.begin(); + if (latter_to_delete.size() == sons.size()) { + // Left one Depend node relation and delete others! + ++delete_begin; + } + for (; delete_begin != latter_to_delete.end(); ++delete_begin) { + auto depend_anfnode = *delete_begin; + auto depend_cnode = depend_anfnode->cast(); + MS_EXCEPTION_IF_NULL(depend_cnode); + auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend); + mng->Replace(depend_anfnode, depend_prior_node); + } +} + +bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector ¶llel_infos, + const std::shared_ptr &kernel_graph) { + bool changed = false; + + for (size_t i = 0; i < parallel_infos.size(); ++i) { + const auto &fuse_nodes = parallel_infos[i].nodes(); + if (fuse_nodes.size() <= 1) { + continue; + } + changed = true; + SetFusedParallelOpAttrToReturnNode(parallel_infos[i]); + AnfNodePtr sg_node; + std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); + PostProcessForNewSubGraphCNode(sg_node, kernel_graph); + DumpParallelFusionDetail(fuse_nodes, sg_node); + } + + return changed; +} + +bool ParallelOpFusion::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto kernel_graph = graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); + + cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_); + MS_EXCEPTION_IF_NULL(cost_model_ptr_); + + auto nodes = TopoSort(kernel_graph->get_return()); + std::reverse(nodes.begin(), nodes.end()); + + auto node_rels = GenAnalysisGraph(nodes); + auto groups = SearchParallelGroups(node_rels); + auto parallel_infos = SearchFusableParallelCNodes(groups); + + // Create core-fuse subgraph and change origin graph. + return CreateParallelOpSubGraphs(parallel_infos, kernel_graph); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h new file mode 100644 index 0000000000..a9372d6018 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h @@ -0,0 +1,122 @@ + +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/graph_kernel/parallel_cost_model.h" +#include "backend/session/kernel_graph.h" +#include "utils/ms_context.h" + +namespace mindspore { +namespace opt { +class ParallelInfo { + public: + ParallelInfo() = default; + ParallelInfo(const AnfNodePtrList &nodes, const std::vector &dims) : nodes_(nodes), dims_(dims) {} + ParallelInfo(const ParallelInfo &obj) { + nodes_ = obj.nodes_; + dims_ = obj.dims_; + } + ~ParallelInfo() = default; + + size_t GetSize() const { + if (nodes_.size() != dims_.size()) { + MS_LOG(EXCEPTION) << "Internal error in parallel info!"; + } + return nodes_.size(); + } + const AnfNodePtrList &nodes() const { return nodes_; } + const std::vector &dims() const { return dims_; } + + private: + AnfNodePtrList nodes_; + std::vector dims_; +}; + +class ParallelConfig { + public: + ParallelConfig() = default; + explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {} + explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; } + ~ParallelConfig() = default; + size_t max_num_for_fuse() { return max_num_for_fuse_; } + + private: + size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result. +}; + +struct NodeRelation { + public: + NodeRelation() {} + ~NodeRelation() = default; + OrderedSet pres; + OrderedSet nexts; +}; + +class ParallelOpFusion : public Pass { + public: + ParallelOpFusion(const std::string &target, const ParallelConfig &config) + : Pass("parallel_fusion"), target_(target), config_(config) {} + ~ParallelOpFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + std::tuple> GetAvaliableNodesByOffset(int start, const std::vector &offsets, + const std::vector &used, + const AnfNodePtrList &nodes, + const std::set &excludes); + + std::tuple, std::vector> DoSearchInSortedCandidates( + size_t origin_size, const AnfNodePtrList &candidates, std::map *origin_indices, + std::map *sorted_indices); + + std::tuple, std::vector> SearchFuseNodesInCandidates(const AnfNodePtrList &cs); + + void SearchFuseNodesInParallelGroup(const std::vector &group, + std::vector *parallel_infos); + + std::vector SearchFusableParallelCNodes(const std::vector> &groups); + + void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); + + bool CreateParallelOpSubGraphs(const std::vector ¶llel_infos, + const std::shared_ptr &kernel_graph); + + OrderedMap GenAnalysisGraph(const AnfNodePtrList &nodes); + std::vector> SearchParallelGroups(const OrderedMap &node_rels); + + std::string target_; + ParallelConfig config_; + ParallelCostModelPtr cost_model_ptr_; + std::set virtual_noout_nodes_; + std::set ignore_noin_nodes_; +}; +using ParallelOpFusionPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 04d7345f21..96013d5d81 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -43,6 +43,7 @@ #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" #include "backend/optimizer/graph_kernel/clean_all_in_once.h" +#include "backend/optimizer/graph_kernel/depend_formater.h" #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" #include "backend/optimizer/graph_kernel/tensor_promotion.h" #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" @@ -51,6 +52,7 @@ #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" #include "backend/optimizer/graph_kernel/value_graph_binder.h" +#include "backend/optimizer/graph_kernel/parallel_fusion.h" #include "backend/optimizer/pass/communication_op_fusion.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "common/trans.h" @@ -179,6 +181,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ auto optimizer = std::make_shared(); auto pm = std::make_shared("graph_kernel_pm"); std::vector duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; + pm->AddPass(std::make_shared()); // Make more fusion opportunity. pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); @@ -196,7 +199,8 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ // will be exposed, use GetitemTuple Pass to delete them. pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); // Prevent fake loop in parallel fusion. + pm->AddPass(std::make_shared(kGPUDevice, opt::ParallelConfig(7))); pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5971ecc91e..a12d81d4e2 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -382,6 +382,7 @@ constexpr auto kAttrPadding = "padding"; constexpr auto kAttrIsGrad = "is_grad"; constexpr auto kAttrRecompute = "recompute"; constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; +constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/st/graph_kernel/model/test_graph_parallel.py b/tests/st/graph_kernel/model/test_graph_parallel.py new file mode 100644 index 0000000000..4f2fa89a62 --- /dev/null +++ b/tests/st/graph_kernel/model/test_graph_parallel.py @@ -0,0 +1,54 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================== +"""test graph parallel case""" +import model + +def injective_graph(shape): + gb = model.GraphBuilder() + with gb.graph_scope('injective') as _: + a1 = gb.tensor(shape, 'float32') + a2 = gb.emit('Abs', a1) + a3 = gb.emit('Abs', a2) + gb.emit('Abs', a3) + return gb.get()[0] + +def reduce_graph(shape, reduce_axis): + gb = model.GraphBuilder() + with gb.graph_scope('reduce') as _: + a1 = gb.tensor(shape, 'float32') + a2 = gb.emit('Abs', a1) + a3 = gb.emit('Abs', a2) + gb.emit('ReduceSum', a3, 'C', attrs={'reduce_axis': reduce_axis}) + return gb.get()[0] + +def control_graph(shape): + gb = model.GraphBuilder() + with gb.graph_scope('control') as _: + a1 = gb.tensor(shape, 'float32') + a2 = gb.emit('Abs', a1) + gb.emit('ControlDepend', a2) + return gb.get()[0] + +def block_fusion(graphs): + gain = model.parallel_estimate(graphs) + print("fusion = {}, bottleneck = {}, gain = {}".format(gain.fusion_type, gain.bottleneck, gain.gain)) + return gain.fusion_type == "block_fusion" and gain.gain > 0 + +if __name__ == "__main__": + assert block_fusion([injective_graph([40, 1024]), injective_graph([40, 1024])]) + assert block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([24, 1024])]) + assert not block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([50, 1024])]) + assert not block_fusion([reduce_graph([1024, 1024], [0, 1]), injective_graph([1024, 1024])]) + assert block_fusion([control_graph([20, 128]), injective_graph([40, 1024])])