!9926 optimize list setitem in bprop

From: @zhangbuxue
Reviewed-by: @zh_qh
Signed-off-by: @zh_qh
pull/9926/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit dd134d7554

@ -25,7 +25,7 @@
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/incorporate_call.h"
#include "frontend/optimizer/irpass/incorporate_getitem.h"
#include "frontend/optimizer/irpass/item_tuple_eliminate.h"
#include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
#include "frontend/optimizer/irpass/mark_interface_fusion.h"
#include "frontend/optimizer/irpass/merge_addn.h"
#include "frontend/optimizer/irpass/accumulaten_eliminate.h"
@ -67,8 +67,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem});
item_tuple_or_list_eliminate_ = MakeSubstitution(
std::make_shared<ItemTupleOrListEliminater>(), "item_tuple_or_list_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);

@ -39,7 +39,7 @@ class OptimizeIRPassLib {
SubstitutionPtr adjust_all_reduce_mul_add_;
// ops eliminate
SubstitutionPtr item_tuple_eliminate_;
SubstitutionPtr item_tuple_or_list_eliminate_;
SubstitutionPtr tile_eliminate_;
SubstitutionPtr cast_eliminate_;
SubstitutionPtr reshape_eliminate_;

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
#include <algorithm>
#include <memory>
@ -33,6 +33,7 @@ namespace irpass {
// (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
class GetitemEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor {
void Visit(const ValueNodePtr &vnode) override {
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
int64_t idx = GetValue<int64_t>(vnode->value());
auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) {
idx = idx + tuple_->size() - 1;
}
@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor {
// (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C}
// {prim::kPrimListGetItem, C1, C}
class GetitemConstEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor {
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
class SetitemEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
auto fg = node->func_graph();
if (fg != nullptr && z_ != nullptr) {
@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor {
};
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
class GetSetitemEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node);
auto fg = node->func_graph();
if (fg != nullptr && key1_ >= 0 && key2_ >= 0) {
@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor {
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) {
if (cnode->size() < 4) {
return;
}
@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor {
// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
class GetitemDependReorder : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
};
class ItemTupleEliminater : public OptimizerCaller {
class ItemTupleOrListEliminater : public OptimizerCaller {
public:
ItemTupleEliminater()
ItemTupleOrListEliminater()
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller {
eliminaters_.emplace_back(get_set_item_eliminater_);
eliminaters_.emplace_back(get_item_depend_reorder_);
}
~ItemTupleEliminater() = default;
~ItemTupleOrListEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller {
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_

@ -100,7 +100,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.specialize_transform_,
// Miscellaneous
irpass.item_tuple_eliminate_,
irpass.item_tuple_or_list_eliminate_,
irpass.env_get_item_eliminate_,
irpass.cast_eliminate_,
irpass.reshape_eliminate_,
@ -188,8 +188,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
}
OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_});
opt::OptPassConfig d_1 =
opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_});
OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
@ -198,7 +199,7 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig(
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
{irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});

@ -232,7 +232,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
equal to the case when `fn` is not None.
Examples:
>>> from mindspore.ops import functional as F
>>> from mindspore.ops import functional as F
...
>>> def tensor_add(x, y):
... z = x + y

@ -360,7 +360,7 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2);
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
@ -372,7 +372,7 @@ TEST_F(TestOptLib, test_tuple_setitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
@ -384,7 +384,7 @@ TEST_F(TestOptLib, test_tuple_get_set_item) {
FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));

@ -13,9 +13,14 @@
# limitations under the License.
# ============================================================================
""" test enumerate"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE)
@ -168,3 +173,60 @@ def test_list_index_3D_parameter():
net = Net()
net(Tensor(0))
def test_const_list_index_3D_bprop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
self.relu = P.ReLU()
def construct(self, input_x):
list_x = self.value
list_x[2][0][1] = input_x
return self.relu(list_x[2][0][1])
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
def construct(self, x, sens):
return self.grad_all_with_sens(self.net)(x, sens)
net = Net()
grad_net = GradNet(net)
x = Tensor(np.arange(2 * 3).reshape(2, 3))
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
grad_net(x, sens)
def test_parameter_list_index_3D_bprop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
self.relu = P.ReLU()
def construct(self, x, value):
list_value = [[x], [x, x], [[x, x], [x, x]]]
list_value[2][0][1] = value
return self.relu(list_value[2][0][1])
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
def construct(self, x, value, sens):
return self.grad_all_with_sens(self.net)(x, value, sens)
net = Net()
grad_net = GradNet(net)
x = Tensor(np.arange(2 * 3).reshape(2, 3))
value = Tensor(np.ones((2, 3), np.int64))
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
grad_net(x, value, sens)

Loading…
Cancel
Save