From 63a89925ffa4431950c729a2c6e652456c6091b6 Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Fri, 2 Apr 2021 15:25:55 +0800 Subject: [PATCH] remove ControlDepend and its use --- .../ascend/enhancer/split_n_optimizer.cc | 4 +- .../ir_fission/batch_norm_bert_fission.cc | 4 +- .../ascend/mindir/dropout_unify_mindir.cc | 16 +- .../ccsrc/backend/optimizer/common/helper.cc | 13 +- .../ccsrc/backend/optimizer/common/helper.h | 2 +- .../graph_kernel/add_atomic_clean.cc | 25 ++- .../graph_kernel/basic_ops_fusion.cc | 4 +- .../graph_kernel/graph_kernel_helper.cc | 70 ------ .../graph_kernel/graph_kernel_helper.h | 2 - .../optimizer/pass/eliminate_redundant_op.cc | 7 +- .../backend/session/anf_runtime_algorithm.h | 2 +- .../backend/session/ascend_control_parser.cc | 21 +- .../ccsrc/backend/session/kernel_graph.cc | 84 +------- .../ccsrc/backend/session/kernel_graph.h | 8 +- .../ccsrc/backend/session/session_basic.cc | 28 +-- .../optimizer/irpass/branch_culling.cc | 126 ++--------- .../ccsrc/frontend/optimizer/recompute.cc | 10 +- .../frontend/parallel/ops_info/ops_utils.h | 3 +- .../runtime/framework/graph_scheduler.cc | 6 +- mindspore/ccsrc/transform/graph_ir/convert.cc | 201 +----------------- mindspore/ccsrc/transform/graph_ir/convert.h | 12 +- mindspore/ccsrc/vm/graph_partition.cc | 101 +-------- mindspore/ccsrc/vm/segment_runner.cc | 31 +-- mindspore/core/abstract/infer_functions.h | 2 - mindspore/core/abstract/prim_others.cc | 19 +- .../core/abstract/primitive_infer_map.cc | 1 - mindspore/core/base/core_ops.h | 1 - mindspore/core/ir/anf.cc | 5 +- mindspore/core/ops/conv2d.cc | 3 +- mindspore/core/ops/expand_dims.cc | 3 +- mindspore/core/utils/parallel_node_check.cc | 4 +- .../lite/tools/anf_exporter/anf_exporter.cc | 10 +- .../lite/tools/optimizer/common/gllo_utils.cc | 4 +- .../tools/optimizer/graph/infershape_pass.cc | 4 +- .../graph/redundant_op_remove_pass.cc | 7 - .../nlp/gpt2/src/gpt2_for_finetune.py | 9 +- .../pass/optimize_dependence_test.cc | 18 +- 37 files changed, 118 insertions(+), 752 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_n_optimizer.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_n_optimizer.cc index 5390da2373..1f33f29afd 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_n_optimizer.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_n_optimizer.cc @@ -108,7 +108,7 @@ bool InputCheck(const AnfNodePtr &node) { MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; return false; } - if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) { + if (in_node_name == prim::kPrimDepend->name()) { return false; } if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr(in_node, "non_task")) || @@ -131,7 +131,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { return false; } for (const auto &item : outputs) { - if (IsPrimitiveCNode(item, prim::kPrimControlDepend) || IsPrimitiveCNode(item, prim::kPrimDepend)) { + if (IsPrimitiveCNode(item, prim::kPrimDepend)) { MS_LOG(INFO) << "Split has control edge, can not optimizer."; return false; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc index 46dfd9dfd7..1140ee33f1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -168,7 +168,7 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); output_index++; } - // Return the new node for control depends. + // Return the new node. return bn_training_update_v2; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc index 1f307b1ab4..a3c3fefc24 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -201,20 +201,6 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f } } } - if (dropout_do_mask1 != nullptr) { - // Dropout is used by ControlDepend in some situation, need to replace ControlDepend. - auto &users = manager->node_users(); - iter = users.find(dropout_node); - if (iter != users.end()) { - for (auto &node_index : iter->second) { - auto used_node = node_index.first; - if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) { - (void)manager->Replace(used_node, dropout_do_mask1); - break; - } - } - } - } // CreateDropoutDoMask-backward if (equiv->find(grad_input_) == equiv->end()) { diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 36e5ae14a9..a6dd3f5055 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -426,9 +426,6 @@ std::shared_ptr>> GetRealNodeUsedListByOu } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { - continue; - } if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) { continue; @@ -908,16 +905,12 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - // find BatchNorm's output which is a Depend or ControlDepend + // find BatchNorm's output which is a Depend for (const auto &node_index : manager->node_users()[old_node]) { AnfNodePtr output = node_index.first; size_t index = IntToSize(node_index.second); MS_EXCEPTION_IF_NULL(output); - if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { - auto control_depend = output->cast(); - MS_EXCEPTION_IF_NULL(control_depend); - control_depend->set_input(index, new_node); - } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { auto depend = output->cast(); MS_EXCEPTION_IF_NULL(depend); depend->set_input(index, new_node); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 2834c493ce..699f8f0ee5 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -210,7 +210,7 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set &suppor // Create a new value node of func graph,not kernel graph ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); -// Transfer depend or control_depend to the new node +// Transfer depend to the new node void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc index 3ee4124309..7202082c39 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc @@ -339,7 +339,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) { - // Create depend node to hold new control depend node. + // Create depend node to hold execution order. AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node}; auto depend_cnode = main_graph->NewCNode(d_inputs); depend_cnode->set_abstract(clean_node->abstract()); @@ -513,18 +513,17 @@ bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_ const FuncGraphManagerPtr &mng) { auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false); // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce - // node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong! - return std::all_of(reduce_users.cbegin(), reduce_users.cend(), - [&main_graph](const std::pair &user_info) -> bool { - auto &user = user_info.first; - if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) || - IsPrimitiveCNode(user, prim::kPrimControlDepend)) && - !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { - return false; - } else { - return true; - } - }); + // node and user node. If reduce is Depend node, the origin node may be wrong! + return std::all_of( + reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair &user_info) -> bool { + auto &user = user_info.first; + if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) && + !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { + return false; + } else { + return true; + } + }); } bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index 111f29de65..2f216e2612 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -137,9 +137,9 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vectormanager(); - // depend_prior[depend] = pair(prior, controlDependNode) + // depend_prior[depend] = pair(prior, behind) std::multimap> depend_prior; - InitDependPrior(todos, &depend_prior); + // InitDependPrior(todos, &depend_prior); for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { auto node = (*iter)->cast(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index edbb9ddfe0..2e1cabf929 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -635,76 +635,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { #endif } -void InitDependPrior(const std::vector &todos, - std::multimap> *depend_prior) { - for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { - auto cnode = (*iter)->cast(); - if (cnode == nullptr) { - continue; - } - if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { - continue; - } - - auto prior_node = cnode->input(kControlDependPriorIndex); - auto depend_node = cnode->input(kControlDependBehindIndex); - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(depend_node); - std::vector prior_nodes = {prior_node}; - std::vector depend_nodes = {depend_node}; - - int depend_mode = 0; - if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { - depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); - } - - auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector { - std::vector out_nodes; - auto user_set = param->func_graph()->manager()->node_users()[param]; - for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) { - if (iter->first != cnode) { - out_nodes.push_back(iter->first); - } - } - return out_nodes; - }; - - if (prior_node->isa() && depend_mode == 1) { - prior_nodes = GetOutputNodes(prior_node); - } - if (depend_node->isa()) { - depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; - } - - std::vector real_prior_nodes; - std::set prior_visited; - for (const auto &tmp : prior_nodes) { - AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); - } - prior_visited.clear(); - std::vector real_depend_nodes; - std::set depend_visited; - for (const auto &tmp : depend_nodes) { - AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); - } - depend_visited.clear(); - - for (auto &prior : real_prior_nodes) { - if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) { - continue; - } - for (auto &depend : real_depend_nodes) { - if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) { - continue; - } - depend_prior->insert({depend, std::make_pair(prior, cnode)}); - } - } - real_prior_nodes.clear(); - real_depend_nodes.clear(); - } -} - void ReplaceNewFuseCNodeForDependPrior(std::multimap> *depend_prior, const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { std::multimap> new_fuse_cnode_dep_pri; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 7baa5571af..79847e0603 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -74,8 +74,6 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p std::vector GetFusibleOpList(); bool IsBasicFuseOp(const AnfNodePtr &node); void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); -void InitDependPrior(const std::vector &todos, - std::multimap> *depend_prior); void ReplaceNewFuseCNodeForDependPrior(std::multimap> *depend_prior, const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc index abc7d24486..1d3fea1106 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector(value_node->value()); pass_vector->push_back(make_pair(cnode, IntToSize(1))); return GetRealPrevCNode(cnode->input(1), LongToSize(item_idx), pass_vector); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { + } else if (IsPrimitive(input0, prim::kPrimDepend)) { pass_vector->push_back(make_pair(cnode, IntToSize(1))); return GetRealPrevCNode(cnode->input(1), 0, pass_vector); } else if (IsPrimitive(input0, prim::kPrimUpdateState)) { @@ -92,8 +92,7 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode auto pass_size = pass_vector->size(); for (size_t idx = 1; idx <= pass_size - 1; ++idx) { auto nd = (*pass_vector)[idx].first; - if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || - AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { + if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend)) { has_depend_node = true; } if (users[nd].size() >= 2) { diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 2ed8a160dc..58dba2e36c 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -248,7 +248,7 @@ class AnfRuntimeAlgorithm { static void InferShape(const CNodePtr &node); static std::vector GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); static std::vector GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); - // Find control_depend real input nodes. + // Find real input nodes. static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited); }; diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 96984a39c6..0fc701f8b4 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -534,14 +534,17 @@ void AscendControlParser::InsertDependToGraph(NotNull kg, NotNul return_node->set_input(kFirstDataInputIndex, depend_node); } -void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, - NotNull second_node) { - MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() - << ", the second node is " << second_node->DebugString(); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), - first_node, second_node}; - auto control_depend = kg->NewCNode(inputs); - InsertDependToGraph(kg, NOT_NULL(control_depend)); +void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull prior_node, + NotNull behind_node) { + MS_LOG(INFO) << "Insert control dependence at the end of graph, the prior node is " << prior_node->DebugString() + << ", the behind node is " << behind_node->DebugString(); + auto manager = kg->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node}; + auto depend_cnode = kg->NewCNode(inputs); + if (!manager->Replace(behind_node, depend_cnode)) { + MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed."; + } } void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index c1160a28da..079ebb1fed 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -422,7 +422,7 @@ void KernelGraph::CheckLoop() { none_zero_nodes[it.first] = it.second; } } - // if don't consider control depend and loop exit,a exception will be throw + // if don't consider loop exit,a exception will be throw if (!none_zero_nodes.empty()) { MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes); MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); @@ -851,61 +851,10 @@ std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { return output_nodes; } -// update the depend relations of control depend -void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { - for (const auto &node : depends) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { - MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; - } - auto prior_node = cnode->input(kControlDependPriorIndex); - auto depend_node = cnode->input(kControlDependBehindIndex); - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(depend_node); - std::vector prior_nodes = {prior_node}; - std::vector depend_nodes = {depend_node}; - int depend_mode = 0; - if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { - depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); - } - MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() - << "], depend_mode :" << depend_mode << "."; - if (prior_node->isa() && depend_mode == 1) { - prior_nodes = GetOutputNodes(prior_node); - } - if (depend_node->isa()) { - depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; - } - - std::vector real_prior_nodes; - std::set prior_visited; - for (const auto &tmp : prior_nodes) { - AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); - } - std::vector real_depend_nodes; - std::set depend_visited; - for (const auto &tmp : depend_nodes) { - AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); - } - UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes); - } -} - void KernelGraph::UpdateNodeInputOutputEdges(const std::vector &real_prior_nodes, const std::vector &real_depend_nodes) { for (auto &first_node : real_prior_nodes) { - if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { - continue; - } for (auto &second_node : real_depend_nodes) { - if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { - continue; - } MS_EXCEPTION_IF_NULL(first_node); MS_EXCEPTION_IF_NULL(second_node); MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); @@ -914,35 +863,6 @@ void KernelGraph::UpdateNodeInputOutputEdges(const std::vector &real } } -bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(que); - MS_EXCEPTION_IF_NULL(visited_nodes); - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { - return false; - } - // set the control depend visited but don't push it into the que - if (visited_nodes->find(node) != visited_nodes->end()) { - return true; - } - (void)visited_nodes->insert(cnode); - // add a 0 depend num to keep the link relations to prepare for finding zero output nodes - auto prior_node = cnode->input(kControlDependPriorIndex); - auto depend_node = cnode->input(kControlDependBehindIndex); - for (const auto &input : cnode->inputs()) { - AddDependEdge(node, input, 0); - } - PushNoVisitedNode(depend_node, que, visited_nodes); - PushNoVisitedNode(prior_node, que, visited_nodes); - return true; -} - void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { MS_EXCEPTION_IF_NULL(seed_nodes); node_output_edges_.clear(); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 7b34641959..63b0f7315c 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -287,15 +287,11 @@ class KernelGraph : public FuncGraph { std::unordered_set *visited_nodes, bool comm_first = true); // update node edge list void UpdateNodeEdgeList(std::queue *seed_nodes); - // add node depend edge by data edge or control depend + // add node depend edge by data edge void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); void UpdateNodeInputOutputEdges(const std::vector &real_prior_nodes, const std::vector &real_depend_nodes); - // handle control depend std::vector GetOutputNodes(const AnfNodePtr &node); - bool HandleControlDependNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes); - void UpdateControlDependRelations(const std::vector &depends); AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); AnfNodePtr TransCNodeTuple(const CNodePtr &node); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 8c2acc2a43..c63a25c3cd 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -223,10 +223,8 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra MS_EXCEPTION_IF_NULL(cnode); VectorRef ret; for (size_t i = 1; i < cnode->inputs().size(); ++i) { - if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) { - auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); - ret.push_back(out); - } + auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); + ret.push_back(out); } return ret; } @@ -386,22 +384,6 @@ bool ExistSummaryNode(const KernelGraph *graph) { return false; } -bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - const auto &node_inputs = cnode->inputs(); - for (size_t i = 1; i < node_inputs.size(); ++i) { - if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) { - return false; - } - } - return true; -} - void GetParameterIndex(KernelGraph *graph, const std::vector &inputs, std::map *parameter_index) { size_t index = 0; @@ -692,9 +674,6 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); - if (IgnoreCreateParameterForMakeTuple(node)) { - return nullptr; - } auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); auto parameters = AnfAlgo::GetAllOutput(new_parameter); std::vector pre_graph_out = {node}; @@ -1872,9 +1851,6 @@ std::vector ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr auto &users = front_func_graph_manager->node_users()[front_node]; std::vector result; for (auto &user : users) { - if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) { - continue; - } if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { auto depend_cnode = user.first->cast(); if (depend_cnode == nullptr) { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc index 04740e0c7f..f832c00c58 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { } } - std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend, prim::kPrimLoad}; + std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimLoad}; for (auto &item : adapter_convert_ops) { if (IsPrimitiveCNode(node, item)) { return true; @@ -243,8 +243,7 @@ CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t sw return merge_op; } -// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) -// control_depend(output_node, square_op) +// merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, int64_t switch_idx) { tensor::TensorPtr const_data = GetConstData(); @@ -259,54 +258,21 @@ AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr SetSquareOp(switch_idx, square_op); } + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), square_op, output_node}; + auto depend_cnode = graph->NewCNode(inputs); + if (!manager->Replace(square_op, depend_cnode)) { + MS_LOG(EXCEPTION) << square_op->DebugString() << ", replace node failed."; + } + CNodePtr merge_op = GetMergeOp(switch_idx); if (merge_op == nullptr) { merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); SetMergeOp(switch_idx, merge_op); } - std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; - auto control_depend_op = graph->NewCNode(control_depend_nodes); - - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; - auto depend_op = graph->NewCNode(depend_nodes); - - return depend_op; -} - -// construct a merge output and add dependency with the netoutput node from control_depend -// we need to reserve the control_depend node, besides the generated merge node and control_depend node -CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, - int64_t switch_idx) { - auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); - auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); - std::vector shp = {1}; - tensor::TensorPtr const_data = std::make_shared(kInt64->type_id(), shp); - auto *val = static_cast(const_data->data_c()); - *val = 0; - // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same - // switch the other use the opposite - auto ctrl_data = NewValueNode(const_data); - auto oppsite_ctrl_data = NewValueNode(const_data); - auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); - auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); - - std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; - auto square_op = graph->NewCNode(square_nodes); - - std::vector merge_nodes; - merge_nodes.push_back(NewValueNode(PrimMerge)); - std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; - merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); - auto merge_output = graph->NewCNode(merge_nodes); - - std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; - auto cond_dep_output = graph->NewCNode(control_depend_nodes); - - std::vector depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, - cond_dep_output}; - return graph->NewCNode(depended_make_tuple_nodes); + return merge_op; } // generate switch nodes for true graph node inputs @@ -321,26 +287,12 @@ AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNod return GenerateSwitchDependNode(graph, cond, data, 0); } -// generate switch nodes for true graph node inputs -CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &con_input, const AnfNodePtr &output) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); -} - -// generate switch nodes for false graph node inputs -CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &con_input, const AnfNodePtr &output) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); -} - -// to judge if the node used in ControlDepend is a net output node +// to judge if the node used in Depend is a net output node bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { auto uses = manager->node_users()[node]; bool is_output_node = true; for (auto &item : uses) { - if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { + if (IsPrimitiveCNode(item.first, prim::kPrimDepend)) { continue; } is_output_node = false; @@ -353,8 +305,7 @@ bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) void GenerateReplNodeForDependMakeTuple( const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, const std::shared_ptr> &repl_node, - const std::function &generate_func, - const std::function &gen_ctl_depd_func) { + const std::function &generate_func) { MS_EXCEPTION_IF_NULL(graph->manager()); auto make_tuple_inputs = depended_node->cast()->inputs(); @@ -368,26 +319,6 @@ void GenerateReplNodeForDependMakeTuple( new_make_tuple_nodes.push_back(depended_tuple_input_node); continue; } - if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimControlDepend)) { - // only when the control depend input is not square op (the op to use as merge output) - auto control_inputs = depended_tuple_input_node->cast()->inputs(); - if (control_inputs.size() != 3) { - MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); - } - // control inputs: primitive, src, dst - auto dst_node = control_inputs[2]; - if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { - auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); - MS_EXCEPTION_IF_NULL(gen_node); - auto tuple_inputs = gen_node->inputs(); - // add depended tuple inputs to new_make_tuple directly - for (size_t i = 1; i < tuple_inputs.size(); i++) { - new_make_tuple_nodes.push_back(tuple_inputs[i]); - } - } - replace_make_tuple = true; - continue; - } if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { auto gen_node = generate_func(graph, cond, depended_tuple_input_node); @@ -408,8 +339,7 @@ void GenerateReplNodeForDependMakeTuple( void GenerateRepDepend( const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, const std::shared_ptr> &repl_node, - const std::function &generate_func, - const std::function &gen_ctl_depd_func) { + const std::function &generate_func) { auto inputs = node->inputs(); if (inputs.size() != 3) { MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; @@ -422,19 +352,7 @@ void GenerateRepDepend( new_depened_inputs.push_back(inputs[1]); // depended node should be make_tuple or a single depended node if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { - GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); - } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { - // only when the control depend input is not square op (the op to use as merge output) - auto control_inputs = depended_node->cast()->inputs(); - // control inputs: primitive, src, dst - if (control_inputs.size() != 3) { - MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); - } - auto dst_node = control_inputs[2]; - if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { - auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); - (*repl_node)[depended_node] = gen_node; - } + GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func); } else { // Check if there is only single user for depend_node. if (graph->manager()->node_users()[depended_node].size() == 1) { @@ -448,11 +366,9 @@ void GenerateRepDepend( // generate depend node for netoutput node, to resolve the stream synchronize problem of ge // traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) -// and add control_depend of graph output node and square node. FuncGraphPtr TransformGraphDependNode( const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::function &gen_depend_func, - const std::function &gen_ctl_depd_func) { + const std::function &gen_depend_func) { auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -478,7 +394,7 @@ FuncGraphPtr TransformGraphDependNode( if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { continue; } - GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); + GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func); } } ResetSharedOp(); @@ -494,12 +410,12 @@ FuncGraphPtr TransformGraphDependNode( FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); - return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); + return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode); } FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); - return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); + return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode); } // judge if the true and false graph output is compatible(they shall have same tuple size) diff --git a/mindspore/ccsrc/frontend/optimizer/recompute.cc b/mindspore/ccsrc/frontend/optimizer/recompute.cc index 4b2aeeea39..65f93ad7a7 100644 --- a/mindspore/ccsrc/frontend/optimizer/recompute.cc +++ b/mindspore/ccsrc/frontend/optimizer/recompute.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -218,10 +218,10 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) { if (output_set_iter == node_users.end()) { return false; } - for (const auto &node_index_set : output_set_iter->second) { - if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { - return true; - } + + if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(), + [](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) { + return true; } return false; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index e24dae7cd2..6bfe40fa6a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -367,7 +367,6 @@ constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary"; constexpr char DEBUG[] = "Debug"; constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; constexpr char INVERTPERMUTATION[] = "InvertPermutation"; -constexpr char CONTROLDEPEND[] = "ControlDepend"; constexpr char DOT[] = "dot"; constexpr char IM2COL[] = "im2col"; constexpr char COL2IM[] = "col2im"; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 72a74f95a7..6bcc3450f2 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -259,10 +259,8 @@ BaseRef CreateOutputTensors(const AnfNodePtr &output_node, const KernelGraphPtr MS_EXCEPTION_IF_NULL(cnode); VectorRef ret; for (size_t i = 1; i < cnode->inputs().size(); ++i) { - if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) { - auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors); - ret.push_back(out); - } + auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors); + ret.push_back(out); } return ret; } diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index f66cc9aa3e..0bb2c6b373 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1044,7 +1044,7 @@ bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) { OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) { auto op = Convert(GetRealOpNode(node)); if (op == nullptr) { - MS_LOG(ERROR) << "Convert control depend node to operator failed, " << node->ToString(); + MS_LOG(ERROR) << "Convert real op node to operator failed, " << node->ToString(); error_ = FAILED; return nullptr; } @@ -1170,13 +1170,13 @@ void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) { void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) { AutoMonadSetControlInput(node); - if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { + if (control_edge_cache_.find(node.get()) == control_edge_cache_.end()) { return; } - std::vector control_edges = control_depend_cache_[node.get()]; + std::vector control_edges = control_edge_cache_[node.get()]; if ((control_edges.empty())) { - MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; + MS_LOG(ERROR) << "Get control edge node's src or dest operator failed"; return; } @@ -1600,7 +1600,7 @@ std::vector DfGraphConvertor::ConvertDependNode(const AnfNodePtr no for (size_t index = 1; index < node_inputs.size(); index++) { auto op = Convert(GetRealOpNode(node_inputs[index])); if (op == nullptr) { - MS_LOG(ERROR) << "Convert control depend node to operator failed"; + MS_LOG(ERROR) << "Convert real op node to operator failed"; error_ = FAILED; return std::vector({}); } @@ -1611,194 +1611,13 @@ std::vector DfGraphConvertor::ConvertDependNode(const AnfNodePtr no auto op = Convert(GetRealOpNode(node)); if (op == nullptr) { - MS_LOG(ERROR) << "Convert control depend node to operator failed"; + MS_LOG(ERROR) << "Convert real op node to operator failed"; error_ = FAILED; return std::vector({}); } return std::vector({op}); } -// get the anf node list for depend -std::vector DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) { - std::vector nodes; - // for make tuple, should control depend on the tuple items - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - auto node_inputs = node->cast()->inputs(); - for (size_t index = 1; index < node_inputs.size(); index++) { - nodes.push_back(GetRealOpNode(node_inputs[index])); - } - return nodes; - } - - // for parameter ,find the apply that used the parameter as the control depended node - if (node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { - nodes.push_back(GetRealOpNode(use_node)); - } - } - return nodes; - } - nodes.push_back(GetRealOpNode(node)); - return nodes; -} - -void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) { -#ifdef DRAW_GE_GRAPH - auto src_depend_nodes = GetDependNodes(src_node); - auto dst_depend_nodes = GetDependNodes(dest_node); - if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) { - for (auto &item : dst_depend_nodes) { - compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()] - << "[style=\"dotted\"]" << endl; - } - } else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) { - for (auto &item : src_depend_nodes) { - compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] - << "[style=\"dotted\"]" << endl; - } - } else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) { - compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] - << "[style=\"dotted\"]" << endl; - } -#endif -} - -void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, - const AnfNodePtr &dest_node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list) { - if (src_node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[src_node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && - (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { - auto converted_list = ConvertDependNode(use_node); - src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); - } - } - } - - if (dest_node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[dest_node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && - (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { - auto converted_list = ConvertDependNode(use_node); - dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); - } - } - } -} - -bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list) { - const int CONTROL_DEPEND_INDEX = 0; - const int SRC_NODE_INDEX = 1; - const int DEST_NODE_INDEX = 2; - const int DEPEND_MODE_NORMAL_USE = 0; - const int DEPEND_MODE_ON_PARAMETER_USE = 1; - - auto node_inputs = node->inputs(); - if (node_inputs.size() <= DEST_NODE_INDEX) { - MS_LOG(WARNING) << "Control depend node input size error"; - return false; - } - auto src_node = node_inputs[SRC_NODE_INDEX]; - auto dest_node = node_inputs[DEST_NODE_INDEX]; - if ((src_node == nullptr) || (dest_node == nullptr)) { - MS_LOG(ERROR) << "Control depend node miss src or dest node"; - error_ = FAILED; - return false; - } - AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX]; - PrimitivePtr prim_ptr = GetValueNode(fn); - ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); - int depend_mode = DEPEND_MODE_NORMAL_USE; - if (mode_ptr != nullptr) { - auto mode_int = mode_ptr->cast(); - MS_EXCEPTION_IF_NULL(mode_int); - depend_mode = mode_int->value(); - MS_LOG(DEBUG) << "depend_mode = " << depend_mode; - } - if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) { - GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list); - } - - if (src_node->isa()) { - auto converted_list = ConvertDependNode(src_node); - src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); - } - - if (dest_node->isa()) { - auto converted_list = ConvertDependNode(dest_node); - dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); - } - if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; - error_ = SUCCESS; - } - return true; -} - -void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { - const int SRC_NODE_INDEX = 1; - const int DEST_NODE_INDEX = 2; - if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) { - return; - } - auto node_inputs = node->inputs(); - if (node_inputs.size() <= DEST_NODE_INDEX) { - MS_LOG(WARNING) << "Control depend node input size error"; - return; - } - auto src_node = node_inputs[SRC_NODE_INDEX]; - auto dest_node = node_inputs[DEST_NODE_INDEX]; - if ((src_node == nullptr) || (dest_node == nullptr)) { - MS_LOG(ERROR) << "Control depend node miss src or dest node"; - error_ = FAILED; - return; - } - std::shared_ptr> src_ops_list = std::make_shared>(); - std::shared_ptr> dst_ops_list = std::make_shared>(); - if (!GetControlDependList(node, src_ops_list, dst_ops_list)) { - MS_LOG(ERROR) << "Get depend list failed"; - error_ = FAILED; - return; - } - std::vector control_edges; - if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) { - (void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges), - [src_ops_list](const OperatorPtr &op) -> ControlEdge { - return {(*src_ops_list)[0], op}; - }); - } else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) { - (void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges), - [dst_ops_list](const OperatorPtr &op) -> ControlEdge { - return {op, (*dst_ops_list)[0]}; - }); - } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { - control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); - } else if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; - } else { - MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() - << " -> dst:" << dst_ops_list->size(); - error_ = FAILED; - return; - } - control_depend_cache_[node.get()] = control_edges; - -#ifdef DRAW_GE_GRAPH - DrawControlDepend(src_node, dest_node); -#endif -} - bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { // ignore apply node of return if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() || @@ -1818,12 +1637,6 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) return false; } - // ControlDepend - if (name == prim::kPrimControlDepend->name()) { - ConvertControlDependNode(node); - return false; - } - return true; } diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h index 9789dd9467..d305af69de 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.h +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -145,19 +145,11 @@ class DfGraphConvertor { OperatorPtr ConvertCNode(CNodePtr node); std::vector ConvertDependNode(AnfNodePtr node); AnfNodePtr GetRealOpNode(AnfNodePtr node); - std::vector GetDependNodes(const AnfNodePtr &node); OperatorPtr ConvertParameter(AnfNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node); void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); void ConvertTupleGetItem(const CNodePtr node); - void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list); - bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list); - void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); - void ConvertControlDependNode(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node); bool CheckCNode(const std::string &name, const CNodePtr node); void TraceOutput(AnfNodePtr node); @@ -195,7 +187,7 @@ class DfGraphConvertor { std::shared_ptr broadcast_graph_{nullptr}; std::unordered_map branches_map_; std::unordered_map op_cache_; - std::unordered_map> control_depend_cache_; + std::unordered_map> control_edge_cache_; std::unordered_map> monad_control_edge_cache_; /* record "tuple_getitem"<->"out_handler" mapping */ std::unordered_map out_handle_cache_; diff --git a/mindspore/ccsrc/vm/graph_partition.cc b/mindspore/ccsrc/vm/graph_partition.cc index 20113692e6..0e8e4b3f84 100644 --- a/mindspore/ccsrc/vm/graph_partition.cc +++ b/mindspore/ccsrc/vm/graph_partition.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,88 +51,6 @@ std::string GetOtherTarget(const std::vector &nodes) { } return ""; } -bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, - std::vector *prior_nodes, std::vector *depend_nodes) { - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(behind_node); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - if (prior_node->isa()) { - for (auto &user : node_users[prior_node]) { - auto cnode = user.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - prior_nodes->emplace_back(cnode); - } - } - } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { - prior_nodes->emplace_back(prior_node); - } else { - return false; - } - if (behind_node->isa()) { - for (auto &user : node_users[behind_node]) { - auto cnode = user.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - depend_nodes->emplace_back(cnode); - } - } - } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { - depend_nodes->emplace_back(behind_node); - } else { - return false; - } - return true; -} - -void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, - std::map> *control_edges, - std::map *nodes_ref) { - MS_EXCEPTION_IF_NULL(node); - auto input_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - auto prior_node = input_cnode->input(kControlDependPriorIndex); - auto depend_node = input_cnode->input(kControlDependBehindIndex); - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(depend_node); - auto prim_ptr = GetValueNode(input_cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim_ptr); - ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); - int64_t depend_mode = 0; - if (mode_ptr != nullptr) { - depend_mode = GetValue(mode_ptr); - } - if ((prior_node->isa() || depend_node->isa()) && depend_mode == 0) { - return; - } - std::vector prior_nodes; - std::vector behind_nodes; - if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { - return; - } - for (auto &first_node : prior_nodes) { - for (auto &second_node : behind_nodes) { - MS_EXCEPTION_IF_NULL(first_node); - MS_EXCEPTION_IF_NULL(second_node); - auto iter = control_edges->find(second_node); - if (iter == control_edges->end()) { - (void)control_edges->insert( - std::pair>(second_node, std::vector{first_node})); - } else { - iter->second.emplace_back(first_node); - } - auto ref_iter = nodes_ref->find(first_node); - if (ref_iter != nodes_ref->end()) { - ref_iter->second++; - } else { - (void)nodes_ref->insert(std::pair(first_node, 1)); - } - } - } -} void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref, std::map> *control_edges) { @@ -149,9 +67,6 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *n auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); for (auto &input : cnode->inputs()) { - if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { - AddControlEdge(graph, input, control_edges, nodes_ref); - } auto iter = nodes_ref->find(input); if (iter != nodes_ref->end()) { iter->second++; @@ -479,11 +394,9 @@ void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_ node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); } GraphSegmentPtr node_segment{nullptr}; - if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - auto node_iter = node_to_segment.find(node); - if (node_iter != node_to_segment.end()) { - node_segment = node_iter->second; - } + auto node_iter = node_to_segment.find(node); + if (node_iter != node_to_segment.end()) { + node_segment = node_iter->second; } for (auto &input : node_inputs) { if (node_segment != nullptr && !node_segment->is_cut_ && input->isa()) { @@ -615,18 +528,14 @@ void SplitDynamicNodeSegment(const std::vector &segment_nodes, std:: std::map *node_to_segment, const std::set &dynamic_nodes_set) { SplitDynamicNodesHelper helper; - bool is_last_node_dynamic = false; for (auto &node : segment_nodes) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - helper.AddNode(node, is_last_node_dynamic); - continue; - } auto &inputs = cnode->inputs(); bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end(); bool depend_common_node = false; bool depend_dynamic_node = false; + bool is_last_node_dynamic = false; for (size_t i = 1; i < inputs.size(); ++i) { if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) { has_dynamic_shape = true; diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 1a52f5014f..60d0e35d06 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -87,26 +87,7 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo if (node->isa() && !IsValueNode(node)) { eqv[node] = node; } else if (eqv.find(node) == eqv.end()) { - if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) { - eqv[node] = NewValueNode(MakeValue(0)); - return eqv[node]; - } - bool ignore_make_tuple = false; - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - ignore_make_tuple = true; - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - const auto &node_inputs = cnode->inputs(); - for (size_t i = 1; i < node_inputs.size(); ++i) { - if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) { - ignore_make_tuple = false; - break; - } - } - } - if (!ignore_make_tuple) { - inputs.push_back(node); - } + inputs.push_back(node); eqv[node] = fg->add_parameter(); eqv[node]->set_abstract(node->abstract()); eqv[node]->set_kernel_info(node->kernel_info_ptr()); @@ -148,14 +129,6 @@ std::tuple TransformSegmentToAnfGr for (size_t i = 2; i < inps.size(); ++i) { args.emplace_back(NewValueNode(MakeValue(0))); } - } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { - for (size_t i = 1; i < inps.size(); ++i) { - if (inps[i]->isa() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { - args.emplace_back(NewValueNode(MakeValue(static_cast(i)))); - } else { - args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv)); - } - } } else { (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index e8089e4f34..d25c6c51ef 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -182,8 +182,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 37688918b8..e71d398e9f 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -188,23 +188,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP return args_spec_list[0]->Broaden(); } -AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Two objects of a subclass of AbstractBase - CheckArgsSize(primitive->name(), args_spec_list, 2); - auto arg_src = args_spec_list[0]; - auto arg_dst = args_spec_list[1]; - // control depend can not setup tuple of ops to tuple of ops dependency relation - if (arg_src->isa() && arg_dst->isa()) { - auto src_size = arg_src->cast()->size(); - auto dst_size = arg_src->cast()->size(); - if (src_size > 1 && dst_size > 1) { - MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple"; - } - } - return std::make_shared(kAnyValue, kBool); -} - AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors and a tuple. diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 913e5a6181..6ed0136aab 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -149,7 +149,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, {prim::kPrimDepend, {InferImplDepend, nullptr, true}}, {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, - {prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}}, // Debug {prim::kPrimDebug, {InferImplDebug, nullptr, true}}, // Dynamic shape testing diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 9f446640df..2b6d46e292 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -451,7 +451,6 @@ inline const PrimitivePtr kPrimHookBackward = std::make_shared("HookB inline const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); inline const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); inline const PrimitivePtr kPrimPrint = std::make_shared("Print"); -inline const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); inline const PrimitivePtr kPrimIs_ = std::make_shared("is_"); inline const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); inline const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 379d0919a2..80c4a936d8 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -399,8 +399,7 @@ std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_tar if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || - IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || - IsPrimitive(attr_input, prim::kPrimPartial)) { + IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { primitive->EraseAttr(primitive_target); return default_target; } diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index 50e4ce1108..14a22629df 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ #include "ir/dtype/tensor_type.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" -#include "ops/control_depend.h" namespace mindspore { namespace ops { diff --git a/mindspore/core/ops/expand_dims.cc b/mindspore/core/ops/expand_dims.cc index 62d9ad5f68..c813af5c45 100644 --- a/mindspore/core/ops/expand_dims.cc +++ b/mindspore/core/ops/expand_dims.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ #include "ops/expand_dims.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" -#include "ops/control_depend.h" #include "utils/log_adapter.h" namespace mindspore { diff --git a/mindspore/core/utils/parallel_node_check.cc b/mindspore/core/utils/parallel_node_check.cc index e48335c0c0..855ac58e78 100644 --- a/mindspore/core/utils/parallel_node_check.cc +++ b/mindspore/core/utils/parallel_node_check.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", - "InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", + "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", "stop_gradient", "Send", "UpdateState", "Load"}; // clang-format on diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index ab02d82c47..b477839cef 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,6 @@ #include "abstract/abstract_value.h" #include "mindspore/core/ir/primitive.h" #include "ops/fusion/partial_fusion.h" -#include "ops/control_depend.h" #include "ops/depend.h" #include "ops/make_tuple.h" #include "ops/quant_dtype_cast.h" @@ -213,8 +212,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { MS_LOG(ERROR) << "value node is invalid."; return; } - if (value_node->value() != nullptr && (opt::CheckPrimitiveType(depend_node, prim::kPrimDepend) || - opt::CheckPrimitiveType(depend_node, prim::kPrimControlDepend))) { + if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { has_depend = true; bool mask_out = (depend_node->inputs().size() == 3); for (size_t j = 1; j < depend_node->inputs().size(); ++j) { @@ -466,8 +464,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrname() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend || - prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { + if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameTupleGetItem || + prim->name() == mindspore::ops::kNameMakeTuple) { continue; } if (prim->name() == "make_tuple") { diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 29cef5187f..394b51b1b9 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -57,8 +57,8 @@ bool IsRealKernel(const AnfNodePtr &node) { IsPrimitive(input, prim::kPrimTensorSummary) || IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || - IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || - IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); + IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) || + IsPrimitive(input, prim::kPrimPartial); return !is_virtual_node; } diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index f28bb2b842..da8bece8b0 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -43,8 +43,8 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { bool IsSpecialType(const CNodePtr &cnode) { if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || - CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || - CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared("While")) || + CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) || + CheckPrimitiveType(cnode, std::make_shared("While")) || CheckPrimitiveType(cnode, std::make_shared("If"))) { return true; } diff --git a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc index 87dde56380..41a3d354b9 100644 --- a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -58,13 +58,6 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph return lite::RET_NO_CHANGE; } } - if (CheckPrimitiveType(anf_node, prim::kPrimControlDepend)) { - if (cnode->size() != InputDoubleNum) { - MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; - remove_cnode_.insert(anf_node); - return lite::RET_NO_CHANGE; - } - } bool replace_succ = manager->Replace(anf_node, cnode->input(1)); if (!replace_succ) { diff --git a/model_zoo/research/nlp/gpt2/src/gpt2_for_finetune.py b/model_zoo/research/nlp/gpt2/src/gpt2_for_finetune.py index 3f5256e277..4bc7bd2a56 100644 --- a/model_zoo/research/nlp/gpt2/src/gpt2_for_finetune.py +++ b/model_zoo/research/nlp/gpt2/src/gpt2_for_finetune.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -96,7 +96,6 @@ class GPT2FinetuneCell(nn.Cell): self.get_status = P.NPUGetFloatStatus() self.clear_before_grad = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -132,8 +131,8 @@ class GPT2FinetuneCell(nn.Cell): if not self.gpu_target: init = self.alloc_status() + init = F.depend(init, loss) clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, @@ -145,10 +144,10 @@ class GPT2FinetuneCell(nn.Cell): if self.reducer_flag: grads = self.grad_reducer(grads) if not self.gpu_target: + init = F.depend(init, grads) flag = self.get_status(init) + init = F.depend(init, flag) flag_sum = self.reduce_sum(init, (0,)) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) else: flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) flag_sum = self.addn(flag_sum) diff --git a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc index f9cfe273bc..ae2e000b28 100644 --- a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc +++ b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,9 +74,9 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl /* * def before(x, y, a, b): * z = make_tuple(TransData(a), TransData(b)) - * depend_intput = control_depend(y, z) - * sum = add(x, depend_intput) - * return sum + * depend_intput = depend(y, z) + * sum_add = add(x, depend_intput) + * return sum_add */ FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before"); @@ -93,11 +93,11 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) { /* - * def before(x, y, a, b): - * z = make_tuple(TransData(a), TransData(b)) - * depend_intput = control_depend(y, z) - * sum = add(x, depend_intput) - * return sum + * def before(x, y, z): + * new_z = TransData(z) + * depend_intput = depend(y, new_z) + * sum_add = add(x, depend_intput) + * return sum_add */ FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before");