diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 8ed290cc13..38c040e6b1 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr; using kernel::KernelMod; using kernel::KernelModPtr; namespace { +constexpr size_t kNopNodeInputSize = 2; +constexpr size_t kNopNodeRealInputIndex = 1; + std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { MS_EXCEPTION_IF_NULL(shape); std::vector shape_size_t; @@ -48,6 +51,26 @@ std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { } } // namespace +AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) { + MS_EXCEPTION_IF_NULL(tuple_get_item); + if (tuple_get_item->size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem); +} + +size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { + MS_EXCEPTION_IF_NULL(tuple_get_item); + if (tuple_get_item->size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(output_index_value_node); + auto value_node = output_index_value_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return IntToSize(GetValue(value_node->value())); +} + KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) { MS_EXCEPTION_IF_NULL(anf_node); if (anf_node->isa()) { @@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz } } -KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, +KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index, bool visit_nop_node, const std::vector &return_types) { MS_EXCEPTION_IF_NULL(anf_node); - for (const auto &prim_type : return_types) { - if (CheckPrimitiveType(anf_node, prim_type)) { - return std::make_pair(anf_node, index); - } + if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool { + return CheckPrimitiveType(anf_node, prim_type); + })) { + return KernelWithIndex(anf_node, index); } - if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input0 = cnode->input(0); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(input2); - auto value_node = input2->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), - visit_nop_node, return_types); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); - } else if (opt::IsNopNode(cnode) && visit_nop_node) { - if (cnode->inputs().size() == 2) { - return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); - } else { - MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; + if (!anf_node->isa()) { + return KernelWithIndex(anf_node, 0); + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) { + auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode), + GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types); + if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) { + MS_EXCEPTION_IF_NULL(item_with_index_tmp.first); + auto make_tuple = item_with_index_tmp.first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + const std::vector &make_tuple_inputs = make_tuple->inputs(); + size_t make_tuple_input_index = item_with_index_tmp.second + 1; + if (make_tuple_input_index >= make_tuple_inputs.size()) { + MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size() + << "]."; } - } else { - return std::make_pair(anf_node, index); + return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types); } - } else { - MS_LOG(EXCEPTION) << "The input is invalid"; + return item_with_index_tmp; + } + if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { + return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types); + } + if (opt::IsNopNode(cnode) && visit_nop_node) { + if (cnode->size() != kNopNodeInputSize) { + MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString(); + } + return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types); } + return KernelWithIndex(anf_node, index); } std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node, @@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, if (opt::IsNopNode(node) && visit_nop_node) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() == 2) { + if (cnode->size() == kNopNodeInputSize) { return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0); } else { MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; @@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod if (opt::IsNopNode(node) && visit_nop_node) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() == 2) { + if (cnode->inputs().size() == kNopNodeInputSize) { return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0); } else { MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; @@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || - IsPrimitive(input, prim::kPrimReturn); + IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); return !is_virtual_node; } @@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s } return GetCNodeOutputPrecision(kernel_with_index.first); } + +bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode."; + } + auto input = node->input(kAnfPrimitiveIndex); + return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch); +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index d5e8016a29..4fa3150e36 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -42,9 +42,12 @@ using DeviceAddress = device::DeviceAddress; using DeviceAddressPtr = device::DeviceAddressPtr; class AnfRuntimeAlgorithm { public: + // get real input node of tuple_get_item + static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item); + static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); // get input_anf_node's real kernel by recurse static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); - static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, + static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, int output_index, bool visit_nop_node = false, const std::vector &return_types = { prim::kPrimMakeTuple}); @@ -205,6 +208,7 @@ class AnfRuntimeAlgorithm { static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); // get fix output precision from prev node, input_idx is the input index of current node related to prev node. static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); + static bool IsCondControlKernel(const CNodePtr &node); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 656a6b40ed..274b355679 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -17,6 +17,7 @@ #include "backend/session/ascend_control_parser.h" #include #include +#include #include "backend/session/anf_runtime_algorithm.h" #include "utils/union_find_set.h" #include "runtime/device/ascend/ascend_label_assign.h" @@ -31,94 +32,11 @@ static constexpr size_t kCNodePartialLength = 2; static constexpr size_t kCNodePartialFunc = 1; static constexpr size_t kCNodeSwitchLayerBranch = 2; static constexpr size_t kCNodeSwitchLayerLength = 3; +static constexpr size_t kCNodeAssignTarget = 1; +static constexpr size_t kCNodeAssignSource = 2; namespace mindspore { namespace session { -static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { - auto &nodes = parent_graph->execution_order(); - CNodePtr last_jump_node = nullptr; - for (auto &node : nodes) { - if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { - if (child_graph->get_start_label() == node->input(kCNodeCallArg)) { - return node; - } - last_jump_node = node; - } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { - if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || - child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) { - return node; - } - last_jump_node = node; - } - } - if (last_jump_node == nullptr) { - MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); - } - return last_jump_node; -} - -static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, - const NotNull *> memo) { - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &iter : real_inputs) { - auto ¶ = iter.first; - MS_EXCEPTION_IF_NULL(para); - if (para->isa()) { - union_find_set->Add(para); - } - for (auto &arg : iter.second) { - MS_EXCEPTION_IF_NULL(arg); - if (!arg->isa()) { - continue; - } - union_find_set->Add(arg); - } - } - for (auto &child : kg->child_graph_order()) { - InitUnionFindSet(NOT_NULL(child), union_find_set, memo); - } -} - -static void UnionParentParameter(NotNull kg, const NotNull *> union_find_set, - const NotNull *> memo) { - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &iter : real_inputs) { - auto ¶ = iter.first; - for (auto &arg : iter.second) { - MS_EXCEPTION_IF_NULL(arg); - if (!arg->isa()) { - continue; - } - if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) { - continue; - } - union_find_set->Union(arg, para); - } - } - for (auto &child : kg->child_graph_order()) { - UnionParentParameter(NOT_NULL(child), union_find_set, memo); - } -} - -static UnionFindSet MakeUnionFindSet(NotNull root_kg) { - UnionFindSet result; - std::set memo; - InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); - memo.clear(); - UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); - return result; -} - static void RecursiveReplaceNode(NotNull kg, NotNull main_parameter, const std::set ¶meter_reuse_set, const NotNull *> memo) { @@ -135,8 +53,9 @@ static void RecursiveReplaceNode(NotNull kg, NotNull continue; } MS_EXCEPTION_IF_NULL(para); - MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " - << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); + MS_LOG(INFO) << "In " << kg->ToString() << " replace " << para->DebugString() << " of graph " + << AnfAlgo::GetGraphId(para.get()) << " to " << main_parameter->DebugString() << " of graph " + << AnfAlgo::GetGraphId(main_parameter.get().get()); kg->ReplaceNode(NOT_NULL(para), main_parameter); } @@ -145,7 +64,7 @@ static void RecursiveReplaceNode(NotNull kg, NotNull } } -static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr key, +static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr &key, const std::set ¶meter_reuse_set) { AnfNodePtr main_parameter = key; std::set root_inputs_set; @@ -160,8 +79,19 @@ static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNod return main_parameter; } -static void ReuseParameter(NotNull root_kg, NotNull *> parameter_set) { - auto parameter_reuse_sets = parameter_set->GetSets(); +static void ReuseParameter(NotNull root_kg, + const std::vector> &link_list) { + // make union find set + UnionFindSet union_find_set; + for (auto &[param, arg] : link_list) { + union_find_set.Add(param); + union_find_set.Add(arg); + } + for (auto &[param, arg] : link_list) { + union_find_set.Union(param, arg); + } + auto parameter_reuse_sets = union_find_set.GetSets(); + for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { if (parameter_reuse_set.size() <= 1) { continue; @@ -172,7 +102,7 @@ static void ReuseParameter(NotNull root_kg, NotNull &list, size_t start) { +static CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { for (size_t i = start; i < list.size() - 1; ++i) { if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { return list[i]; @@ -181,71 +111,287 @@ CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { return nullptr; } +static void UpdateLabelIdToLabelSetMap(const std::vector &exec_order, + const NotNull *> label_id_to_label_set) { + for (auto &node : exec_order) { + MS_EXCEPTION_IF_NULL(node); + if (!IsPrimitiveCNode(node, prim::kPrimLabelSet)) { + continue; + } + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + MS_LOG(EXCEPTION) << node->DebugString() << " has no attr kAttrLabelIndex"; + } + uint32_t label_id = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + if (auto iter = label_id_to_label_set->find(label_id); iter != label_id_to_label_set->end()) { + MS_LOG(EXCEPTION) << "There are more than one node has same label id " << label_id + << ", node: " << iter->second->DebugString() << " and " << node->DebugString(); + } + (*label_id_to_label_set)[label_id] = node; + } +} + +static std::vector GetTargetLabelSetNodes(NotNull jump_node, + const std::map &label_id_to_label_set) { + std::vector target_label_list; + std::vector target_labelset_nodes; + if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelGoto)) { + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, jump_node)) { + MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kAttrLabelIndex"; + } + uint32_t label_id = AnfAlgo::GetNodeAttr(jump_node.get(), kAttrLabelIndex); + target_label_list.push_back(label_id); + } else if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelSwitch)) { + if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, jump_node)) { + MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kPrimLabelSwitch"; + } + target_label_list = AnfAlgo::GetNodeAttr>(jump_node.get(), kAttrLabelSwitchList); + } else { + MS_LOG(EXCEPTION) << "Unknown type jump node " << jump_node->DebugString(); + } + + for (auto label_id : target_label_list) { + auto iter = label_id_to_label_set.find(label_id); + if (iter == label_id_to_label_set.end()) { + MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id; + } + target_labelset_nodes.push_back(iter->second); + } + return target_labelset_nodes; +} + +static void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull *> exec_order) { + MS_EXCEPTION_IF_NULL(node); + auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node); + if (exec_iter == exec_order->end()) { + MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order."; + } + exec_order->erase(exec_iter); +} + void AscendControlParser::LinkGraph(NotNull kg) { std::set memo; + std::vector> link_list; + // Insert Assign + ChildGraphDataAssign(kg, NOT_NULL(&link_list), NOT_NULL(&memo)); + // Reuse Parameter + ReuseParameter(kg, link_list); + // replace call by label goto / label switch + memo.clear(); (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); + // assign label resource device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); - std::map graph_id_map; - for (auto &g : memo) { - MS_EXCEPTION_IF_NULL(g); - if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { - MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id() - << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString(); +} + +void AscendControlParser::EraseParameter(NotNull root_graph, + const std::set &graph_list) { + std::vector exec_order = root_graph->execution_order(); + std::set search_list(exec_order.begin(), exec_order.end()); + std::set root_inputs(root_graph->inputs().begin(), root_graph->inputs().end()); + auto ref_map = root_graph->GetRefMap(); + ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; }); + std::multimap> ref_multimap; + std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()), + [](const std::pair, std::pair> &p) + -> std::pair> { + return {p.first.first, {p.first.second, p.second.first, p.second.second}}; + }); + std::set all_nodes; + std::map para_to_written_node; + for (auto &graph : graph_list) { + auto out = graph->get_return(); + MS_EXCEPTION_IF_NULL(out); + search_list.insert(out->cast()); + auto nodes = TopoSort(out); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode != nullptr) { + all_nodes.insert(cnode); + } + } + } + // prepare referance count + for (auto &node : search_list) { + MS_EXCEPTION_IF_NULL(node); + // if assign node + std::set refed_parameters; + for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) { + refed_parameters.insert(std::get<1>(iter->second)); + } + + for (auto &in : node->inputs()) { + auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first; + if (!visit_node->isa() || root_inputs.find(visit_node) != root_inputs.end()) { + continue; + } + if (refed_parameters.find(visit_node) != refed_parameters.end()) { + parameter_count.AddWriteCount(visit_node, 1); + para_to_written_node[visit_node] = node; + } else { + parameter_count.AddReadCount(visit_node, 1); + } } - graph_id_map[g->graph_id()] = g; } - // Insert Assign - ChildGraphDataAssign(graph_id_map); - // Make UnionFindSet - UnionFindSet parameter_set = MakeUnionFindSet(kg); - // Reuse Parameter - ReuseParameter(kg, NOT_NULL(¶meter_set)); + while (parameter_count.HasValidElem()) { + auto [para, read, written] = parameter_count.GetOneValidElem(); + MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; + auto assign_iter = para_to_written_node.find(para); + if (assign_iter == para_to_written_node.end()) { + MS_LOG(EXCEPTION) << "Cannot find assign node that write " << para->DebugString(); + } + auto &assign_node = assign_iter->second; + MS_EXCEPTION_IF_NULL(assign_node); + if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) { + parameter_count.EraseElem(para); + continue; + } + MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); + EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); + + auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first; + parameter_count.AddReadCount(source, -1); + parameter_count.AddWriteCount(para, -1); + for (auto &node : all_nodes) { + for (size_t i = 0; i < node->size(); ++i) { + if (node->input(i) == para) { + MS_LOG_INFO << "Replace " << node->DebugString() << " input " << i << " by " << source->DebugString(); + node->set_input(i, source); + } + } + } + parameter_count.AddReadCount(source, 1); + parameter_count.AddReadCount(para, -1); + } + root_graph->set_execution_order(exec_order); +} + +void AscendControlParser::EraseLabel(NotNull root_graph) { + std::vector exec_order = root_graph->execution_order(); + ReferenceCounter label_count([](int32_t read, int32_t write) -> bool { return read <= 1; }); + std::map label_to_written_node; + std::map label_id_to_label_set; + UpdateLabelIdToLabelSetMap(exec_order, NOT_NULL(&label_id_to_label_set)); + CNodePtr last_node = nullptr; + for (auto &cur_node : exec_order) { + MS_EXCEPTION_IF_NULL(cur_node); + if (AnfAlgo::IsCondControlKernel(cur_node)) { + std::vector target_labelset_nodes = GetTargetLabelSetNodes(NOT_NULL(cur_node), label_id_to_label_set); + for (auto &label_set : target_labelset_nodes) { + label_count.AddReadCount(label_set, 1); + label_to_written_node[label_set] = cur_node; + } + } else if (IsPrimitiveCNode(cur_node, prim::kPrimLabelSet)) { + label_count.AddWriteCount(cur_node, 1); + if (last_node != nullptr && !AnfAlgo::IsCondControlKernel(last_node)) { + label_count.AddReadCount(cur_node, 1); + label_to_written_node[cur_node] = last_node; + } + } + last_node = cur_node; + } + + while (label_count.HasValidElem()) { + auto [label_set, read, written] = label_count.GetOneValidElem(); + MS_LOG(INFO) << label_set->DebugString() << " was read " << read << " times, written " << written << " times."; + auto iter = label_to_written_node.find(label_set); + if (read > 0 && iter == label_to_written_node.end()) { + MS_LOG(EXCEPTION) << "Cannot find node jump to " << label_set->DebugString(); + } + CNodePtr jump_node = read > 0 ? iter->second : nullptr; + if (jump_node == nullptr || IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { + MS_LOG(INFO) << "Erase node " << label_set->DebugString(); + EraseNodeFromExecOrder(label_set, NOT_NULL(&exec_order)); + } + if (jump_node != nullptr && IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { + MS_LOG(INFO) << "Erase node " << jump_node->DebugString(); + EraseNodeFromExecOrder(jump_node, NOT_NULL(&exec_order)); + } + label_count.EraseElem(label_set); + } + + root_graph->set_execution_order(exec_order); } void AscendControlParser::ExecutorValidate(NotNull root_graph) { std::set memo; (void)RecurseGraph(root_graph, NOT_NULL(&memo)); + EraseParameter(root_graph, memo); + EraseLabel(root_graph); } -void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { - for (auto &iter : graph_id_map) { - auto &kg = iter.second; - MS_LOG(INFO) << "Data assign graph:" << kg->graph_id(); - MS_EXCEPTION_IF_NULL(kg); - std::set> memo; - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &it : real_inputs) { - auto ¶meter = it.first; - auto &args = it.second; - for (auto &arg : args) { - MS_EXCEPTION_IF_NULL(arg); - if (memo.find({parameter, arg}) != memo.end()) { - continue; - } else { - memo.emplace(parameter, arg); - } - auto unreuse_args_map = kg->unreuse_args(); - auto unreuse_arg_iter = unreuse_args_map.find(arg); - if (unreuse_arg_iter == unreuse_args_map.end()) { - MS_EXCEPTION_IF_NULL(arg); - MS_EXCEPTION_IF_NULL(parameter); - if (!arg->isa()) { - MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << "."; - } - MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() - << ", arg:" << arg->DebugString(); +std::vector>> AscendControlParser::ParseCallNode( + NotNull call_node) { + std::vector>> ret; + if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) { + MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node."; + } + if (call_node->size() <= kCNodeCallArg) { + MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size(); + } + const std::vector &call_node_inputs = call_node->inputs(); + auto call_arg = call_node_inputs[kCNodeCallArg]; + MS_EXCEPTION_IF_NULL(call_arg); + if (IsValueNode(call_arg)) { + ret.emplace_back(GetValueNode(call_arg), + std::vector(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end())); + } else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) { + auto switch_cnode = call_arg->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + const std::vector &switch_inputs = switch_cnode->inputs(); + if (switch_inputs.size() <= kCNodeSwitchCond) { + MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size " + << switch_inputs.size(); + } + for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { + const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); + ret.emplace_back(target_graph, args); + } + } else { + MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5); + } + return ret; +} + +void AscendControlParser::ChildGraphDataAssign( + NotNull kg, const NotNull> *> link_list, + const NotNull *> memo) { + if (memo->find(kg) != memo->end()) { + return; + } + memo->insert(kg.get()); + + MS_LOG(INFO) << "Start link data for " << kg->ToString(); + const std::vector &nodes = kg->execution_order(); + + for (auto &node : nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimCall)) { + continue; + } + + auto child_graph_list = ParseCallNode(NOT_NULL(node)); + for (auto &[child_graph, args] : child_graph_list) { + MS_EXCEPTION_IF_NULL(child_graph); + const std::vector ¶ms = child_graph->inputs(); + if (args.size() != params.size()) { + MS_LOG(EXCEPTION) << child_graph->ToString() << " needs " << params.size() << " inputs but call node " + << node->DebugString(5) << " gives " << args.size(); + } + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->isa() && memo->find(child_graph) == memo->end()) { + MS_LOG(INFO) << args[i]->DebugString() << " to " << params[i]->DebugString() + << " should be reused, continue."; + link_list->emplace_back(args[i], params[i]); continue; } - auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); - if (target_graph_iter == graph_id_map.end()) { - MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; - } - InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), - NOT_NULL(parameter)); + + InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); } } - kg->SetExecOrderByDefault(); + } + kg->SetExecOrderByDefault(); + for (auto &child_graph : kg->child_graph_order()) { + ChildGraphDataAssign(NOT_NULL(child_graph), link_list, memo); } } @@ -325,7 +471,7 @@ void AscendControlParser::InsertDependToGraph(NotNull kg, NotNul std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), return_node->input(kFirstDataInputIndex), attch_node.get()}; auto depend_node = kg->NewCNode(inputs); - return_node->set_input(1, depend_node); + return_node->set_input(kFirstDataInputIndex, depend_node); } void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, @@ -381,6 +527,7 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNullset_inputs(new_inputs); cur_node->set_abstract(nullptr); + AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>({call_kg}), cur_node.get()); MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); } @@ -409,9 +556,12 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNull new_switch_inputs = { std::make_shared(std::make_shared(kLabelSwitchOpName)), origin_switch_inputs[kCNodeSwitchCond]}; + std::vector child_graphs; for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + KernelGraphPtr branch_fg; + std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + child_graphs.push_back(branch_fg); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); @@ -420,6 +570,7 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNullset_inputs(new_switch_inputs); cur_node->set_abstract(nullptr); + AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), cur_node.get()); MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); } @@ -453,9 +604,12 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull std::vector new_switch_inputs = { std::make_shared(std::make_shared(kLabelSwitchOpName)), origin_switch_inputs[kCNodeSwitchCond]}; + std::vector child_graphs; for (size_t i = 0; i < branch_partial.size(); ++i) { // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + KernelGraphPtr branch_fg; + std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + child_graphs.push_back(branch_fg); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); @@ -463,13 +617,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); cur_node->set_inputs(new_switch_inputs); cur_node->set_abstract(nullptr); + AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), cur_node.get()); MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); } -KernelGraphPtr AscendControlParser::ParsePartial(NotNull node) { +std::tuple> AscendControlParser::ParsePartial(NotNull node) { if (!node.get()->isa()) { if (IsValueNode(node)) { - return GetValueNode(node); + return {GetValueNode(node), {}}; } MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); } @@ -485,12 +640,11 @@ KernelGraphPtr AscendControlParser::ParsePartial(NotNull node) { MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; } auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); - return branch_kg; + return {branch_kg, std::vector(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end())}; } -void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, - NotNull to_graph, NotNull from, - NotNull to) { +void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, const AnfNodePtr &jump_node, + NotNull from, NotNull to) { std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; @@ -500,22 +654,35 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull fr } for (size_t i = 0; i < from_outputs.size(); i++) { auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); - if (assign_node != nullptr) { - auto jump_node = GetJumpNode(from_graph, to_graph); - const auto &from_graph_exe_order = from_graph->execution_order(); - auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); - if (jump_node_iter == from_graph_exe_order.end()) { - MS_EXCEPTION_IF_NULL(jump_node); - MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id(); - } - // insert assign between jump_node -1 and jump_node - if (jump_node_iter != from_graph_exe_order.begin()) { - InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); - } - if (jump_node != nullptr) { - InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); + const auto &from_graph_exe_order = from_graph->execution_order(); + std::vector real_exe_order(from_graph_exe_order.size()); + size_t real_exe_order_size = 0; + std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(), + [&real_exe_order_size](const CNodePtr &node) -> bool { + return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial)) + ? false + : (++real_exe_order_size, true); + }); + real_exe_order.resize(real_exe_order_size); + if (jump_node == nullptr) { + if (!real_exe_order.empty()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node)); + } else { + InsertDependToGraph(from_graph, NOT_NULL(assign_node)); } + continue; + } + + auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node); + if (jump_node_iter == real_exe_order.end()) { + MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " + << from_graph->ToString(); } + // insert assign between jump_node -1 and jump_node + if (jump_node_iter != real_exe_order.begin()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); + } + InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); } } @@ -618,26 +785,45 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i } } -void AscendControlParser::UpdateChildGraphOrder(NotNull kg) { - MS_LOG(INFO) << "Graph id:" << kg->graph_id(); - kg->SetExecOrderByDefault(); - auto call_nodes = kg->FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); - std::vector child_graph_order; - for (auto &call_node : call_nodes) { - MS_EXCEPTION_IF_NULL(call_node); - auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); - for (const auto &child_graph : call_child_graphs) { - MS_EXCEPTION_IF_NULL(child_graph); - if (child_graph != kg->parent_graph()) { - child_graph->set_parent_graph(kg.get()); - } - child_graph_order.push_back(child_graph); - } +void AscendControlParser::ReferenceCounter::AddReadCount(const AnfNodePtr &key, int32_t num) { + auto iter = count_.find(key); + if (iter != count_.end()) { + iter->second.first += num; + } else { + count_[key] = {num, 0}; } - for (size_t i = 0; i < child_graph_order.size(); i++) { - MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; +} + +void AscendControlParser::ReferenceCounter::AddWriteCount(const AnfNodePtr &key, int32_t num) { + auto iter = count_.find(key); + if (iter != count_.end()) { + iter->second.second += num; + } else { + count_[key] = {0, num}; + } +} + +void AscendControlParser::ReferenceCounter::EraseElem(const AnfNodePtr &key) { count_.erase(key); } + +bool AscendControlParser::ReferenceCounter::HasValidElem() const { + auto it = std::find_if(count_.begin(), count_.end(), + [this](const std::pair> &p) -> bool { + auto &[read, written] = p.second; + return predicate_(read, written); + }); + return it != count_.end(); +} + +std::tuple AscendControlParser::ReferenceCounter::GetOneValidElem() const { + auto it = std::find_if(count_.begin(), count_.end(), + [this](const std::pair> &p) -> bool { + auto &[read, written] = p.second; + return predicate_(read, written); + }); + if (it == count_.end()) { + MS_LOG(EXCEPTION) << "No valid parameter."; } - kg->set_child_graph_order(child_graph_order); + return {it->first, it->second.first, it->second.second}; } } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index bd35d68b36..ac24735139 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "backend/session/kernel_graph.h" #include "utils/base_ref.h" #include "utils/contract.h" @@ -29,16 +31,23 @@ namespace mindspore { namespace session { class AscendControlParser { public: - static void ChildGraphDataAssign(const std::map &graph_id_map); static void LinkGraph(NotNull kg); static void InsertDependToGraph(NotNull kg, NotNull attch_node); static void InsertControlDependToGraph(NotNull kg, NotNull first_node, NotNull second_node); static void ExecutorValidate(NotNull root_graph); - static void UpdateChildGraphOrder(NotNull kg); + static void InsertMultipleAssignToGraph(NotNull from_graph, const AnfNodePtr &jump_node, + NotNull from, NotNull to); private: + class ReferenceCounter; + + static void EraseParameter(NotNull root_graph, const std::set &graph_list); + static void EraseLabel(NotNull root_graph); + static void ChildGraphDataAssign(NotNull kg, + const NotNull> *> link_list, + const NotNull *> memo); static NotNull GetStartLabel(NotNull kg, const CNodePtr &last_node, const CNodePtr &last_label); static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, @@ -53,11 +62,10 @@ class AscendControlParser { static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label); - static KernelGraphPtr ParsePartial(NotNull node); - static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, - NotNull from, NotNull to); static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); + static std::vector>> ParseCallNode(NotNull call_node); + static std::tuple> ParsePartial(NotNull node); // root graph order static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, @@ -65,6 +73,19 @@ class AscendControlParser { static std::vector RecurseGraph(NotNull graph, const NotNull *> memo); }; +class AscendControlParser::ReferenceCounter { + public: + explicit ReferenceCounter(std::function func) : predicate_(func), count_() {} + void AddReadCount(const AnfNodePtr &key, int32_t num); + void AddWriteCount(const AnfNodePtr &key, int32_t num); + void EraseElem(const AnfNodePtr &key); + bool HasValidElem() const; + std::tuple GetOneValidElem() const; + + private: + std::function predicate_; + std::map> count_; +}; } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 9995518c00..7348f4f82e 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -289,6 +289,17 @@ static void RecurseToUpdateCallRealInput(NotNull graph, // this action should from bottom to top graph->UpdateCallRealInput(); } + +void InsertMakeTupleForOutput(NotNull root_graph) { + auto return_node = root_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() <= kReturnDataIndex) { + return; + } + auto make_tuple = root_graph->NewCNode( + {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), root_graph->output()}); + root_graph->set_output(make_tuple); +} } // namespace GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { @@ -305,22 +316,39 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { std::vector all_graphs; auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); BackendOptimization(all_graphs); - // split switch - SplitGraphs(NOT_NULL(root_graph)); // empty graph dont entry to backend if (root_graph->execution_order().empty()) { MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; + InsertMakeTupleForOutput(NOT_NULL(root_graph)); root_graph->set_executable(false); InitRuntimeResource(); return root_graph->graph_id(); } + // create parameter for multiple branch + std::set memo; + CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); // insert goto labels and label_sets LinkChildGraphs(NOT_NULL(root_graph)); // resource initialize InitRuntimeResource(); - // recurse compile child root_graph - std::set memo; - RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); + + IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + + SelectKernel(NOT_NULL(root_graph)); + memo.clear(); + + HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + + AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + + UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + // add make_tuple to the output graph + InsertMakeTupleForOutput(NOT_NULL(root_graph)); // root root_graph valiate,include genearte execute order and so on RootGraphExecutorValidate(NOT_NULL(root_graph)); // adjust kernel @@ -1677,7 +1705,7 @@ void AscendSession::SplitGraph(NotNull graph, const std::setget_return())); // update the root graph child graph order - AscendControlParser::UpdateChildGraphOrder(graph); + graph->UpdateChildGraphOrder(); // get child list from current graph std::vector> child_graph_lists = GetChildList(apply_list, cut_prims); if (child_graph_lists.size() > 1) { @@ -1709,7 +1737,7 @@ void AscendSession::SplitGraph(NotNull graph, const std::setUpdateChildGraphOrder(); UpdateRealInput(graph, split_flag, memo); MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; } @@ -1748,5 +1776,216 @@ void AscendSession::RecurseCompileGraph(NotNull graph, const Not } } } + +void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNull *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + memo->insert(graph.get()); + + graph->UpdateChildGraphOrder(); + for (auto &child_graph : graph->child_graph_order()) { + CreateMultiBranchOutput(NOT_NULL(child_graph), memo); + } + + std::map need_replace_list; + auto node_list = GetCNodes(TopoSort(graph->get_return())); + for (auto &node : node_list) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output + // auto multi_output_param = graph->NewParameter(); + auto origin_inputs = graph->inputs(); + auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get()); + MS_EXCEPTION_IF_NULL(graph->MutableInputs()); + graph->MutableInputs()->operator=(origin_inputs); + graph->AddChildGraphResult(output_param); + + std::vector depend_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name()))), output_param, node}; + auto depend = graph->NewCNode(depend_inputs); + need_replace_list.emplace(node, depend); + MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() + << ", depend node is " << depend->DebugString(); + // insert assign in order to transfer child graph output to parameter + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); + for (auto &child_graph : child_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + if (child_graph->get_output_null()) { + continue; + } + auto graph_output = child_graph->output(); + AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output), + NOT_NULL(output_param)); + } + } + } + // searching for nodes' input to replace call by depend(parameter, call) + for (auto &node : node_list) { + for (size_t i = 0; i < node->size(); ++i) { + auto input = node->input(i); + auto iter = need_replace_list.find(input); + if (iter != need_replace_list.end()) { + node->set_input(i, iter->second); + } + } + } +} + +void AscendSession::IrFusionPass(const NotNull graph, NotNull *> memo) { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + opt::AscendBackendIRFusionOptimization(graph); + opt::AscendBackendFuseBasicOpt(graph, true); + opt::AscendBackendGraphKernelOpt(graph, true); + graph->SetExecOrderByDefault(); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs) { + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + std::string file_path = + save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_path, graph.get()); + } + + for (auto &child_graph : graph->child_graph_order()) { + IrFusionPass(NOT_NULL(child_graph), memo); + } +} + +void AscendSession::SelectKernel(NotNull root_graph) { + MS_LOG(INFO) << "Start select kernel."; + size_t raise_precision_count = 0; + size_t reduce_precision_count = 0; + + std::set memo; + (void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count); + memo.clear(); + + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kGraphMode) { + if (raise_precision_count > 0) { + MS_LOG(WARNING) << "There has " << raise_precision_count + << " node/nodes used raise precision to selected the kernel!"; + } + if (reduce_precision_count > 0) { + MS_LOG(WARNING) << "There has " << raise_precision_count + << " node/nodes used reduce precision to selected the kernel!"; + } + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RecurseSelectKernelInfo(NotNull graph, + NotNull *> const memo, + size_t *const raise_precision_count, + size_t *const reduce_precision_count) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id(); + + for (const auto &cnode : graph->execution_order()) { + if (AnfAlgo::IsCondControlKernel(cnode)) { + std::vector child_graphs; + if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) { + child_graphs = AnfAlgo::GetNodeAttr>(cnode, kAttrChildGraph); + } + for (auto &child_graph : child_graphs) { + RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count); + } + } + + auto status = device::ascend::SelectKernelInfo(cnode); + if (status == device::ascend::kStatusRaisePrecision) { + (*raise_precision_count)++; + } else if (status == device::ascend::kStatusReducePrecision) { + (*reduce_precision_count)++; + } + MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); + } + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs) { + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + std::string file_path = + save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_path, graph.get()); + } + MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id(); +} + +void AscendSession::HardwareOptimize(NotNull graph, + NotNull *> const memo) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id(); + // convert kernel Graph to model + predictmodel::StepConvertGraph(graph.get()); + + HardwareOptimize(graph.get()); + for (auto &child_graph : graph->child_graph_order()) { + HardwareOptimize(NOT_NULL(child_graph), memo); + } + MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id(); +} + +void AscendSession::AssignStaticMemory(NotNull graph, + NotNull *> const memo) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id(); + // assign static memory for parameters + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignStaticMemoryInput(graph.get().get()); + runtime_instance->AssignStaticMemoryValueNode(graph.get().get()); + for (auto &child_graph : graph->child_graph_order()) { + AssignStaticMemory(NOT_NULL(child_graph), memo); + } + MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id(); +} + +void AscendSession::UpdateRefOutputMap(NotNull graph, + NotNull *> const memo) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + for (auto &child_graph : graph->child_graph_order()) { + UpdateRefOutputMap(NOT_NULL(child_graph), memo); + // copy ref map to final graph + auto child_ref_map = child_graph->GetRefMap(); + for (auto &item : child_ref_map) { + if (graph->IsInRefOutputMap(item.first)) { + MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second + << "> is already in " << graph->ToString(); + continue; + } + graph->AddRefCorrespondPairs(item.first, item.second); + } + } +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index f8ec7e8545..11cb1c92d2 100755 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -151,6 +151,15 @@ class AscendSession : public SessionBasic { // sync intial tensors' data to device void SyncInitialTenosrToDevice(); void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); + // create parameter to receive data from multiple branch output + void CreateMultiBranchOutput(NotNull graph, NotNull *> memo); + void SelectKernel(NotNull root_graph); + void RecurseSelectKernelInfo(NotNull graph, NotNull *> const memo, + size_t *const raise_precision_count, size_t *const reduce_precision_count) const; + void IrFusionPass(const NotNull graph, NotNull *> memo); + void HardwareOptimize(const NotNull graph, NotNull *> memo) const; + void AssignStaticMemory(const NotNull graph, NotNull *> memo) const; + void UpdateRefOutputMap(const NotNull graph, NotNull *> memo) const; // member variables // key is final_graph_id,value is child graph execute order of final graph diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 0bf447751b..df810fe6ef 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -616,8 +616,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); } - MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() - << "], depend_mode :" << depend_mode << "."; + MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() + << "], depend_mode :" << depend_mode << "."; if (prior_node->isa() && depend_mode == 1) { prior_nodes = GetOutputNodes(prior_node); } @@ -647,7 +647,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de } MS_EXCEPTION_IF_NULL(first_node); MS_EXCEPTION_IF_NULL(second_node); - MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); + MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() + << ",second node:" << second_node->DebugString(); AddDependEdge(second_node, first_node, 1); } } @@ -991,6 +992,30 @@ bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { return false; } +void KernelGraph::UpdateChildGraphOrder() { + MS_LOG(INFO) << "Update " << ToString() << " child graph order."; + SetExecOrderByDefault(); + auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + std::vector child_graph_order; + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + for (const auto &child_graph : call_child_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + if (child_graph != parent_graph_) { + auto shared_this = std::dynamic_pointer_cast(shared_from_this()); + MS_EXCEPTION_IF_NULL(shared_this); + child_graph->set_parent_graph(shared_this); + } + child_graph_order.push_back(child_graph); + } + } + for (size_t i = 0; i < child_graph_order.size(); ++i) { + MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; + } + child_graph_order_ = child_graph_order; +} + std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index f353ed1dda..48df351120 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -156,6 +156,12 @@ class KernelGraph : public FuncGraph { bool IsFinalOutputKernel(const AnfNodePtr &node) const; uint32_t current_epoch() const { return current_epoch_; } void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } + void UpdateChildGraphOrder(); + const std::vector &child_graph_result() const { return child_graph_result_; } + void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); } + void set_child_graph_result(const std::vector &child_graph_result) { + child_graph_result_ = child_graph_result; + } private: // remove value node form graph @@ -173,6 +179,7 @@ class KernelGraph : public FuncGraph { void UpdateControlDependRelations(const std::vector &depends); std::shared_ptr> inputs_; + std::vector child_graph_result_; std::vector execution_order_; uint32_t graph_id_; uint32_t stream_distinction_label_; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 117e48fbb8..dd65fe89f3 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -74,7 +74,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne return input_tensors[input_idx]; } } - MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr"; + MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; } } // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) @@ -107,8 +107,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne return tensor; } -BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, - const std::vector &input_tensors) { +BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, + const std::vector &input_tensors) { MS_EXCEPTION_IF_NULL(anf); MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); @@ -120,7 +120,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, MS_EXCEPTION_IF_NULL(cnode); VectorRef ret; for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors); + auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors); ret.push_back(out); } return ret; @@ -133,25 +133,6 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors); } -BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(anf); - if (!AnfAlgo::IsRealKernel(anf)) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel"; - } - if (anf->isa()) { - return CreateOneTensor(anf, 0, graph, input_tensors); - } - VectorRef ret; - if (anf->isa() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) { - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) { - auto out = CreateOneTensor(anf, i, graph, input_tensors); - ret.emplace_back(out); - } - } - return ret; -} - ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(graph); @@ -880,20 +861,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_grap const std::vector &input_tensors) const { MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(outputs); - if (!kernel_graph->child_graph_order().empty()) { - // use the last child graph output as the root graph output - UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors); - return; - } auto anf_outputs = kernel_graph->outputs(); for (auto &item : anf_outputs) { MS_EXCEPTION_IF_NULL(item); MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; - if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) { - outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors)); - continue; - } - outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors)); + outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors)); } } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d5fd00da5b..3de9af8c23 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -294,6 +294,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(mem_manager_); auto graph_inputs = graph->inputs(); auto graph_valid_input = graph->valid_inputs(); + graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); std::vector need_alloc_nodes; for (size_t i = 0; i < graph_inputs.size(); ++i) { auto item = graph_inputs[i]; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 8317ce3116..e437ce8534 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -240,6 +240,7 @@ constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; constexpr auto kAttrOffset = "offset"; constexpr auto kAttrPsKey = "ps_key"; constexpr auto kAttrOptimizerType = "optim_type"; +constexpr auto kAttrChildGraph = "child_graph"; // attr value constexpr auto kValueTargetSwitch = "target_switch";