diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc index 6ae3f3be36..db32354abf 100644 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc @@ -27,6 +27,69 @@ namespace mindspore { namespace opt { constexpr auto kSingleInputIndex = 1; +namespace { +AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + string op_name = AnfAlgo::GetCNodeName(cnode); + // Currently we only eliminate transdata or cast nodes. + if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { + return nullptr; + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + // Check whether the node has only one output node. + if (manager->node_users().find(cnode) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "The node should be used by at least another node's input"; + } + if (manager->node_users()[cnode].size() > 1) { + return nullptr; + } + CheckCNodeInputSize(cnode, kSingleInputIndex + 1); + return cnode->input(kSingleInputIndex); +} + +bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { + return false; + } + std::vector new_make_tuple_inputs; + bool need_update = false; + for (const auto &input : cnode->inputs()) { + AnfNodePtr replace_input = GetReplaceNode(func_graph, input); + // If replace input is not null, it will be the input of the TransData or Cast. + if (replace_input == nullptr) { + new_make_tuple_inputs.push_back(input); + continue; + } + new_make_tuple_inputs.push_back(replace_input); + need_update = true; + } + if (need_update) { + auto kernel_graph = func_graph->cast>(); + CNodePtr new_make_tuple = nullptr; + if (kernel_graph == nullptr) { + new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs); + } else { + new_make_tuple = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_make_tuple); + new_make_tuple->set_inputs(new_make_tuple_inputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(cnode, new_make_tuple); + } + return true; +} +} // namespace + const BaseRef OptimizeDependence::DefinePattern() const { VarPtr X = std::make_shared("X"); MS_EXCEPTION_IF_NULL(X); @@ -43,9 +106,8 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con return nullptr; } auto depend_cnode = node->cast(); - if (depend_cnode->inputs().size() < kDependInputNum) { - return nullptr; - } + MS_EXCEPTION_IF_NULL(depend_cnode); + CheckCNodeInputSize(depend_cnode, kDependInputNum); auto replacing_node = depend_cnode->input(kDependInputNum - 1); MS_EXCEPTION_IF_NULL(replacing_node); if (!replacing_node->isa()) { @@ -53,36 +115,29 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con } auto replacing_cnode = replacing_node->cast(); MS_EXCEPTION_IF_NULL(replacing_cnode); - // Currently we only optimize transdata or cast nodes. - string replacing_cnode_op_name = AnfAlgo::GetCNodeName(replacing_cnode); - if (replacing_cnode_op_name != kTransDataOpName && replacing_cnode_op_name != prim::kPrimCast->name()) { + // Deal with the make_tuple with TransData or Cast inputs. + if (ReplaceMakeTuple(func_graph, replacing_cnode)) { return nullptr; } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - // Check whether the replacing node has only one input and one output. - if (replacing_cnode->inputs().size() != kSingleInputIndex + 1) { - return nullptr; - } - if (manager->node_users().find(replacing_node) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The node should be used by at least another node input"; - } - if (manager->node_users()[replacing_node].size() > 1) { + AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode); + if (replace_node == nullptr) { + MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); return nullptr; } std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex), - depend_cnode->input(kRealInputIndexInDepend), - replacing_cnode->input(kSingleInputIndex)}; + depend_cnode->input(kRealInputIndexInDepend), replace_node}; auto kernel_graph = func_graph->cast>(); CNodePtr new_depend; if (kernel_graph == nullptr) { new_depend = func_graph->NewCNode(new_depend_inputs); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_abstract(node->abstract()); + new_depend->set_scope(node->scope()); } else { new_depend = kernel_graph->NewCNode(depend_cnode); MS_EXCEPTION_IF_NULL(new_depend); new_depend->set_inputs(new_depend_inputs); } - new_depend->set_abstract(node->abstract()); return new_depend; } } // namespace opt diff --git a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc index 3f59b6159a..e95d63e93e 100644 --- a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc +++ b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc @@ -48,5 +48,25 @@ TEST_F(TestHWOptimizeDependence, test_optimize_dependence) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWOptimizeDependence, test_optimize_dependence_with_make_tuple) { + /* + * def before(x, y, a, b): + * z = make_tuple(TransData(a), TransData(b)) + * depend_intput = depend(y, z) + * sum = add(x, depend_intput) + * return sum + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "before"); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(g); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py index 45c419d25d..05eb057327 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py @@ -18,6 +18,8 @@ from mindspore.ops import Primitive depend = Primitive('depend') TransData = Primitive('TransData') add = P.TensorAdd() +make_tuple = Primitive('make_tuple') + class FnDict: def __init__(self): @@ -47,3 +49,23 @@ def test_optimize_dependence(tag): return sum return fns[tag] + + +def test_optimize_dependence_with_make_tuple(tag): + fns = FnDict() + + @fns + def before(x, y, a, b): + z = make_tuple(TransData(a), TransData(b)) + depend_intput = depend(y, z) + sum = add(x, depend_intput) + return sum + + @fns + def after(x, y, a, b): + z = make_tuple(a, b) + depend_intput = depend(y, z) + sum = add(x, depend_intput) + return sum + + return fns[tag]