/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/anf.h" #include "ir/visitor.h" #include "ir/func_graph_cloner.h" #include "frontend/optimizer/opt.h" #include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/arithmetic_simplify.h" #include "debug/draw.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/cse.h" namespace mindspore { namespace opt { class TestOptOpt : public UT::Common { public: TestOptOpt() : getPyFun("gtest_input.optimizer.opt_test", true) {} class IdempotentEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { x_ = nullptr; AnfVisitor::Match(P, {irpass::IsCNode})(node); if (x_ == nullptr || node->func_graph() == nullptr) { return nullptr; } return node->func_graph()->NewCNode({NewValueNode(P), x_}); }; void Visit(const CNodePtr &cnode) override { if (IsPrimitiveCNode(cnode, P) && cnode->inputs().size() == 2) { x_ = cnode->input(1); } } private: AnfNodePtr x_{nullptr}; }; class QctToP : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { v_ = nullptr; AnfVisitor::Match(Q, {irpass::IsVNode})(node); if (v_ == nullptr || node->func_graph() == nullptr) { return nullptr; } return node->func_graph()->NewCNode({NewValueNode(P), v_}); }; void Visit(const ValueNodePtr &vnode) override { v_ = vnode; } private: AnfNodePtr v_{nullptr}; }; void SetUp() { elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); elim_R = MakeSubstitution(std::make_shared(R), "elim_R", R); idempotent_P = MakeSubstitution(std::make_shared(), "idempotent_P", P); Qct_to_P = MakeSubstitution(std::make_shared(), "Qct_to_P", Q); } bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { equiv_node.clear(); equiv_graph.clear(); FuncGraphPtr gbefore_clone = BasicClone(gbefore); OptimizerPtr optimizer = std::make_shared("ut_test", std::make_shared()); transform(gbefore_clone, optimizer); return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node); } bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector opts = {}) { SubstitutionList eq(opts); return CheckTransform(before, after, eq); } public: UT::PyFuncGraphFetcher getPyFun; FuncGraphPairMapEquiv equiv_graph; NodeMapEquiv equiv_node; static const PrimitivePtr P; static const PrimitivePtr Q; static const PrimitivePtr R; SubstitutionPtr elim_Z; SubstitutionPtr elim_R; SubstitutionPtr idempotent_P; SubstitutionPtr Qct_to_P; }; const PrimitivePtr TestOptOpt::P = std::make_shared("P"); const PrimitivePtr TestOptOpt::Q = std::make_shared("Q"); const PrimitivePtr TestOptOpt::R = std::make_shared("R"); TEST_F(TestOptOpt, TestCheckOptIsClone) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(CheckOpt(before, before)); ASSERT_FALSE(CheckOpt(before, before, std::vector({elim_Z}))); } TEST_F(TestOptOpt, Elim) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({elim_Z}))); } TEST_F(TestOptOpt, ElimTwo) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_2"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({elim_Z}))); } TEST_F(TestOptOpt, ElimR) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({elim_R}))); } TEST_F(TestOptOpt, idempotent) { FuncGraphPtr before_2 = getPyFun.CallAndParseRet("test_idempotent", "before_2"); FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_idempotent", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_idempotent", "after"); ASSERT_TRUE(nullptr != before_2); ASSERT_TRUE(nullptr != before_1); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before_1, after, std::vector({idempotent_P}))); ASSERT_TRUE(CheckOpt(before_2, after, std::vector({idempotent_P}))); } TEST_F(TestOptOpt, ConstantVariable) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_constant_variable", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_variable", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({Qct_to_P}))); } TEST_F(TestOptOpt, CSE) { // test a simple cse testcase test_f1 FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_cse", "test_f1"); ASSERT_TRUE(nullptr != test_graph1); // add func_graph the GraphManager FuncGraphManagerPtr manager1 = Manage(test_graph1); draw::Draw("opt_cse_before_1.dot", test_graph1); ASSERT_EQ(manager1->all_nodes().size(), 9); auto cse = std::make_shared(); ASSERT_TRUE(cse != nullptr); bool is_changed = cse->Cse(test_graph1, manager1); ASSERT_TRUE(is_changed); ASSERT_EQ(manager1->all_nodes().size(), 8); draw::Draw("opt_cse_after_1.dot", test_graph1); // test a more complicated case test_f2 FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_cse", "test_f2"); ASSERT_TRUE(nullptr != test_graph2); FuncGraphManagerPtr manager2 = Manage(test_graph2); draw::Draw("opt_cse_before_2.dot", test_graph2); ASSERT_EQ(manager2->all_nodes().size(), 16); is_changed = cse->Cse(test_graph2, manager2); ASSERT_TRUE(is_changed); ASSERT_EQ(manager2->all_nodes().size(), 12); draw::Draw("opt_cse_after_2.dot", test_graph2); } } // namespace opt } // namespace mindspore