You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
7.0 KiB
217 lines
7.0 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include <iostream>
|
|
#include <memory>
|
|
|
|
#include "common/common_test.h"
|
|
#include "common/py_func_graph_fetcher.h"
|
|
|
|
#include "ir/anf.h"
|
|
#include "ir/visitor.h"
|
|
#include "ir/func_graph_cloner.h"
|
|
#include "frontend/optimizer/opt.h"
|
|
#include "frontend/optimizer/anf_visitor.h"
|
|
#include "frontend/optimizer/irpass.h"
|
|
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
|
|
|
|
#include "debug/draw.h"
|
|
#include "frontend/operator/ops.h"
|
|
#include "frontend/optimizer/cse.h"
|
|
|
|
namespace mindspore {
|
|
namespace opt {
|
|
class TestOptOpt : public UT::Common {
|
|
public:
|
|
TestOptOpt() : getPyFun("gtest_input.optimizer.opt_test", true) {}
|
|
|
|
class IdempotentEliminater : public AnfVisitor {
|
|
public:
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
x_ = nullptr;
|
|
AnfVisitor::Match(P, {irpass::IsCNode})(node);
|
|
if (x_ == nullptr || node->func_graph() == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
return node->func_graph()->NewCNode({NewValueNode(P), x_});
|
|
};
|
|
|
|
void Visit(const CNodePtr &cnode) override {
|
|
if (IsPrimitiveCNode(cnode, P) && cnode->inputs().size() == 2) {
|
|
x_ = cnode->input(1);
|
|
}
|
|
}
|
|
|
|
private:
|
|
AnfNodePtr x_{nullptr};
|
|
};
|
|
|
|
class QctToP : public AnfVisitor {
|
|
public:
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
v_ = nullptr;
|
|
AnfVisitor::Match(Q, {irpass::IsVNode})(node);
|
|
if (v_ == nullptr || node->func_graph() == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
return node->func_graph()->NewCNode({NewValueNode(P), v_});
|
|
};
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { v_ = vnode; }
|
|
|
|
private:
|
|
AnfNodePtr v_{nullptr};
|
|
};
|
|
|
|
void SetUp() {
|
|
elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd);
|
|
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
|
|
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
|
|
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
|
|
}
|
|
|
|
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
|
|
equiv_node.clear();
|
|
equiv_graph.clear();
|
|
|
|
FuncGraphPtr gbefore_clone = BasicClone(gbefore);
|
|
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
|
|
transform(gbefore_clone, optimizer);
|
|
|
|
return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node);
|
|
}
|
|
|
|
bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}) {
|
|
SubstitutionList eq(opts);
|
|
return CheckTransform(before, after, eq);
|
|
}
|
|
|
|
public:
|
|
UT::PyFuncGraphFetcher getPyFun;
|
|
|
|
FuncGraphPairMapEquiv equiv_graph;
|
|
NodeMapEquiv equiv_node;
|
|
|
|
static const PrimitivePtr P;
|
|
static const PrimitivePtr Q;
|
|
static const PrimitivePtr R;
|
|
|
|
SubstitutionPtr elim_Z;
|
|
SubstitutionPtr elim_R;
|
|
SubstitutionPtr idempotent_P;
|
|
SubstitutionPtr Qct_to_P;
|
|
};
|
|
|
|
const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
|
|
const PrimitivePtr TestOptOpt::Q = std::make_shared<Primitive>("Q");
|
|
const PrimitivePtr TestOptOpt::R = std::make_shared<Primitive>("R");
|
|
|
|
TEST_F(TestOptOpt, TestCheckOptIsClone) {
|
|
FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
|
|
|
|
ASSERT_TRUE(nullptr != before);
|
|
ASSERT_TRUE(CheckOpt(before, before));
|
|
ASSERT_FALSE(CheckOpt(before, before, std::vector<SubstitutionPtr>({elim_Z})));
|
|
}
|
|
|
|
TEST_F(TestOptOpt, Elim) {
|
|
FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
|
|
FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
|
|
|
|
ASSERT_TRUE(nullptr != before);
|
|
ASSERT_TRUE(nullptr != after);
|
|
ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
|
|
}
|
|
|
|
TEST_F(TestOptOpt, ElimTwo) {
|
|
FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_2");
|
|
FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
|
|
|
|
ASSERT_TRUE(nullptr != before);
|
|
ASSERT_TRUE(nullptr != after);
|
|
ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
|
|
}
|
|
|
|
TEST_F(TestOptOpt, ElimR) {
|
|
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1");
|
|
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after");
|
|
|
|
ASSERT_TRUE(nullptr != before);
|
|
ASSERT_TRUE(nullptr != after);
|
|
ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_R})));
|
|
}
|
|
|
|
TEST_F(TestOptOpt, idempotent) {
|
|
FuncGraphPtr before_2 = getPyFun.CallAndParseRet("test_idempotent", "before_2");
|
|
FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_idempotent", "before_1");
|
|
FuncGraphPtr after = getPyFun.CallAndParseRet("test_idempotent", "after");
|
|
|
|
ASSERT_TRUE(nullptr != before_2);
|
|
ASSERT_TRUE(nullptr != before_1);
|
|
ASSERT_TRUE(nullptr != after);
|
|
|
|
ASSERT_TRUE(CheckOpt(before_1, after, std::vector<SubstitutionPtr>({idempotent_P})));
|
|
ASSERT_TRUE(CheckOpt(before_2, after, std::vector<SubstitutionPtr>({idempotent_P})));
|
|
}
|
|
|
|
TEST_F(TestOptOpt, ConstantVariable) {
|
|
FuncGraphPtr before = getPyFun.CallAndParseRet("test_constant_variable", "before_1");
|
|
FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_variable", "after");
|
|
|
|
ASSERT_TRUE(nullptr != before);
|
|
ASSERT_TRUE(nullptr != after);
|
|
ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({Qct_to_P})));
|
|
}
|
|
|
|
TEST_F(TestOptOpt, CSE) {
|
|
// test a simple cse testcase test_f1
|
|
FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_cse", "test_f1");
|
|
|
|
ASSERT_TRUE(nullptr != test_graph1);
|
|
|
|
// add func_graph the GraphManager
|
|
FuncGraphManagerPtr manager1 = Manage(test_graph1);
|
|
draw::Draw("opt_cse_before_1.dot", test_graph1);
|
|
|
|
ASSERT_EQ(manager1->all_nodes().size(), 9);
|
|
|
|
auto cse = std::make_shared<CSE>();
|
|
ASSERT_TRUE(cse != nullptr);
|
|
bool is_changed = cse->Cse(test_graph1, manager1);
|
|
|
|
ASSERT_TRUE(is_changed);
|
|
ASSERT_EQ(manager1->all_nodes().size(), 8);
|
|
|
|
draw::Draw("opt_cse_after_1.dot", test_graph1);
|
|
|
|
// test a more complicated case test_f2
|
|
FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_cse", "test_f2");
|
|
|
|
ASSERT_TRUE(nullptr != test_graph2);
|
|
|
|
FuncGraphManagerPtr manager2 = Manage(test_graph2);
|
|
draw::Draw("opt_cse_before_2.dot", test_graph2);
|
|
ASSERT_EQ(manager2->all_nodes().size(), 16);
|
|
is_changed = cse->Cse(test_graph2, manager2);
|
|
ASSERT_TRUE(is_changed);
|
|
ASSERT_EQ(manager2->all_nodes().size(), 12);
|
|
draw::Draw("opt_cse_after_2.dot", test_graph2);
|
|
}
|
|
|
|
} // namespace opt
|
|
} // namespace mindspore
|