|
|
@ -38,7 +38,7 @@ namespace irpass {
|
|
|
|
// setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z)
|
|
|
|
// setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z)
|
|
|
|
// {prim::kPrimTupleSetItem, T, N, Z}
|
|
|
|
// {prim::kPrimTupleSetItem, T, N, Z}
|
|
|
|
// {prim::kPrimListSetItem, L, N, Z}
|
|
|
|
// {prim::kPrimListSetItem, L, N, Z}
|
|
|
|
class ConvertItemIndexToPositive : public AnfVisitor {
|
|
|
|
class TupleListConvertItemIndexToPositive : public AnfVisitor {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
Reset();
|
|
|
|
Reset();
|
|
|
@ -96,7 +96,7 @@ class ConvertItemIndexToPositive : public AnfVisitor {
|
|
|
|
// (a, b, c, ...)[1] => b
|
|
|
|
// (a, b, c, ...)[1] => b
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
|
|
|
|
class GetitemEliminator : public AnfVisitor {
|
|
|
|
class TupleListGetitemEliminator : public AnfVisitor {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
Reset();
|
|
|
|
Reset();
|
|
|
@ -144,7 +144,7 @@ class GetitemEliminator : public AnfVisitor {
|
|
|
|
// (a, b, c, ...)[1] => b
|
|
|
|
// (a, b, c, ...)[1] => b
|
|
|
|
// {prim::kPrimTupleGetItem, C1, C}
|
|
|
|
// {prim::kPrimTupleGetItem, C1, C}
|
|
|
|
// {prim::kPrimListGetItem, C1, C}
|
|
|
|
// {prim::kPrimListGetItem, C1, C}
|
|
|
|
class GetitemConstEliminator : public AnfVisitor {
|
|
|
|
class TupleListGetitemConstEliminator : public AnfVisitor {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
Reset();
|
|
|
|
Reset();
|
|
|
@ -195,7 +195,7 @@ class GetitemConstEliminator : public AnfVisitor {
|
|
|
|
// {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...}
|
|
|
|
// {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...}
|
|
|
|
// {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
|
|
|
|
// {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
|
|
|
|
// {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...}
|
|
|
|
// {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...}
|
|
|
|
class SetitemEliminator : public AnfVisitor {
|
|
|
|
class TupleListSetitemEliminator : public AnfVisitor {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
Reset();
|
|
|
|
Reset();
|
|
|
@ -277,7 +277,7 @@ class SetitemEliminator : public AnfVisitor {
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
|
|
|
|
class GetSetitemEliminator : public AnfVisitor {
|
|
|
|
class TupleListGetSetitemEliminator : public AnfVisitor {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
Reset();
|
|
|
|
Reset();
|
|
|
@ -348,7 +348,7 @@ class GetSetitemEliminator : public AnfVisitor {
|
|
|
|
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
|
|
|
|
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
|
|
|
|
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
|
|
|
|
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
|
|
|
|
class GetitemDependReorder : public AnfVisitor {
|
|
|
|
class TupleListGetitemDependReorder : public AnfVisitor {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
Reset();
|
|
|
|
Reset();
|
|
|
@ -405,41 +405,6 @@ class GetitemDependReorder : public AnfVisitor {
|
|
|
|
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
|
|
|
|
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ItemTupleOrListEliminator : public OptimizerCaller {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
ItemTupleOrListEliminator()
|
|
|
|
|
|
|
|
: get_item_eliminator_(std::make_shared<GetitemEliminator>()),
|
|
|
|
|
|
|
|
get_item_const_eliminator_(std::make_shared<GetitemConstEliminator>()),
|
|
|
|
|
|
|
|
set_item_eliminator_(std::make_shared<SetitemEliminator>()),
|
|
|
|
|
|
|
|
get_set_item_eliminator_(std::make_shared<GetSetitemEliminator>()),
|
|
|
|
|
|
|
|
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()),
|
|
|
|
|
|
|
|
convert_item_index_to_positive_(std::make_shared<ConvertItemIndexToPositive>()) {
|
|
|
|
|
|
|
|
eliminators_.emplace_back(get_item_eliminator_);
|
|
|
|
|
|
|
|
eliminators_.emplace_back(get_item_const_eliminator_);
|
|
|
|
|
|
|
|
eliminators_.emplace_back(set_item_eliminator_);
|
|
|
|
|
|
|
|
eliminators_.emplace_back(get_set_item_eliminator_);
|
|
|
|
|
|
|
|
eliminators_.emplace_back(get_item_depend_reorder_);
|
|
|
|
|
|
|
|
eliminators_.emplace_back(convert_item_index_to_positive_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
~ItemTupleOrListEliminator() = default;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
|
|
|
|
|
|
|
AnfNodePtr new_node;
|
|
|
|
|
|
|
|
for (auto &eliminator : eliminators_) {
|
|
|
|
|
|
|
|
new_node = (*eliminator)(optimizer, node);
|
|
|
|
|
|
|
|
if (new_node != nullptr) {
|
|
|
|
|
|
|
|
return new_node;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
OptimizerCallerPtr get_item_eliminator_, get_item_const_eliminator_, set_item_eliminator_, get_set_item_eliminator_,
|
|
|
|
|
|
|
|
get_item_depend_reorder_, convert_item_index_to_positive_;
|
|
|
|
|
|
|
|
std::vector<OptimizerCallerPtr> eliminators_{};
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace irpass
|
|
|
|
} // namespace irpass
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|
|
|
|