diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index b07648ed85..1b7d7fddf0 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -14,138 +14,221 @@ # =========================================================================== """Cost model splitter""" -from .model import PrimLib, Graph +from .model import PrimLib, Graph, Tensor class GraphSplitByPattern: - """Graph split by pattern""" + """Graph splitter""" + class Area: + """Area""" + MODE_BASIC = 1 + MODE_COMPOSITE = 2 + + def __init__(self, init_op): + self.pattern = PrimLib.iter_type(init_op) + self.ops = [init_op] + self.in_relations = dict() # {area1: relation1, area2: relation2, ...} + self.out_relations = dict() # {area1: relation1, area2: relation2, ...} + self.mode = self.MODE_BASIC + + def __str__(self): + return '<' + '-'.join([op.output.name for op in self.ops]) + '>' + + def __repr__(self): + return str(self) + + def link_input(self, area_map): + """Link inputs""" + def get_relation(op, i): + relation = PrimLib.UNKNOWN + _, elem_relation = PrimLib.input_relation(op, i) + for r in elem_relation: + if r is not None and r > relation: + relation = r + return relation + for i, t in enumerate(self.ops[0].inputs): + if t.op is not None: + area, relation = area_map[t.op], get_relation(self.ops[0], i) + self.in_relations[area] = relation + + def link_output(self): + """Link outputs""" + for input_area, r in self.in_relations.items(): + input_area.out_relations[self] = r + + def fuse(self, area): + """Fuse `area` to `self`""" + def _update_relation(relations, a, r): + relations[a] = max(r, relations[a]) if a in relations else r + + def _update_pattern(): + self.pattern = max(self.pattern, area.pattern, self.in_relations[area]) + + def _fuse_relation(self_relations, new_relations): + for a, r in new_relations.items(): + if a != self: + _update_relation(self_relations, a, r) + if area in self_relations: + self_relations.pop(area) + + def _redirect_relation(rels): + """Replace `area` with `self` in relations""" + if area in rels: + r = rels.pop(area) + _update_relation(rels, self, r) + + self.ops.extend(area.ops) + _update_pattern() + _fuse_relation(self.in_relations, area.in_relations) + _fuse_relation(self.out_relations, area.out_relations) + for a, _ in area.in_relations.items(): + _redirect_relation(a.out_relations) + for a, _ in area.out_relations.items(): + _redirect_relation(a.in_relations) + self.mode = self.MODE_COMPOSITE + + def check_circle(self, to): + """Check circle. It returns false if circle exists""" + def _reached(area, to): + for out, _ in area.out_relations.items(): + if out == to or _reached(out, to): + return True + return False + for out, _ in self.out_relations.items(): + if out != to and _reached(out, to): + return False + return True + + BORADCAST_FUSE_DEPTH = 3 + REDUCE_FUSE_DEPTH = 3 def __init__(self, graph): self.graph = graph - self.groups = [] - self.op_group = {} - for op in self.graph.ops: - g = [op] - self.groups.append(g) - self.op_group[op] = g - self.ids = {} - for i, op in enumerate(graph.ops): - self.ids[op] = i - self.doms = self.post_dom(graph.ops) - _, outputs = graph.deduce_parameters() - self.outputs = set(outputs) - - def post_dom(self, ops): - """Post dom""" - doms, i_doms = {}, {} - for i in range(len(ops) - 1, -1, -1): - op = ops[i] - doms[op] = {op} - i_dom = None - if op.output.to_ops: - suc_dom = set(doms[op.output.to_ops[0]]) - for to in op.output.to_ops[1:]: - suc_dom.intersection_update(doms[to]) - doms[op].update(suc_dom) - for dom in suc_dom: - if i_dom is None or self.ids[dom] < self.ids[i_dom]: - i_dom = dom - i_doms[op] = i_dom - return i_doms - - def get_pattern(self, op, i): - """Get pattern""" - pattern = PrimLib.UNKNOWN - _, elem_relation = PrimLib.input_relation(op, i) - for pat in elem_relation: - if pat and pat > pattern: - pattern = pat - return pattern - - def fuse(self, check_fun): - """Fuse ops""" - def _get_path(op, dom): - path_ops, visited = [], set() - - def _get_path_depth(p): - visited.add(p) - if self.op_group[p][0] == p: - path_ops.append(p) - for to in p.output.to_ops: - if to != dom and to not in visited: - _get_path_depth(to) - _get_path_depth(op) - return path_ops - changed = True - while changed: - for group in self.groups: - op = group[0] - dom = self.doms[op] - if dom is None or op.output in self.outputs: - continue - ops = _get_path(op, dom) - if check_fun(op, dom, ops): - dom_group = self.op_group[dom] - fused = [] - for fop in ops: - f_group = self.op_group[fop] - for p in f_group: - self.op_group[p] = dom_group - fused.append(f_group) - dom_group += f_group - for g in fused: - self.groups.remove(g) + self.areas = [] + area_map = {} + for op in graph.ops: + a = self.Area(op) + self.areas.append(a) + area_map[op] = a + for a in self.areas: + a.link_input(area_map) + for a in self.areas: + a.link_output() + + def fuse(self, selector): + """Fuse areas""" + changed = False + while True: + for dominant in self.areas: + fuse_areas = selector(dominant) + if fuse_areas: + for area in fuse_areas: + changed = True + dominant.fuse(area) + self.areas.remove(area) break else: - changed = False + return changed def to_subgraphs(self): """Transform op groups to subgraphs""" + ids = {} + for i, op in enumerate(self.graph.ops): + ids[op] = i subgraphs = [] - for i, group in enumerate(self.groups): - group.sort(key=lambda op: self.ids[op]) - subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group)) - return subgraphs + graphmodes = [] + for i, area in enumerate(self.areas): + area.ops.sort(key=lambda op: ids[op]) + subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops)) + graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") + return subgraphs, graphmodes def split(self): - """Split graph""" - def _buddy(op, dom, path_ops): - """Fuse buddy together""" - group = self.op_group[op] - for p in group: - # p is buddy - if p.output.buddy is not None and p.output.buddy.members[0].op not in group: + """Split graph by pattern""" + def _elemwise_depth(dom): + if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1: + return None + a, r = list(dom.in_relations.items())[0] + if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE: + return None + return [a] + + def _elemwise_width(dom): + if dom.pattern > PrimLib.BROADCAST: + return None + fused = [] + for a, r in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom): + fused.append(a) + return fused + + def _broadcast_depth(dom): + if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1: + return None + a, r = list(dom.in_relations.items())[0] + if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ + r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH: + return None + return [a] + + def _broadcast_width(dom): + if dom.pattern > PrimLib.BROADCAST: + return None + fused = [] + for a, r in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \ + a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH: + fused.append(a) + return fused + + def _check_reduce_exclude(dom): + # exclude large all-reduce + if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \ + dom.ops[0].inputs[0].get_size() > 10000: + return True + + # exclude multi output + for a in dom.in_relations.keys(): + if len(a.out_relations) > 1: + return True + if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]): return True - # p's output is buddy - for to in p.output.to_ops: - if to.output.buddy is not None and to not in group: - return True return False - def _injective(pattern, limit): - def _checker(op, dom, path_ops): - for p in op.output.to_ops: - if p not in self.op_group[dom]: - return False - if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST): - for i, t in enumerate(dom.inputs): - if t == op.output: - return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit - return False - return _checker + def _reduce_depth(dom): + if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: + return None + if _check_reduce_exclude(dom): + return None + a, r = list(dom.in_relations.items())[0] + if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ + r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH: + return None + return [a] - def _diamond(op, dom, path_ops): - if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ - PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM): - return False - return len(path_ops) == 1 and op.output not in dom.inputs - self.fuse(_buddy) - self.fuse(_injective(PrimLib.ELEMWISE, 100)) - self.fuse(_injective(PrimLib.BROADCAST, 6)) - self.fuse(_injective(PrimLib.REDUCE, 6)) - self.fuse(_diamond) - return self.to_subgraphs() + def _reduce_width(dom): + if dom.pattern != PrimLib.REDUCE: + return None + if _check_reduce_exclude(dom): + return None + fused = [] + for a, r in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \ + a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH: + fused.append(a) + return fused + changed = True + while changed: + changed = self.fuse(_elemwise_depth) + changed = self.fuse(_elemwise_width) or changed + changed = self.fuse(_broadcast_depth) or changed + changed = self.fuse(_broadcast_width) or changed + changed = self.fuse(_reduce_depth) or changed + changed = self.fuse(_reduce_width) or changed + subgraphs, graphmodes = self.to_subgraphs() + return subgraphs, graphmodes def split(graph): + """Split graph""" return GraphSplitByPattern(graph).split() diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index 9ceb1bc827..f5772b586f 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -196,8 +196,7 @@ class CompositeGraph: shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) cur_fusion = None for op in desc['op_desc']: - inputs = [self.tensors[d[0]['tensor_name']] - for d in op['input_desc'] if 'value' not in d[0]] + inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] out_desc = op['output_desc'] name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[ 0]['shape'], out_desc[0]['data_type'], out_desc[0]['format'] @@ -263,7 +262,7 @@ class CompositeGraph: self.tensors[y], True) inplace_desc = copy.deepcopy(d) inplace_desc['attr'] = {'name': 'fake_output', 'value': fake} - z_desc, out_desc = inplace_desc['input_desc'][2][0].inplace_desc['output_desc'][0] + z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0] z_desc['shape'] = z.shape z_desc['data_type'] = z.dtype z_desc['tensor_name'] = z.name diff --git a/mindspore/_extends/graph_kernel/splitter.py b/mindspore/_extends/graph_kernel/splitter.py index 38f30c693b..bcb74ac718 100644 --- a/mindspore/_extends/graph_kernel/splitter.py +++ b/mindspore/_extends/graph_kernel/splitter.py @@ -26,10 +26,12 @@ def split_with_json(json_str: str): try: graph_desc = json.loads(json_str) comp = model.load_composite(graph_desc) - graph_split = model.split(comp.graph) + graph_split, graph_mode = model.split(comp.graph) is_multi_graph = len(graph_split) > 1 graph_list = list(map(comp.dump, graph_split)) - result = {"multi_graph": is_multi_graph, "graph_desc": graph_list} + result = {"multi_graph": is_multi_graph, + "graph_desc": graph_list, + "graph_mode": graph_mode} return json.dumps(result) except jd.JSONDecodeError: logger.error(traceback.format_exc()) diff --git a/mindspore/_extends/graph_kernel/tests/test_split.py b/mindspore/_extends/graph_kernel/tests/test_split.py deleted file mode 100644 index 670872fdfe..0000000000 --- a/mindspore/_extends/graph_kernel/tests/test_split.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# =========================================================================== -"""test split""" -import model - - -def graph_1(): - gb = model.GraphBuilder() - with gb.graph_scope("main"): - a = gb.tensor([1024, 16], "float32", name="a") - b = gb.emit("Abs", a, 'b') - c = gb.emit("Abs", b, 'c') - d = gb.emit("Abs", c, 'd') - gb.emit("TensorAdd", [b, d], "e") - return gb.get()[0] - - -def graph_2(): - gb = model.GraphBuilder() - with gb.graph_scope("main"): - a = gb.tensor([1024, 16], "float32", name="a") - b = gb.emit("Abs", a, 'b') - c = gb.emit("Abs", b, 'c') - d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)}) - gb.emit("Sqrt", d, 'e') - return gb.get()[0] - - -def test_split_by_pattern(): - def _test(graph): - print("***************** main graph ***************") - print(graph) - subgraphs = model.split(graph) - for i, g in enumerate(subgraphs): - print('------------- subgraph {} --------------'.format(i)) - print(g) - _test(graph_2()) - - -if __name__ == '__main__': - test_split_by_pattern() diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index 840a7f13ff..e390702163 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -485,7 +485,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf (*kernel_json)[kJsonKeyPlatform] = "AKG"; (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); (*kernel_json)[kJsonKeyComposite] = true; - (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString(); + (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { MS_LOG(ERROR) << "Cal mem size failed."; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index da7d0bc985..aa6b06bdd3 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -37,22 +37,17 @@ namespace opt { namespace { bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) { #if ENABLE_D - std::vector fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, + std::vector fusible_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimExpandDims}; if (!is_before_kernel_select) { - fusable_basic_ops.push_back(prim::kPrimCast); + fusible_basic_ops.push_back(prim::kPrimCast); } #elif ENABLE_GPU - std::vector fusable_basic_ops = { - prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, - prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, - prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, - prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, - prim::kPrimGreater, prim::kPrimAssign}; + std::vector fusible_basic_ops = GetFusibleOpList(); #else - std::vector fusable_basic_ops; + std::vector fusible_basic_ops; #endif - return std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), + return std::any_of(fusible_basic_ops.begin(), fusible_basic_ops.end(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc index 97741f455c..60d3d560e5 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc @@ -49,12 +49,7 @@ bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { basic_ops.push_back(prim::kPrimCast); } #elif ENABLE_GPU - std::vector basic_ops = { - prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, - prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, - prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, - prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, - prim::kPrimGreater, prim::kPrimAssign}; + std::vector basic_ops = GetFusibleOpList(); #else std::vector basic_ops; #endif diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 5189bb030e..97ec115cdd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -26,8 +26,8 @@ #include "ir/func_graph_cloner.h" #include "ir/func_graph.h" #include "backend/optimizer/pass/const_input_to_attr_registry.h" -#ifdef ENABLE_D -#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#if ENABLE_GPU +#include "runtime/device/gpu/kernel_info_setter.h" #endif namespace mindspore { @@ -612,36 +612,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector &address_node_map, - std::vector *res_graphs) { - MS_EXCEPTION_IF_NULL(res_graphs); - auto kernel_json = nlohmann::json::parse(json_desc); - if (kernel_json.find(kJsonKeyMultiGraph) == kernel_json.end() || kernel_json[kJsonKeyMultiGraph].is_null()) { - // not multi graphs. - MS_LOG(ERROR) << "Input json is not multi graph, " << json_desc; - return false; - } - - kernel::AkgKernelJsonDecoder akg_kernel_json_decoder; - std::vector graph_descs = kernel_json[kJsonKeyGraphDesc]; - if (graph_descs.empty()) { - MS_LOG(ERROR) << "No sub graph found, " << json_desc; - return false; - } - - for (size_t i = 0; i < graph_descs.size(); ++i) { - const auto &graph_desc = graph_descs[i]; - AnfNodePtrList res_graph; - if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { - MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; - return false; - } - res_graphs->push_back(res_graph); - } - - return true; -} - std::unordered_set GetExpandOps() { std::unordered_set expand_ops = { prim::kPrimSquare, @@ -664,5 +634,23 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p } return name.str(); } + +std::vector GetFusibleOpList() { + std::vector fusible_basic_ops = { + prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, + prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, + prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, + prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, + prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum}; + return fusible_basic_ops; +} + +void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); +#if ENABLE_GPU + device::gpu::SetKernelInfo(cnode, kernel_type); +#endif +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 1bd74664fc..ebcddbdff3 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -35,6 +35,7 @@ constexpr auto kGraphKernelSplitFunc = "split_with_json"; constexpr auto kGetGraphKernelOpExpander = "get_op_expander"; constexpr auto kJsonKeyMultiGraph = "multi_graph"; constexpr auto kJsonKeyGraphDesc = "graph_desc"; +constexpr auto kJsonKeyGraphMode = "graph_mode"; void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, kernel::Processor processor); @@ -50,10 +51,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n std::map *address_node_map = nullptr); bool AnfToJsonDesc(const std::vector &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector &inputs); -bool JsonDescToAnf(const std::string &json_desc, const std::map &address_node_map, - std::vector *res_graphs); std::unordered_set GetExpandOps(); std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); +std::vector GetFusibleOpList(); +void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index ac000d0739..ff8d27c0ed 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -26,6 +26,7 @@ #include "pipeline/jit/parse/python_adapter.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "debug/anf_ir_dump.h" @@ -203,7 +204,7 @@ class AreaGraph { } SortCNodes(main_cnodes); - cnode_group_id->swap(topo_order_); // The topo_order is not used anymore. + *cnode_group_id = std::move(topo_order_); // The topo_order is not used anymore. return; } @@ -291,7 +292,7 @@ class AreaGraph { std::vector main_cnodes_sorted; std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted), [main_cnodes](int index) { return main_cnodes->at(index); }); - main_cnodes->swap(main_cnodes_sorted); + *main_cnodes = std::move(main_cnodes_sorted); } // Areas in this subgraph @@ -415,6 +416,9 @@ class Splitter { cnode->set_input(i, iter->second); } } + if (AnfAlgo::IsRealKernel(node)) { + ResetKernelInfo(node); + } } } return output; @@ -445,7 +449,7 @@ class Splitter { tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]); } } - new_subgraph_cnodes_.swap(tmp_subgraph_cnodes); + new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes); TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) { auto cnode = node->cast(); @@ -580,15 +584,38 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { return false; } - // recover json to anf-ir. - split_plan_.clear(); - if (!JsonDescToAnf(split_graphs_str, address_node_map, &split_plan_)) { - MS_LOG(ERROR) << "Failed to decode split graphs."; + if (!DecodeJson(split_graphs_str, address_node_map)) { + MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str; return false; } + return true; + } + + virtual bool DecodeJson(const std::string &json_desc, const std::map &address_node_map) { + auto kernel_json = nlohmann::json::parse(json_desc); + kernel::AkgKernelJsonDecoder akg_kernel_json_decoder; + std::vector graph_descs = kernel_json[kJsonKeyGraphDesc]; + std::vector graph_modes = kernel_json[kJsonKeyGraphMode]; + if (graph_modes.size() != graph_descs.size()) { + MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size(); + return false; + } + + // recover json to anfnode. + split_plan_.clear(); + for (const auto &graph_desc : graph_descs) { + AnfNodePtrList res_graph; + if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { + MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; + return false; + } + split_plan_.push_back(std::move(res_graph)); + } - // The info should be returned from costmodel. - need_inline_.assign(split_plan_.size(), 0); + // ops to be inlined. + need_inline_.clear(); + std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_), + [](const std::string &mode) { return mode == "basic" ? 1 : 0; }); return true; } diff --git a/mindspore/_extends/graph_kernel/tests/env.sh b/tests/st/graph_kernel/model/env.sh similarity index 89% rename from mindspore/_extends/graph_kernel/tests/env.sh rename to tests/st/graph_kernel/model/env.sh index f16225f4b7..d314236f88 100644 --- a/mindspore/_extends/graph_kernel/tests/env.sh +++ b/tests/st/graph_kernel/model/env.sh @@ -13,5 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -PYTHONPATH="$(pwd)/..:${PYTHONPATH}" +PYTHONPATH="$(pwd)/../../../../mindspore/_extends/graph_kernel:${PYTHONPATH}" export PYTHONPATH diff --git a/mindspore/_extends/graph_kernel/tests/graph_kernel_split.py b/tests/st/graph_kernel/model/graph_kernel_split.py similarity index 100% rename from mindspore/_extends/graph_kernel/tests/graph_kernel_split.py rename to tests/st/graph_kernel/model/graph_kernel_split.py diff --git a/tests/st/graph_kernel/model/test_split.py b/tests/st/graph_kernel/model/test_split.py new file mode 100644 index 0000000000..1e824e0ee0 --- /dev/null +++ b/tests/st/graph_kernel/model/test_split.py @@ -0,0 +1,436 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""Test split""" +import model +from model import model as estimate +from model import graph_split as split + + +def get_nodes(sp, ops): + """Get nodes""" + if isinstance(ops[0], str): + new_ops = [] + for t in ops: + for op in sp.graph.ops: + if op.output.name == t: + new_ops.append(op) + break + else: + print("ERROR: not found op: ", t) + ops = new_ops + return [sp.nodes[sp.graph.ops.index(op)] for op in ops] + + +def first_connected(sp, space): + for cand in space: + nodes = [sp.nodes[i] for i in cand[0]] + graphs = sp.resolve_connnected_graphs(nodes) + if len(graphs) != 1: + print("connect check faied: ", nodes) + return False + return True + + +def split_format(sp, cand): + names = [] + for ids in cand: + ops = [] + for i in ids: + ops.append(sp.graph.ops[i].output.name) + names.append(','.join(ops)) + return '|'.join(names) + + +def graph_1(): + ''' ring, no succ_dep, no prev ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a = gb.tensor([10240, 16], "float32", name="a") + b = gb.emit("Abs", a, 'b') + c = gb.emit("Abs", b, 'c') + d = gb.emit("Abs", c, 'd') + gb.emit('TensorAdd', [b, d], 'e') + return gb.get()[0] + + +def graph_2(): + ''' ring, succ_dep, no prev ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([10240, 16], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("Abs", a, 'c') + d = gb.emit("Abs", b, 'd') + e = gb.emit('TensorAdd', [c, d], 'e') + gb.emit("Abs", e, 'f') + return gb.get()[0] + + +def graph_3(): + ''' no ring, 1 sibling node ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([10240, 16], "float32", name="a0") + a1 = gb.tensor([10240, 16], "float32", name="a1") + b = gb.emit("Abs", a0, 'b') + c = gb.emit("Abs", a1, 'c') + d = gb.emit("Abs", b, 'd') + e = gb.emit('TensorAdd', [c, d], 'e') + gb.emit("Abs", e, 'f') + return gb.get()[0] + + +def graph_4(): + ''' no ring, 2 sibling nodes in 1 step ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([10240, 16], "float32", name="a0") + a1 = gb.tensor([10240, 16], "float32", name="a1") + b = gb.emit("Abs", a0, 'b') + c = gb.emit("Abs", b, 'c') + d = gb.emit("Abs", a1, 'd') + e = gb.emit("Abs", d, 'e') + f = gb.emit('TensorAdd', [c, e], 'f') + gb.emit('Abs', f, 'g') + h = gb.emit("Abs", d, 'h') + i = gb.emit('TensorAdd', [c, h], 'i') + gb.emit("Abs", i, 'j') + return gb.get()[0] + + +def graph_5(): + ''' no ring, 2 sibling step ''' + gb = model.GraphBuilder() + with gb.graph_scope("main") as g: + a0 = gb.tensor([10240, 16], "float32", name="a0") + a1 = gb.tensor([10240, 16], "float32", name="a1") + a2 = gb.tensor([10240, 16], "float32", name="a2") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a1, 'b') + c = gb.emit("Abs", b, 'c') + d = gb.emit('TensorAdd', [a, c], 'd') + gb.emit("Abs", d, 'e') + f = gb.emit("Abs", a2, 'f') + g = gb.emit('TensorAdd', [c, f], 'g') + gb.emit("Abs", g, 'h') + return gb.get()[0] + + +def graph_6(): + ''' no ring, tree down ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([10240, 16], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + gb.emit("Abs", b, 'd') + gb.emit("Abs", b, 'e') + c = gb.emit("Abs", a, 'c') + gb.emit("Abs", c, 'f') + gb.emit("Abs", c, 'g') + return gb.get()[0] + + +def graph_pat_1(): + ''' split by reduce ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) + d = gb.emit("Sqrt", c, 'd') + gb.emit("Sqrt", d, 'f') + return gb.get()[0] + + +def graph_pat_2(): + ''' multi output ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) + gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)}) + return gb.get()[0] + + +def graph_pat_3(): + ''' two reduce ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) + d = gb.emit("Abs", c, 'd') + gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)}) + return gb.get()[0] + + +def graph_pat_4(): + ''' elewise + broadcast ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1, 1024], "float32", name="a0") + a2 = gb.tensor([1014, 1024], "float32", name="a2") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("Abs", b, 'c') + d = gb.emit("Abs", c, 'd') + e = gb.emit("Abs", d, 'e') + f = gb.emit("Abs", e, 'f') + g0 = gb.emit("Abs", a2, 'g0') + # g0 = gb.emit("Abs", g0, 'g0') + # g0 = gb.emit("Abs", g0, 'g0') + # g0 = gb.emit("Abs", g0, 'g0') + # g0 = gb.emit("Abs", g0, 'g0') + # g0 = gb.emit("Abs", g0, 'g0') + # g0 = gb.emit("Abs", g0, 'g0') + g0 = gb.emit("Abs", g0, 'g0') + g1 = gb.emit('TensorAdd', [f, g0], 'g1') + g2 = gb.emit("Abs", g1, 'g2') + g3 = gb.emit("Abs", g2, 'g3') + g4 = gb.emit("Abs", g3, 'g4') + gb.emit("Abs", g4, 'g5') + return gb.get()[0] + + +def graph_pat_5(): + ''' reduce + reshape ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) + d = gb.emit("Abs", c, 'd') + e = gb.tensor([512, 2048], "float32", name="e") + gb.op("Reshape", e, [d]) + return gb.get()[0] + + +def graph_pat_6(): + ''' dimond ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("Abs", a, 'c') + gb.emit("TensorAdd", [b, c], 'd') + gb.emit("Abs", c, 'f') # broke dimond + return gb.get()[0] + + +def graph_pat_7(): + ''' buddy of control op ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a1 = gb.tensor([1024, 1024], "float32", name="a1") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a1, 'b') + c = gb.emit("make_tuple", [a, b], 'c') + d = gb.tensor([1024, 1024], "float32", name="d") + gb.op("AddN", d, [c]) + gb.emit("Abs", d, 'f') + graph = gb.get()[0] + estimate.AddControlBuddy().visit_graph(graph) + return graph + + +def graph_pat_8(): + ''' reduce + reshape ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + #c = gb.emit("Abs", b, 'b') + c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) + gb.emit("TensorAdd", [b, c], 'd') + return gb.get()[0] + + +def graph_pat_9(): + ''' scalar ''' + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a1 = gb.tensor([1], "float32", name="a1") + a = gb.emit("Maximum", a1, 'a') + b = gb.emit("Mul", [a, a1], 'b') + gb.emit('Mul', [b, a0], 'c') + return gb.get()[0] + + +def graph_mo_1(): + gb = model.GraphBuilder() + with gb.graph_scope("main"): + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + gb.emit("Abs", a, 'b') + gb.emit("Abs", a, 'c') + return gb.get()[0] + + +def graph_mo_2(): + gb = model.GraphBuilder() + with gb.graph_scope("main") as g: + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("Abs", b, 'c') + g.set_output(b, c) + return gb.get()[0] + + +def graph_mo_3(): + ''' two reduce ''' + gb = model.GraphBuilder() + with gb.graph_scope("main") as g: + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) + g.set_output(b, c) + return gb.get()[0] + + +def graph_mo_4(): + ''' two reduce ''' + gb = model.GraphBuilder() + with gb.graph_scope("main") as g: + a0 = gb.tensor([1024, 1024], "float32", name="a0") + a = gb.emit("Abs", a0, 'a') + b = gb.emit("Abs", a, 'b') + c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)}) + g.set_output(b, c) + return gb.get()[0] + + +def test_binary_split(): + """Test binary split""" + def _test(graph, expected_space_size): + print("********* test on graph : {} *************".format(graph.name)) + sp = split.GraphSpliter(graph) + nodes = get_nodes(sp, graph.ops) + space = sp.binary_split(nodes) + for i, s in enumerate(space): + print('{}: {}'.format(i, split_format(sp, s))) + assert len(space) == expected_space_size + assert first_connected(sp, space) + _test(graph_1(), 3) + _test(graph_2(), 7) + _test(graph_3(), 4) + _test(graph_4(), 17) + _test(graph_5(), 11) + _test(graph_6(), 24) + + +def test_resolve_connnected_graphs(): + """Test resolve connected graphs""" + graph = graph_5() + sp = split.GraphSpliter(graph) + n1 = get_nodes(sp, ['a', 'd', 'b', 'c']) + graphs = sp.resolve_connnected_graphs(n1) + print(graphs) + assert len(graphs) == 1 + n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g']) + graphs = sp.resolve_connnected_graphs(n2) + print(graphs) + assert len(graphs) == 2 + n3 = get_nodes(sp, ['a', 'b', 'f']) + graphs = sp.resolve_connnected_graphs(n3) + print(graphs) + assert len(graphs) == 3 + + +def test_split(): + """Test split""" + def _print_cost(name, c): + print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" % + (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type())) + + def _test(graph): + print("********* test on graph : {} *************".format(graph.name)) + sp = split.GraphSpliter(graph) + subgraphs = sp.split(False) + print('----- main graph -------') + print(graph) + for i, g in enumerate(subgraphs): + print(' -------- subgraph {} -------'.format(i)) + print(g) + print("--------- cost ------------") + cost, _ = model.estimate(graph) + _print_cost("main graph", cost) + fc, sub_costs = model.estimate(subgraphs) + _print_cost("Subgraphs:", fc) + for i, cost in enumerate(sub_costs): + _print_cost(" |_%d:\t" % (i), cost) + _test(graph_5()) + # _test(graph_4()) + + +def test_estimate(): + """Test estimate""" + graph = graph_5() + e = estimate.Estimator(graph) + e.estimate() + print(e.iter_space) + + +def test_pattern_split(): + """Test pattern split""" + def _test(graph, expect_n=0): + print("************* main graph **************") + print(graph) + subgraphs = split.GraphSplitByPatternV2(graph).split() + for i, g in enumerate(subgraphs): + print(' -------- subgraph {} -------'.format(i)) + print(g) + if expect_n > 0: + assert len(subgraphs) == expect_n + + # _test(graph_1(), 1) + # _test(graph_pat_1(), 2) + # _test(graph_pat_2()) + # _test(graph_pat_3()) + # _test(graph_pat_4()) + # _test(graph_pat_5()) + # _test(graph_pat_6()) + # _test(graph_pat_7()) + # _test(graph_pat_8()) + # _test(graph_pat_9()) + + # _test(graph_mo_1()) + # _test(graph_mo_2()) + # _test(graph_mo_3()) + _test(graph_mo_4()) + + +def main(): + # test_binary_split() + # test_resolve_connnected_graphs() + # test_split() + # test_estimate() + test_pattern_split() + + +if __name__ == '__main__': + main()