!14407 separate irpass EnvGetItemEliminater and ItemTupleOrListEliminator

From: @huangbingjian
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/14407/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ce248c37e0

@ -70,9 +70,25 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
float_depend_g_call_ = MakeSubstitution(std::make_shared<FloatDependGCall>(), "float_depend_g_call", IsCNodeDup); float_depend_g_call_ = MakeSubstitution(std::make_shared<FloatDependGCall>(), "float_depend_g_call", IsCNodeDup);
// ops eliminate // ops eliminate
item_tuple_or_list_eliminate_ = MakeSubstitution( tuple_list_get_item_eliminator_ =
std::make_shared<ItemTupleOrListEliminator>(), "item_tuple_or_list_eliminate", MakeSubstitution(std::make_shared<TupleListGetitemEliminator>(), "tuple_list_get_item_eliminator",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_get_item_const_eliminator_ =
MakeSubstitution(std::make_shared<TupleListGetitemConstEliminator>(), "tuple_list_get_item_const_eliminator",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_set_item_eliminator_ =
MakeSubstitution(std::make_shared<TupleListSetitemEliminator>(), "tuple_list_set_item_eliminator",
{prim::kPrimTupleSetItem, prim::kPrimListSetItem});
tuple_list_get_set_item_eliminator_ =
MakeSubstitution(std::make_shared<TupleListGetSetitemEliminator>(), "tuple_list_get_set_item_eliminator",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_get_item_depend_reorder_ =
MakeSubstitution(std::make_shared<TupleListGetitemDependReorder>(), "tuple_list_get_item_depend_reorder",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_convert_item_index_to_positive_ = MakeSubstitution(
std::make_shared<TupleListConvertItemIndexToPositive>(), "tuple_list_convert_item_index_to_positive",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem}); {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile); tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
@ -99,7 +115,13 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Env Item Eliminate // Env Item Eliminate
env_get_item_eliminate_ = env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem); env_get_item_add_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemAddEliminater>(), "env_get_item_add_eliminate_", prim::kPrimEnvGetItem);
env_get_set_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetSetItemEliminater>(), "env_get_set_item_eliminate", prim::kPrimEnvGetItem);
env_get_item_depend_swap_ =
MakeSubstitution(std::make_shared<EnvGetItemDependSwap>(), "env_get_item_depend_swap", prim::kPrimEnvGetItem);
incorporate_env_getitem_bypass_recursive_ = incorporate_env_getitem_bypass_recursive_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(), incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),

