diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index 45a271f692..e35760ceaf 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -43,26 +43,28 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) { return nullptr; } - AbstractBasePtr res = t; if (t->isa()) { auto abs_class = dyn_cast(t); AbstractBasePtrList baselist; auto attributes = abs_class->attributes(); (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), [](const AbstractAttribute &item) { return item.second; }); - res = std::make_shared(baselist); - } else if (t->isa()) { + return std::make_shared(baselist); + } + if (t->isa()) { auto abs_dict = dyn_cast(t); AbstractBasePtrList baselist; auto elements = abs_dict->elements(); (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), [](const AbstractAttribute &item) { return item.second; }); - res = std::make_shared(baselist); - } else if (t->isa()) { - auto abs_dict = dyn_cast(t); - res = std::make_shared(abs_dict->elements()); + return std::make_shared(baselist); + } + if (t->isa()) { + auto abs_list = dyn_cast(t); + return std::make_shared(abs_list->elements()); } - return res; + + return nullptr; } AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { @@ -376,7 +378,12 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr for (auto &node : manager->all_nodes()) { auto ret = Reabs(node->abstract()); - node->set_abstract(ret); + if (ret) { + MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " + << ret->ToString(); + node->set_abstract(ret); + changed = true; + } } return changed; } diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index cdae50dc8f..3b99d0dc5f 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -1031,3 +1031,13 @@ def test_grad_if_defer_inline(): inp = Tensor(np.ones([128, 96]).astype(np.float32)) grads = C.grad_all(network)(inp) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) + + +def test_dict_const(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.res = {'1': 10} + def construct(self): + return self.res + Net()()