|
|
|
@ -14,8 +14,8 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
|
|
|
|
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
|
|
|
|
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
|
|
|
|
|
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <memory>
|
|
|
|
@ -33,6 +33,7 @@ namespace irpass {
|
|
|
|
|
// (a, b, c, ...)[0] => a
|
|
|
|
|
// (a, b, c, ...)[1] => b
|
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
|
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
|
|
|
|
|
class GetitemEliminater : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor {
|
|
|
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override {
|
|
|
|
|
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
|
|
|
|
|
int64_t idx = GetValue<int64_t>(vnode->value());
|
|
|
|
|
auto idx = GetValue<int64_t>(vnode->value());
|
|
|
|
|
if (idx < 0) {
|
|
|
|
|
idx = idx + tuple_->size() - 1;
|
|
|
|
|
}
|
|
|
|
@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor {
|
|
|
|
|
// (a, b, c, ...)[0] => a
|
|
|
|
|
// (a, b, c, ...)[1] => b
|
|
|
|
|
// {prim::kPrimTupleGetItem, C1, C}
|
|
|
|
|
// {prim::kPrimListGetItem, C1, C}
|
|
|
|
|
class GetitemConstEliminater : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor {
|
|
|
|
|
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
|
|
|
|
|
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
|
|
|
|
|
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
|
|
|
|
|
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
|
|
|
|
|
class SetitemEliminater : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
Reset();
|
|
|
|
|
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
|
|
|
|
|
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
|
|
|
|
|
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
if (fg != nullptr && z_ != nullptr) {
|
|
|
|
@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
|
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
|
|
|
|
|
class GetSetitemEliminater : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
Reset();
|
|
|
|
|
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
|
|
|
|
|
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node);
|
|
|
|
|
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
if (fg != nullptr && key1_ >= 0 && key2_ >= 0) {
|
|
|
|
@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const CNodePtr &cnode) override {
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) {
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) {
|
|
|
|
|
if (cnode->size() < 4) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor {
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} ->
|
|
|
|
|
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
|
|
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
|
|
|
|
|
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
|
|
|
|
|
class GetitemDependReorder : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor {
|
|
|
|
|
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ItemTupleEliminater : public OptimizerCaller {
|
|
|
|
|
class ItemTupleOrListEliminater : public OptimizerCaller {
|
|
|
|
|
public:
|
|
|
|
|
ItemTupleEliminater()
|
|
|
|
|
ItemTupleOrListEliminater()
|
|
|
|
|
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
|
|
|
|
|
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
|
|
|
|
|
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
|
|
|
|
@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller {
|
|
|
|
|
eliminaters_.emplace_back(get_set_item_eliminater_);
|
|
|
|
|
eliminaters_.emplace_back(get_item_depend_reorder_);
|
|
|
|
|
}
|
|
|
|
|
~ItemTupleEliminater() = default;
|
|
|
|
|
~ItemTupleOrListEliminater() = default;
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
|
|
|
|
AnfNodePtr new_node;
|
|
|
|
@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller {
|
|
|
|
|
} // namespace irpass
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
|
|
|
|
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
|