support item negative index bprop

pull/13197/head
buxue 4 years ago
parent 8eb3e396e5
commit 50ee325b96

@ -47,8 +47,11 @@ class ConvertItemIndexToPositive : public AnfVisitor {
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
if (is_match_) { FuncGraphPtr fg = node->func_graph();
node->cast<CNodePtr>()->set_input(2, NewValueNode(id_)); if (is_match_ && fg != nullptr) {
auto inputs = node->cast<CNodePtr>()->inputs();
inputs[2] = NewValueNode(id_);
return fg->NewCNode(inputs);
} }
return nullptr; return nullptr;
} }

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Providing decorators.""" """Providing decorators."""
from mindspore import log
def deprecated(version, substitute, use_substitute_name=False): def deprecated(version, substitute, use_substitute_name=False):
"""deprecated warning """deprecated warning
@ -28,8 +29,8 @@ def deprecated(version, substitute, use_substitute_name=False):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
cls = getattr(args[0], "__class__", None) if args else None cls = getattr(args[0], "__class__", None) if args else None
name = cls.__name__ if cls else func.__name__ name = cls.__name__ if cls else func.__name__
print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " log.warning(f"'{name}' is deprecated from version {version} and "
f"use '{substitute}' instead.") f"will be removed in a future version, use '{substitute}' instead.")
if cls and use_substitute_name: if cls and use_substitute_name:
cls.substitute_name = substitute cls.substitute_name = substitute
ret = func(*args, **kwargs) ret = func(*args, **kwargs)

@ -20,8 +20,9 @@ from mindspore import nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE)
def test_tuple_index_by_negative_number(): def test_tuple_index_by_negative_number():
@ -37,12 +38,24 @@ def test_tuple_index_by_negative_number():
ret[-1] = 100 ret[-1] = 100
return ret return ret
class GradNet(nn.Cell):
def __init__(self, net, get_all):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, x):
return self.grad_all(self.forward_net)(x)
net = Net() net = Net()
grad_net = GradNet(net, True)
x = Tensor(np.ones((4, 2, 3))) x = Tensor(np.ones((4, 2, 3)))
net(x) net(x)
grad_net(x)
def Ttest_tuple_index_by_negative_number_out_bound(): def test_tuple_index_by_negative_number_out_bound():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()

Loading…
Cancel
Save