!2426 add one opt pass which optimize tuple_getitem with constant input to const value

Merge pull request !2426 from xychow/add-tuple-getitem-const-pass
pull/2426/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d831474d33

@ -51,6 +51,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,

@ -33,6 +33,7 @@ class OptimizeIRPassLib {
~OptimizeIRPassLib() = default;
SubstitutionPtr arithmetic_simplify_;
SubstitutionPtr arithmetic_simplify2_;
SubstitutionPtr special_op_eliminate_;
SubstitutionPtr zero_like_fill_zero_;
SubstitutionPtr adjust_all_reduce_mul_add_;

File diff suppressed because it is too large Load Diff

@ -70,6 +70,45 @@ class GetitemEliminater : public AnfVisitor {
CNodePtr tuple_{nullptr};
};
// (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C}
class GetitemConstEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node);
if (is_match_) {
return NewValueNode((*tuple_)[id_]);
}
return nullptr;
}
void Visit(const ValueNodePtr &vnode) override {
if (IsValueNode<ValueTuple>(vnode)) {
tuple_ = GetValueNode<ValueTuplePtr>(vnode);
}
if (tuple_ != nullptr && IsValueNode<Int32Imm>(vnode)) {
id_ = IntToSize(GetValue<int>(vnode->value()));
if (tuple_->size() > id_) {
is_match_ = true;
}
}
}
void Reset() {
id_ = 0;
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
size_t id_{0};
ValueTuplePtr tuple_{nullptr};
};
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
@ -225,8 +264,13 @@ class GetitemDependReorder : public AnfVisitor {
class ItemTupleEliminater {
public:
ItemTupleEliminater()
: get_item_eliminater_(), set_item_eliminater_(), get_set_item_eliminater_(), get_item_depend_reorder_() {
: get_item_eliminater_(),
get_item_const_eliminater_(),
set_item_eliminater_(),
get_set_item_eliminater_(),
get_item_depend_reorder_() {
eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_);
eliminaters_.emplace_back(get_set_item_eliminater_);
eliminaters_.emplace_back(get_item_depend_reorder_);
@ -246,6 +290,7 @@ class ItemTupleEliminater {
private:
GetitemEliminater get_item_eliminater_;
GetitemConstEliminater get_item_const_eliminater_;
SetitemEliminater set_item_eliminater_;
GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_;

@ -114,6 +114,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.depend_value_elim_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.arithmetic_simplify2_,
irpass.same_eliminate_,
irpass.check_bprop_eliminate_,
irpass.replace_applicator_,

@ -20,9 +20,12 @@
#include "common/py_func_graph_fetcher.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
#include "ir/value.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "pipeline/resource.h"
#include "debug/draw.h"
@ -343,9 +346,26 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_1");
FuncGraphPtr make_get_const = std::make_shared<FuncGraph>();
auto value_node_1 = NewValueNode(1);
auto value_node_2 = NewValueNode(2);
std::vector<int> vec{1, 2};
auto value_node_tuple = NewValueNode(MakeValue(vec));
std::vector<AnfNodePtr> node_list{
NewValueNode(prim::kPrimTupleGetItem),
value_node_tuple,
value_node_1
};
auto get_item = make_get_const->NewCNode(node_list);
make_get_const->set_output(get_item);
FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2);
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
}
TEST_F(TestOptLib, test_tuple_setitem) {

Loading…
Cancel
Save