fix valuenode simplify

pull/3062/head
panyifeng 5 years ago
parent bc0a53cfb1
commit 394178569e

@ -43,26 +43,28 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
return nullptr; return nullptr;
} }
AbstractBasePtr res = t;
if (t->isa<AbstractClass>()) { if (t->isa<AbstractClass>()) {
auto abs_class = dyn_cast<AbstractClass>(t); auto abs_class = dyn_cast<AbstractClass>(t);
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto attributes = abs_class->attributes(); auto attributes = abs_class->attributes();
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist); return std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractDictionary>()) { }
if (t->isa<AbstractDictionary>()) {
auto abs_dict = dyn_cast<AbstractDictionary>(t); auto abs_dict = dyn_cast<AbstractDictionary>(t);
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto elements = abs_dict->elements(); auto elements = abs_dict->elements();
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist); return std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractList>()) { }
auto abs_dict = dyn_cast<AbstractList>(t); if (t->isa<AbstractList>()) {
res = std::make_shared<AbstractTuple>(abs_dict->elements()); auto abs_list = dyn_cast<AbstractList>(t);
return std::make_shared<AbstractTuple>(abs_list->elements());
} }
return res;
return nullptr;
} }
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
@ -376,7 +378,12 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
for (auto &node : manager->all_nodes()) { for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract()); auto ret = Reabs(node->abstract());
node->set_abstract(ret); if (ret) {
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
<< ret->ToString();
node->set_abstract(ret);
changed = true;
}
} }
return changed; return changed;
} }

@ -1031,3 +1031,13 @@ def test_grad_if_defer_inline():
inp = Tensor(np.ones([128, 96]).astype(np.float32)) inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp) grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
def test_dict_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.res = {'1': 10}
def construct(self):
return self.res
Net()()

Loading…
Cancel
Save