diff --git a/akg b/akg index 94cb709eca..24ba04df56 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 94cb709ecaf5d1d869883dfe80cee7497dd0692c +Subproject commit 24ba04df564fb3d2578e1b4324c760783b34d551 diff --git a/mindspore/_extends/graph_kernel/model/graph_parallel.py b/mindspore/_extends/graph_kernel/model/graph_parallel.py index 93360ae5df..d4a5cacd0e 100644 --- a/mindspore/_extends/graph_kernel/model/graph_parallel.py +++ b/mindspore/_extends/graph_kernel/model/graph_parallel.py @@ -17,11 +17,12 @@ from .model import PrimLib class ParalGain: - def __init__(self, fusion_type, bottleneck, gain, block_assign): + def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info): self.fusion_type = fusion_type self.bottleneck = bottleneck self.gain = gain self.block_assign = block_assign + self.type_info = type_info class ScheduleAnalyzer: @@ -30,6 +31,7 @@ class ScheduleAnalyzer: MAX_SM = 80 # Volta MAX_NUM_THREADS = 1024 MAX_BLOCK = 256 + PIPELINE_OP_THREADHOLD = 5 def __init__(self, graph): self.graph = graph @@ -132,11 +134,141 @@ class ScheduleAnalyzer: else: self.default_analyze() + def suitable_to_pipeline(self): + """judge whether is suitable to be pipeline optimized""" + # Reduce is not suitable + def _contain_reduce(ops): + for op in ops: + # Reduce may make the tiling bad. + if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE: + return True + return False + + suitable = True + if _contain_reduce(self.ops): + suitable = False + return suitable + + @staticmethod + def k_mean(data, class_n=2, exclude_id=()): + """ + Find k clusters in which element is close to each other. + + Args: + data (list): Elements' information. + class_n (int): Number of clusters wanted to be analyzed, default is 2. + exclude_id (tuple[int]): The list of excluded element's index, default is (). + + Returns: + classes (list[list[int]]): The list of clusters. Each cluster is a list of indices. + """ + def _cal_mean(classes): + class_datas = [[data[cid] for cid in cls] for cls in classes] + return [sum(cls) / len(cls) if cls else float('inf') for cls in class_datas] + + def _cal_distance(a, b): + return abs(a - b) + + def _check_different(old_classes, new_classes): + for o, n in zip(old_classes, new_classes): + if o != n: + return True + return False + + if len(data) < class_n: + return None + classes = [] + for i, _ in enumerate(data): + if i in exclude_id: + continue + if len(classes) >= class_n: + break + classes.append([i]) + changed = True + while changed: + new_classes = [[] for cls in classes] + means = _cal_mean(classes) + for idx, d in enumerate(data): + if idx in exclude_id: + continue + min_idx = -1 + min_dis = float('inf') + for i, m in enumerate(means): + cur_dis = _cal_distance(m, d) + min_idx = i if min_dis > cur_dis else min_idx + min_dis = cur_dis if min_dis > cur_dis else min_dis + new_classes[min_idx].append(idx) + changed = _check_different(classes, new_classes) + classes = new_classes + return classes + + @staticmethod + def pipeline_fusion_analyze(blocks, op_sizes, exclude_id): + """analyze whether the segments can be pipeline optimized""" + # op size first, block second. + def _simple_factor(block, op_size): + return block + 5 * op_size + + def _take_second(elem): + return elem[1] + + simple_indicators = [_simple_factor(b, s) + for b, s in zip(blocks, op_sizes)] + # 2 classes, one heavy, the other light + classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id) + if not classes: + return [] + means = [sum([simple_indicators[idx] for idx in cls]) / + len(cls) if cls else float('inf') for cls in classes] + + # The target two clusters should be a heavy one and a light one. + # The light one maybe suitable to run with pipeline optimized. + classes_infos = [[cls, m] for cls, m in zip(classes, means)] + classes_infos.sort(key=_take_second) + pipeline_target = None + for ci in classes_infos: + if ci: + pipeline_target = ci + break + pipeline_gids, pipeline_mean = pipeline_target + if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks), + ScheduleAnalyzer.PIPELINE_OP_THREADHOLD): + return [] + + pipeline_blocks = [] + pipeline_weight = len(pipeline_gids) + # Try to make two paralleled at least. + if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2: + if len(pipeline_gids[:pipeline_weight // 2]) > 1: + pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2]) + if len(pipeline_gids[pipeline_weight // 2:]) > 1: + pipeline_blocks.append(pipeline_gids[pipeline_weight // 2:]) + elif pipeline_weight > 1: + pipeline_blocks.append(pipeline_gids) + return pipeline_blocks + + @staticmethod + def fusion_consult(blocks, op_sizes, exclude_gid): + """get a recommendation for parallel fusion""" + # Default is block fusion + fusion_type = "block_fusion" + type_info = None + + activate_pipeline_optimization = False # Disable pipeline optimization for now. + if activate_pipeline_optimization: + pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze( + blocks, op_sizes, exclude_gid) + if pipeline_info: + fusion_type = "block_pipeline_fusion" + type_info = pipeline_info + + return fusion_type, type_info + def block_parallel_estimate(graphs): """estimate block parallel gain""" - sum_block, max_weight, sum_weight, blocks = 0, 0, 0, [] - for g in graphs: + sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], [] + for gid, g in enumerate(graphs): s = ScheduleAnalyzer(g) s.analyze() sum_block += s.block_num @@ -144,9 +276,14 @@ def block_parallel_estimate(graphs): max_weight = s.block_weight sum_weight += s.block_weight blocks.append(s.block_num) + op_sizes.append(len(s.ops)) + if not s.suitable_to_pipeline(): + exclude_gid.append(gid) if sum_block > ScheduleAnalyzer.MAX_SM * 32: - return ParalGain("none", sum_weight, 0, []) - return ParalGain("block_fusion", max_weight, sum_weight - max_weight, blocks) + return ParalGain("none", sum_weight, 0, [0 for _ in graphs], None) + + fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid)) + return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info) def parallel_estimate(graphs): diff --git a/mindspore/_extends/graph_kernel/parallel_estimate.py b/mindspore/_extends/graph_kernel/parallel_estimate.py index 593eb558e9..f5d2105a2f 100644 --- a/mindspore/_extends/graph_kernel/parallel_estimate.py +++ b/mindspore/_extends/graph_kernel/parallel_estimate.py @@ -28,10 +28,8 @@ def estimate_ops(json_str: str): for gd in graph_descs: graphs.append(model.load_composite(gd).graph) estimation = model.parallel_estimate(graphs) - if estimation.fusion_type == "block_fusion" and estimation.gain > 0: - res = (estimation.block_assign, estimation.gain) - else: - res = ([0 for g in graphs], 0) + res = (estimation.block_assign, estimation.gain, + estimation.fusion_type, estimation.type_info) return res except jd.JSONDecodeError: logger.error(traceback.format_exc()) diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index 619802248a..7508476cb5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -557,30 +557,6 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j return true; } -void AkgKernelJsonGenerator::SetParallelValueToJson(const std::string &processor, - const std::map &dim_infos, - nlohmann::json *sub_fusion_json) { - if (processor == kProcessorCuda) { - std::vector cnums; - std::transform(dim_infos.cbegin(), dim_infos.cend(), std::back_insert_iterator(cnums), - [](const std::pair &dim) { return dim.second; }); - (*sub_fusion_json)[kJsonKeyCoreNum] = cnums; - } else { - MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now."; - } -} - -void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json) { - nlohmann::json parallel_fusion_json; - parallel_fusion_json[kJsonKeyFusionType] = "block_fusion"; - std::vector> sgraphs; - std::transform(sub_graphs_.cbegin(), sub_graphs_.cend(), std::back_insert_iterator(sgraphs), - [](const std::pair> &sg) { return sg.second; }); - parallel_fusion_json[kJsonKeySubGraph] = sgraphs; - SetParallelValueToJson(processor, dim_infos_, ¶llel_fusion_json); - (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; -} - void AkgKernelJsonGenerator::GenStitchJson(const std::vector &anf_nodes, std::map *node_json_map, nlohmann::json *kernel_json) { @@ -633,12 +609,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf (*kernel_json)[kJsonKeyOutputDesc] = CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); - auto processor = GetProcessorStr(anf_nodes[0]); - // Add parallel fusion information. - if (!sub_graphs_.empty()) { - AddParalleFusionJsonInfo(processor, kernel_json); - } + GenParallelJson(anf_nodes, input_list, output_list, node_json_map, kernel_json); size_t hash_id = std::hash()(kernel_json->dump()); kernel_name_ = "Fused_"; @@ -660,7 +632,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf (*kernel_json)[kJsonKeyId] = GetOpCntInc(); (*kernel_json)[kJsonKeyOp] = kernel_name_; (*kernel_json)[kJsonKeyPlatform] = "AKG"; - (*kernel_json)[kJsonKeyProcess] = processor; + (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); (*kernel_json)[kJsonKeyComposite] = true; (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); @@ -755,6 +727,70 @@ nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector &anf_nodes, + const std::vector &input_list, + const std::vector &output_list, + const std::map &node_json_map, + nlohmann::json *kernel_json) { + std::map>> sub_graphs_info; + std::string fusion_type; + std::vector> type_info; + + auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); + for (size_t i = 0; i < output_index.size(); ++i) { + auto [tmp_output, tmp_output_index] = output_index[i]; + bool found = std::any_of(input_list.cbegin(), input_list.cend(), + [&tmp_output](const AnfNodePtr &in) { return tmp_output == in; }); + if (!found) { + auto tcnode = tmp_output->cast(); + if (tcnode == nullptr) { + return; + } + // Get dim info. + if (AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) { + auto info = AnfAlgo::GetNodeAttr>(tcnode, kAttrParallelDimInfo); + if (info.size() != 2) { + MS_LOG(EXCEPTION) << "Parallel dim info is invalid!"; + } + auto tensor_name = + GetTensorName(node_json_map.at(tmp_output), kJsonKeyOutputDesc, std::make_pair(0, tmp_output_index)); + sub_graphs_info[info[0]].second.push_back(tensor_name); + sub_graphs_info[info[0]].first = info[1]; + } + // Get fusion type. + if (AnfAlgo::HasNodeAttr(kAttrParallelFusionType, tcnode)) { + fusion_type = AnfAlgo::GetNodeAttr(tcnode, kAttrParallelFusionType); + } + // Get fusion type info. + if (AnfAlgo::HasNodeAttr(kAttrParallelTypeInfo, tcnode)) { + type_info = AnfAlgo::GetNodeAttr>>(tcnode, kAttrParallelTypeInfo); + } + } + } + + if (!sub_graphs_info.empty()) { + auto processor = GetProcessorStr(anf_nodes[0]); + if (processor != kProcessorCuda) { + MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now."; + } + + nlohmann::json parallel_fusion_json; + parallel_fusion_json[kJsonKeyFusionType] = fusion_type; + parallel_fusion_json[kJsonKeyTypeInfo] = type_info; + std::vector> sgraphs; + std::vector cnums; + std::for_each(sub_graphs_info.cbegin(), sub_graphs_info.cend(), + [&sgraphs, &cnums](const std::pair>> &sg_info) { + sgraphs.push_back(sg_info.second.second); + cnums.push_back(sg_info.second.first); + }); + parallel_fusion_json[kJsonKeySubGraph] = sgraphs; + parallel_fusion_json[kJsonKeyCoreNum] = cnums; + + (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; + } +} + nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector &anf_nodes, const std::vector &input_list, const std::vector &output_list, @@ -785,17 +821,6 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vectorcast(); - tcnode && AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) { - auto info = AnfAlgo::GetNodeAttr>(tcnode, kAttrParallelDimInfo); - if (info.size() != 2) { - MS_LOG(EXCEPTION) << "Parallel dim info is invalid!"; - } - sub_graphs_[info[0]].push_back(output_desc_json[kJsonKeyTensorName]); - if (dim_infos_.find(info[0]) == dim_infos_.end()) { - dim_infos_[info[0]] = info[1]; - } - } } outputs_json.emplace_back(output_desc_json); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h index 148a54ad9f..99bca58db8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h @@ -54,6 +54,7 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; constexpr auto kJsonKeyFusionType = "fusion_type"; constexpr auto kJsonKeySubGraph = "sub_graph"; constexpr auto kJsonKeyCoreNum = "core_num"; +constexpr auto kJsonKeyTypeInfo = "type_info"; constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; constexpr auto kJsonKeyStitchOp = "stitch_op"; constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; @@ -89,8 +90,6 @@ class AkgKernelJsonGenerator { input_tensor_idx_.clear(); address_node_map_.clear(); output_tensor_idx_ = 0; - sub_graphs_.clear(); - dim_infos_.clear(); } void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; } std::map address_node_map() { return address_node_map_; } @@ -127,9 +126,10 @@ class AkgKernelJsonGenerator { std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index); void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json); OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node); - void SetParallelValueToJson(const std::string &processor, const std::map &dim_infos, - nlohmann::json *sub_fusion_json); - void AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json); + void CollectParallelDimInfo(const AnfNodePtr &anf_node); + void GenParallelJson(const std::vector &anf_nodes, const std::vector &input_list, + const std::vector &output_list, + const std::map &node_json_map, nlohmann::json *kernel_json); DumpOption dump_option_; static int op_cnt_; @@ -142,8 +142,6 @@ class AkgKernelJsonGenerator { std::vector input_size_list_; std::vector output_size_list_; std::map address_node_map_; - std::map> sub_graphs_; - std::map dim_infos_; bool is_basic_op_{false}; }; } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc index 143caac576..27486b1f4b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc @@ -60,8 +60,9 @@ KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_ge return cached_kernel_pack; } - (void)alarm(AUTODIFF_COMPILE_OVERTIME); auto kernel_json = json_generator.kernel_json_str(); + kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path()); + (void)alarm(AUTODIFF_COMPILE_OVERTIME); auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json); (void)alarm(0); if (!res) { @@ -70,7 +71,6 @@ KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_ge } auto new_kernel_pack = InsertCache(kernel_name, processor); - kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path()); if (new_kernel_pack == nullptr) { MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" << anf_node->fullname_with_scope() << "]."; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc index 9a52d05699..8a98c29491 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.cc @@ -47,7 +47,7 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) { return py::cast(ret); } -std::tuple, int> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { +std::tuple, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { nlohmann::json json_desc; std::vector graphs; std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), @@ -65,7 +65,7 @@ std::tuple, int> ParallelCostModel::CalFuseInfo(const An } py::tuple ret_tuple = py::cast(ret); - if (!py::isinstance(ret_tuple) || ret_tuple.size() != 2) { + if (!py::isinstance(ret_tuple) || ret_tuple.size() != 4) { MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!"; } @@ -75,8 +75,41 @@ std::tuple, int> ParallelCostModel::CalFuseInfo(const An dim_infos.push_back(std::make_shared(py::cast(dim_list[i]))); } int benefit = py::cast(ret_tuple[1]); + auto fusion_info = ProcessFusionInfo(ret_tuple[2], ret_tuple[3]); - return std::make_tuple(dim_infos, benefit); + return std::make_tuple(dim_infos, benefit, fusion_info); +} + +FusionInfoPtr ParallelCostModel::ProcessFusionInfo(py::object fusion_type, py::object type_info) { + if (!py::isinstance(fusion_type)) { + MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!"; + } + + std::string fusion_type_name = py::cast(fusion_type); + + FusionInfoPtr fusion_info; + if (fusion_type_name == "block_fusion") { + fusion_info = std::make_shared(); + } else if (fusion_type_name == "block_pipeline_fusion") { + if (!py::isinstance(type_info)) { + MS_LOG(EXCEPTION) << "Fusion type info for block pipe fusion type is invalid!"; + } + std::vector> pipeline_ids; + py::list pipeline_ids_list = py::cast(type_info); + for (size_t i = 0; i < pipeline_ids_list.size(); ++i) { + std::vector part_ids; + py::list inner_ids_list = py::cast(pipeline_ids_list[i]); + for (size_t j = 0; j < inner_ids_list.size(); ++j) { + part_ids.push_back(py::cast(inner_ids_list[j])); + } + pipeline_ids.push_back(part_ids); + } + + fusion_info = std::make_shared(pipeline_ids); + } else { + MS_LOG(EXCEPTION) << "Unsupported parallel fusion type: " << fusion_type_name; + } + return fusion_info; } ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h index a5dd442b3f..1689bd62f6 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_cost_model.h @@ -29,6 +29,7 @@ #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/graph_kernel/parallel_cost_model.h" #include "backend/session/kernel_graph.h" +#include "pipeline/jit/parse/python_adapter.h" #include "utils/ms_context.h" namespace mindspore { @@ -55,12 +56,50 @@ class CommonDimInfo : public DimInfo { using DimInfoPtr = std::shared_ptr; using CommonDimInfoPtr = std::shared_ptr; +class FusionInfo { + public: + FusionInfo() = default; + explicit FusionInfo(const std::string &type) : fusion_type_(type) {} + ~FusionInfo() = default; + std::string FusionType() { return fusion_type_; } + virtual bool ExistTypeInfo() { return false; } + + private: + std::string fusion_type_{"none"}; +}; + +class BlockFusionInfo : public FusionInfo { + public: + BlockFusionInfo() : FusionInfo("block_fusion") {} + ~BlockFusionInfo() = default; + bool ExistTypeInfo() { return false; } +}; + +class BlockPipelineFusionInfo : public FusionInfo { + public: + explicit BlockPipelineFusionInfo(const std::vector> &ids) + : FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {} + ~BlockPipelineFusionInfo() = default; + bool ExistTypeInfo() { return true; } + std::vector> PipelineIds() { return pipeline_ids_; } + + private: + std::vector> pipeline_ids_; +}; + +using FusionInfoPtr = std::shared_ptr; +using BlockFusionInfoPtr = std::shared_ptr; +using BlockPipelineFusionInfoPtr = std::shared_ptr; + class ParallelCostModel { public: ParallelCostModel() {} ~ParallelCostModel() {} int GetNodeCalAmount(const AnfNodePtr &node); - std::tuple, int> CalFuseInfo(const AnfNodePtrList &nodes); + std::tuple, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes); + + private: + FusionInfoPtr ProcessFusionInfo(py::object fusion_type, py::object type_info); }; using ParallelCostModelPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc index bcd94f1c78..73e79c3cdd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc @@ -553,7 +553,7 @@ std::tuple, std::vector> ParallelOpFusion::DoSea std::tie(other_candidates, std::ignore) = GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set()); int benefit; - std::tie(std::ignore, benefit) = cost_model_ptr_->CalFuseInfo(other_candidates); + std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates); if (benefit > 0) { begin = mid + 1; } else { @@ -567,12 +567,12 @@ std::tuple, std::vector> ParallelOpFusion::DoSea AnfNodePtrList other_candidates; std::tie(other_candidates, std::ignore) = GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set()); - auto [dim_infos, benefit] = cost_model_ptr_->CalFuseInfo(other_candidates); + auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates); if (benefit <= 0) { MS_LOG(EXCEPTION) << "Internal error in candidate search!"; } max_benefit = benefit; - best_parallel_info = ParallelInfo(other_candidates, dim_infos); + best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info); i += begin - 1; } @@ -676,10 +676,13 @@ std::vector ParallelOpFusion::SearchFusableParallelCNodes( } void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) { + AnfNodePtr attach_node; + // Dim info should be attach to each segment's output. for (size_t i = 0; i < parallel_info.GetSize(); ++i) { const auto &fuse_nodes = parallel_info.nodes(); std::vector info = {i, std::dynamic_pointer_cast(parallel_info.dims()[i])->dim_info()}; if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) { + attach_node = fuse_nodes[i]; SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue>(info), fuse_nodes[i]); } else { auto node_g = GetValueNode((fuse_nodes[i]->cast())->input(0)); @@ -689,11 +692,16 @@ void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &pa for (size_t j = 1; j < inputs.size(); ++j) { SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue>(info), inputs[j]); } + attach_node = inputs[1]; } else { + attach_node = out_node; SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue>(info), out_node); } } } + + // Fusion info is ok to attach to one of the segments. + SetFusionInfoAttrToNode(attach_node, parallel_info); } void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr &kernel_graph) { @@ -741,6 +749,17 @@ void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_pt } } +void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) { + auto fusion_type = parallel_info.fusion_info()->FusionType(); + AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue(fusion_type), node); + if (parallel_info.fusion_info()->ExistTypeInfo()) { + if (auto pipeline_fusion = std::dynamic_pointer_cast(parallel_info.fusion_info())) { + AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo, + MakeValue>>(pipeline_fusion->PipelineIds()), node); + } + } +} + bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector ¶llel_infos, const std::shared_ptr &kernel_graph) { bool changed = false; @@ -755,6 +774,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector AnfNodePtr sg_node; std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); PostProcessForNewSubGraphCNode(sg_node, kernel_graph); + AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node); DumpParallelFusionDetail(fuse_nodes, sg_node); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h index a9372d6018..d65b3cca4f 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.h @@ -37,10 +37,12 @@ namespace opt { class ParallelInfo { public: ParallelInfo() = default; - ParallelInfo(const AnfNodePtrList &nodes, const std::vector &dims) : nodes_(nodes), dims_(dims) {} + ParallelInfo(const AnfNodePtrList &nodes, const std::vector &dims, const FusionInfoPtr &fusion_info) + : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {} ParallelInfo(const ParallelInfo &obj) { nodes_ = obj.nodes_; dims_ = obj.dims_; + fusion_info_ = obj.fusion_info_; } ~ParallelInfo() = default; @@ -52,10 +54,12 @@ class ParallelInfo { } const AnfNodePtrList &nodes() const { return nodes_; } const std::vector &dims() const { return dims_; } + const FusionInfoPtr &fusion_info() const { return fusion_info_; } private: AnfNodePtrList nodes_; std::vector dims_; + FusionInfoPtr fusion_info_; }; class ParallelConfig { @@ -102,6 +106,8 @@ class ParallelOpFusion : public Pass { std::vector SearchFusableParallelCNodes(const std::vector> &groups); + void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info); + void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); bool CreateParallelOpSubGraphs(const std::vector ¶llel_infos, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 22681b31c3..7497ae7b0a 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -397,6 +397,9 @@ constexpr auto kAttrIsGrad = "is_grad"; constexpr auto kAttrRecompute = "recompute"; constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; +constexpr auto kAttrParallelFusionType = "parallel_fusion_type"; +constexpr auto kAttrParallelTypeInfo = "parallel_type_info"; +constexpr auto kAttrCompositeType = "composite_type"; constexpr auto kAttrStitch = "stitch"; constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; constexpr auto kAttrSwitchLayer = "switch_layer";