!9926 optimize list setitem in bprop

From: @zhangbuxue
Reviewed-by: @zh_qh
Signed-off-by: @zh_qh
pull/9926/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit dd134d7554

@ -25,7 +25,7 @@
#include "frontend/optimizer/irpass/inline.h" #include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/incorporate_call.h" #include "frontend/optimizer/irpass/incorporate_call.h"
#include "frontend/optimizer/irpass/incorporate_getitem.h" #include "frontend/optimizer/irpass/incorporate_getitem.h"
#include "frontend/optimizer/irpass/item_tuple_eliminate.h" #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
#include "frontend/optimizer/irpass/mark_interface_fusion.h" #include "frontend/optimizer/irpass/mark_interface_fusion.h"
#include "frontend/optimizer/irpass/merge_addn.h" #include "frontend/optimizer/irpass/merge_addn.h"
#include "frontend/optimizer/irpass/accumulaten_eliminate.h" #include "frontend/optimizer/irpass/accumulaten_eliminate.h"
@ -67,8 +67,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate // ops eliminate
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", item_tuple_or_list_eliminate_ = MakeSubstitution(
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); std::make_shared<ItemTupleOrListEliminater>(), "item_tuple_or_list_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile); tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);

@ -39,7 +39,7 @@ class OptimizeIRPassLib {
SubstitutionPtr adjust_all_reduce_mul_add_; SubstitutionPtr adjust_all_reduce_mul_add_;
// ops eliminate // ops eliminate
SubstitutionPtr item_tuple_eliminate_; SubstitutionPtr item_tuple_or_list_eliminate_;
SubstitutionPtr tile_eliminate_; SubstitutionPtr tile_eliminate_;
SubstitutionPtr cast_eliminate_; SubstitutionPtr cast_eliminate_;
SubstitutionPtr reshape_eliminate_; SubstitutionPtr reshape_eliminate_;

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
@ -33,6 +33,7 @@ namespace irpass {
// (a, b, c, ...)[0] => a // (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b // (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
class GetitemEliminater : public AnfVisitor { class GetitemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor {
void Visit(const ValueNodePtr &vnode) override { void Visit(const ValueNodePtr &vnode) override {
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) { if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
int64_t idx = GetValue<int64_t>(vnode->value()); auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) { if (idx < 0) {
idx = idx + tuple_->size() - 1; idx = idx + tuple_->size() - 1;
} }
@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor {
// (a, b, c, ...)[0] => a // (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b // (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C} // {prim::kPrimTupleGetItem, C1, C}
// {prim::kPrimListGetItem, C1, C}
class GetitemConstEliminater : public AnfVisitor { class GetitemConstEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor {
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) // setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) // setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
class SetitemEliminater : public AnfVisitor { class SetitemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
auto fg = node->func_graph(); auto fg = node->func_graph();
if (fg != nullptr && z_ != nullptr) { if (fg != nullptr && z_ != nullptr) {
@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor {
}; };
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
class GetSetitemEliminater : public AnfVisitor { class GetSetitemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node);
auto fg = node->func_graph(); auto fg = node->func_graph();
if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { if (fg != nullptr && key1_ >= 0 && key2_ >= 0) {
@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor {
} }
void Visit(const CNodePtr &cnode) override { void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) {
if (cnode->size() < 4) { if (cnode->size() < 4) {
return; return;
} }
@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor {
// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> // {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
class GetitemDependReorder : public AnfVisitor { class GetitemDependReorder : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
}; };
class ItemTupleEliminater : public OptimizerCaller { class ItemTupleOrListEliminater : public OptimizerCaller {
public: public:
ItemTupleEliminater() ItemTupleOrListEliminater()
: get_item_eliminater_(std::make_shared<GetitemEliminater>()), : get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(std::make_shared<SetitemEliminater>()), set_item_eliminater_(std::make_shared<SetitemEliminater>()),
@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller {
eliminaters_.emplace_back(get_set_item_eliminater_); eliminaters_.emplace_back(get_set_item_eliminater_);
eliminaters_.emplace_back(get_item_depend_reorder_); eliminaters_.emplace_back(get_item_depend_reorder_);
} }
~ItemTupleEliminater() = default; ~ItemTupleOrListEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller {
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_

@ -100,7 +100,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.specialize_transform_, irpass.specialize_transform_,
// Miscellaneous // Miscellaneous
irpass.item_tuple_eliminate_, irpass.item_tuple_or_list_eliminate_,
irpass.env_get_item_eliminate_, irpass.env_get_item_eliminate_,
irpass.cast_eliminate_, irpass.cast_eliminate_,
irpass.reshape_eliminate_, irpass.reshape_eliminate_,
@ -188,8 +188,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
} }
OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining opt::OptPassConfig d_1 =
irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_}); opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_});
OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
@ -198,7 +199,7 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig( opt::OptPassConfig b_1 = opt::OptPassConfig(
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});

@ -232,7 +232,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
equal to the case when `fn` is not None. equal to the case when `fn` is not None.
Examples: Examples:
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
... ...
>>> def tensor_add(x, y): >>> def tensor_add(x, y):
... z = x + y ... z = x + y

@ -360,7 +360,7 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_2 = std::make_shared<FuncGraph>(); FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2); after_2->set_output(value_node_2);
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
@ -372,7 +372,7 @@ TEST_F(TestOptLib, test_tuple_setitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
@ -384,7 +384,7 @@ TEST_F(TestOptLib, test_tuple_get_set_item) {
FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));

@ -13,9 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test enumerate""" """ test enumerate"""
import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -168,3 +173,60 @@ def test_list_index_3D_parameter():
net = Net() net = Net()
net(Tensor(0)) net(Tensor(0))
def test_const_list_index_3D_bprop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
self.relu = P.ReLU()
def construct(self, input_x):
list_x = self.value
list_x[2][0][1] = input_x
return self.relu(list_x[2][0][1])
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
def construct(self, x, sens):
return self.grad_all_with_sens(self.net)(x, sens)
net = Net()
grad_net = GradNet(net)
x = Tensor(np.arange(2 * 3).reshape(2, 3))
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
grad_net(x, sens)
def test_parameter_list_index_3D_bprop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
self.relu = P.ReLU()
def construct(self, x, value):
list_value = [[x], [x, x], [[x, x], [x, x]]]
list_value[2][0][1] = value
return self.relu(list_value[2][0][1])
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
def construct(self, x, value, sens):
return self.grad_all_with_sens(self.net)(x, value, sens)
net = Net()
grad_net = GradNet(net)
x = Tensor(np.arange(2 * 3).reshape(2, 3))
value = Tensor(np.ones((2, 3), np.int64))
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
grad_net(x, value, sens)

Loading…
Cancel
Save