@ -39,7 +39,13 @@ class OptimizeIRPassLib {
SubstitutionPtr adjust_all_reduce_mul_add_; SubstitutionPtr adjust_all_reduce_mul_add_;
SubstitutionPtr float_depend_g_call_; SubstitutionPtr float_depend_g_call_;
// ops eliminate // ops eliminate
SubstitutionPtr item_tuple_or_list_eliminate_; SubstitutionPtr tuple_list_get_item_eliminator_;
SubstitutionPtr tuple_list_get_item_const_eliminator_;
SubstitutionPtr tuple_list_set_item_eliminator_;
SubstitutionPtr tuple_list_get_set_item_eliminator_;
SubstitutionPtr tuple_list_get_item_depend_reorder_;
SubstitutionPtr tuple_list_convert_item_index_to_positive_;
SubstitutionPtr tile_eliminate_; SubstitutionPtr tile_eliminate_;
SubstitutionPtr cast_eliminate_; SubstitutionPtr cast_eliminate_;
SubstitutionPtr reshape_eliminate_; SubstitutionPtr reshape_eliminate_;
@ -57,7 +63,9 @@ class OptimizeIRPassLib {
// Env Item Eliminate // Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr env_get_item_eliminate_;
SubstitutionPtr new_env_get_item_; SubstitutionPtr env_get_item_add_eliminate_;
SubstitutionPtr env_get_set_item_eliminate_;
SubstitutionPtr env_get_item_depend_swap_;
SubstitutionPtr incorporate_env_getitem_; SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_; SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_; SubstitutionPtr incorporate_env_getitem_switch_;

@ -157,7 +157,7 @@ class EnvGetitemTransformACrossGraph {
} // namespace internal } // namespace internal
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
class NewEnvGetItem : public AnfVisitor { class EnvGetItemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode c1, c2, y; PatternNode c1, c2, y;
@ -170,10 +170,10 @@ class NewEnvGetItem : public AnfVisitor {
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} // {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}}
class AddEnvGetItem : public AnfVisitor { class EnvGetItemAddEliminater : public AnfVisitor {
public: public:
AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} EnvGetItemAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
~AddEnvGetItem() override = default; ~EnvGetItemAddEliminater() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false; is_match_ = false;
@ -211,7 +211,7 @@ class AddEnvGetItem : public AnfVisitor {
}; };
// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} // {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z}
class EnvGetSetItem : public AnfVisitor { class EnvGetSetItemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false; is_match_ = false;
@ -281,7 +281,7 @@ class EnvGetSetItem : public AnfVisitor {
// {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} -> // {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} ->
// {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2} // {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2}
class SwapEnvGetItemDepend : public OptimizerCaller { class EnvGetItemDependSwap : public OptimizerCaller {
public: public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) { if (!node->isa<CNode>() || node->func_graph() == nullptr) {
@ -297,36 +297,6 @@ class SwapEnvGetItemDepend : public OptimizerCaller {
} }
}; };
class EnvGetItemEliminater : public OptimizerCaller {
public:
EnvGetItemEliminater()
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
env_get_set_item_(std::make_shared<EnvGetSetItem>()),
swap_env_get_item_depend_(std::make_shared<SwapEnvGetItemDepend>()) {
eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_);
eliminaters_.emplace_back(swap_env_get_item_depend_);
}
~EnvGetItemEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_, swap_env_get_item_depend_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} // {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor { class IncorporateEnvGetitem : public AnfVisitor {
public: public:

@ -38,7 +38,7 @@ namespace irpass {
// setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z) // setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z)
// {prim::kPrimTupleSetItem, T, N, Z} // {prim::kPrimTupleSetItem, T, N, Z}
// {prim::kPrimListSetItem, L, N, Z} // {prim::kPrimListSetItem, L, N, Z}
class ConvertItemIndexToPositive : public AnfVisitor { class TupleListConvertItemIndexToPositive : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
@ -96,7 +96,7 @@ class ConvertItemIndexToPositive : public AnfVisitor {
// (a, b, c, ...)[1] => b // (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C} // {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
class GetitemEliminator : public AnfVisitor { class TupleListGetitemEliminator : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
@ -144,7 +144,7 @@ class GetitemEliminator : public AnfVisitor {
// (a, b, c, ...)[1] => b // (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C} // {prim::kPrimTupleGetItem, C1, C}
// {prim::kPrimListGetItem, C1, C} // {prim::kPrimListGetItem, C1, C}
class GetitemConstEliminator : public AnfVisitor { class TupleListGetitemConstEliminator : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
@ -195,7 +195,7 @@ class GetitemConstEliminator : public AnfVisitor {
// {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...} // {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...}
// {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...} // {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
// {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...} // {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...}
class SetitemEliminator : public AnfVisitor { class TupleListSetitemEliminator : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
@ -277,7 +277,7 @@ class SetitemEliminator : public AnfVisitor {
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2} // {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
class GetSetitemEliminator : public AnfVisitor { class TupleListGetSetitemEliminator : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
@ -348,7 +348,7 @@ class GetSetitemEliminator : public AnfVisitor {
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} -> // {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y} // {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
class GetitemDependReorder : public AnfVisitor { class TupleListGetitemDependReorder : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
@ -405,41 +405,6 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
}; };
class ItemTupleOrListEliminator : public OptimizerCaller {
public:
ItemTupleOrListEliminator()
: get_item_eliminator_(std::make_shared<GetitemEliminator>()),
get_item_const_eliminator_(std::make_shared<GetitemConstEliminator>()),
set_item_eliminator_(std::make_shared<SetitemEliminator>()),
get_set_item_eliminator_(std::make_shared<GetSetitemEliminator>()),
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()),
convert_item_index_to_positive_(std::make_shared<ConvertItemIndexToPositive>()) {
eliminators_.emplace_back(get_item_eliminator_);
eliminators_.emplace_back(get_item_const_eliminator_);
eliminators_.emplace_back(set_item_eliminator_);
eliminators_.emplace_back(get_set_item_eliminator_);
eliminators_.emplace_back(get_item_depend_reorder_);
eliminators_.emplace_back(convert_item_index_to_positive_);
}
~ItemTupleOrListEliminator() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminator : eliminators_) {
new_node = (*eliminator)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
OptimizerCallerPtr get_item_eliminator_, get_item_const_eliminator_, set_item_eliminator_, get_set_item_eliminator_,
get_item_depend_reorder_, convert_item_index_to_positive_;
std::vector<OptimizerCallerPtr> eliminators_{};
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -112,8 +112,18 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.replace_applicator_, irpass.replace_applicator_,
// Miscellaneous // Miscellaneous
irpass.item_tuple_or_list_eliminate_, irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_,
irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_,
irpass.tuple_list_convert_item_index_to_positive_,
irpass.env_get_item_eliminate_, irpass.env_get_item_eliminate_,
irpass.env_get_item_add_eliminate_,
irpass.env_get_set_item_eliminate_,
irpass.env_get_item_depend_swap_,
irpass.cast_eliminate_, irpass.cast_eliminate_,
irpass.reshape_eliminate_, irpass.reshape_eliminate_,
irpass.reduce_eliminate_, irpass.reduce_eliminate_,
@ -146,7 +156,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_call_switch_, irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_bypass_recursive_, irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_, irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_, irpass.env_get_item_eliminate_,
irpass.depend_value_elim_, irpass.depend_value_elim_,
irpass.all_reduce_const_elim_, irpass.all_reduce_const_elim_,
}, },
@ -220,7 +230,10 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig d_1 = opt::OptPassConfig d_1 =
opt::OptPassConfig({// Safe inlining opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_}); irpass.call_graph_tuple_transform_, irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_get_item_const_eliminator_, irpass.tuple_list_set_item_eliminator_,
irpass.tuple_list_get_set_item_eliminator_, irpass.tuple_list_get_item_depend_reorder_,
irpass.tuple_list_convert_item_index_to_positive_});
OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
@ -228,13 +241,31 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib
} }
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig( opt::OptPassConfig b_1 = opt::OptPassConfig({irpass.zero_like_fill_zero_,
{irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, irpass.tuple_list_get_item_eliminator_,
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, irpass.tuple_list_get_item_const_eliminator_,
irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.tuple_list_set_item_eliminator_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.tuple_list_get_set_item_eliminator_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}, irpass.tuple_list_get_item_depend_reorder_,
false, true); irpass.tuple_list_convert_item_index_to_positive_,
irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_,
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_switch_,
irpass.env_get_item_eliminate_,
irpass.env_get_item_add_eliminate_,
irpass.env_get_set_item_eliminate_,
irpass.env_get_item_depend_swap_,
irpass.incorporate_env_getitem_switch_layer_,
irpass.value_based_eliminate_,
irpass.receive_eliminate_},
false, true);
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,

@ -15,9 +15,9 @@
import os import os
import sys import sys
import json import json
import openpyxl as opx
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import openpyxl as opx
def parse_arguments(): def parse_arguments():

@ -355,7 +355,10 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_2 = std::make_shared<FuncGraph>(); FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2); after_2->set_output(value_node_2);
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); auto patterns = std::vector<SubstitutionPtr>(
{irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
@ -367,7 +370,10 @@ TEST_F(TestOptLib, test_tuple_setitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); auto patterns = std::vector<SubstitutionPtr>(
{irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
@ -379,7 +385,10 @@ TEST_F(TestOptLib, test_tuple_get_set_item) {
FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); auto patterns = std::vector<SubstitutionPtr>(
{irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));

Loading…
Cancel
Save