From 3cf94daefc000fc164523a82078aab3e0452ef2a Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Tue, 19 May 2020 15:09:08 +0800 Subject: [PATCH] add support for dict setitem operation format code remove save_graph raise exception when dictionary setitem if the key not exists resolve error remove unnessary op_ --- mindspore/ccsrc/optimizer/clean.cc | 43 +++++++++++++++++++ .../optimizer/irpass/item_tuple_eliminate.h | 7 ++- tests/ut/python/dtype/test_dictionary.py | 14 ++++++ 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc index 97ac72e3fb..4ada67893e 100644 --- a/mindspore/ccsrc/optimizer/clean.cc +++ b/mindspore/ccsrc/optimizer/clean.cc @@ -139,6 +139,47 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } +AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + // Inputs should be [dict_setitem, dict, item, value] + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs."); + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + AnfNodePtr item_value = inputs[3]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto dt = data->abstract(); + MS_EXCEPTION_IF_NULL(dt); + if (!dt->isa()) { + MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name(); + } + auto cons_is_str = IsValueNode(cons); + auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; + + auto ct = dyn_cast(dt); + const auto &cmap = ct->elements(); + int count = 0; + for (auto &item : cmap) { + if (cons_is_str && item.first == cons_str) { + break; + } + count++; + } + if (IntToSize(count) >= cmap.size()) { + MS_LOG(EXCEPTION) << "dictionary assignment key " << cons_str + << " does not exist, can not create new dictionary item for now."; + } + auto idx_c = NewValueNode(count); + AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); + idx_c->set_abstract(aptr); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value}); +} + AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); @@ -300,6 +341,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr new_node = ErasePartialNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) { new_node = ConvertDictGetItemToTupleGetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { + new_node = ConvertDictSetItemToTupleSetItem(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { new_node = EraseMakeDictNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { diff --git a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h index 2edf5298ad..2693aec1c9 100644 --- a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h @@ -138,7 +138,7 @@ class GetSetitemEliminater : public AnfVisitor { if (key1_ == key2_) { return last_; } - return fg->NewCNode({op_, tuple_, c2_}); + return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_, c2_}); } return nullptr; } @@ -148,7 +148,7 @@ class GetSetitemEliminater : public AnfVisitor { if (cnode->size() < 4) { return; } - op_ = cnode->input(0); + tuple_ = cnode->input(1); last_ = cnode->input(3); @@ -174,7 +174,6 @@ class GetSetitemEliminater : public AnfVisitor { void Reset() { key1_ = -1; key2_ = -1; - op_ = nullptr; c2_ = nullptr; last_ = nullptr; tuple_ = nullptr; @@ -184,7 +183,7 @@ class GetSetitemEliminater : public AnfVisitor { private: bool is_in_set_{false}; int key1_{-1}, key2_{-1}; - AnfNodePtr op_{nullptr}, tuple_{nullptr}, last_{nullptr}, c2_{nullptr}; + AnfNodePtr tuple_{nullptr}, last_{nullptr}, c2_{nullptr}; }; // {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> diff --git a/tests/ut/python/dtype/test_dictionary.py b/tests/ut/python/dtype/test_dictionary.py index 4535fb82d0..8c8b2a9389 100644 --- a/tests/ut/python/dtype/test_dictionary.py +++ b/tests/ut/python/dtype/test_dictionary.py @@ -136,3 +136,17 @@ def test_dict_set_or_get_item_3(): net = DictNet() assert net() == Tensor(np.ones([4, 2, 3], np.float32)) + +def test_dict_set_item(): + class DictSetNet(Cell): + def __init__(self): + super(DictSetNet, self).__init__() + self.attrs = ("abc", "edf", "ghi", "jkl") + def construct(self, x): + my_dict = {"def": x, "abc":x, "edf":x, "ghi":x, "jkl":x} + for i in range(len(self.attrs)): + my_dict[self.attrs[i]] = x - i + return my_dict["jkl"], my_dict["edf"] + x = Tensor(np.ones([2, 2, 3], np.float32)) + net = DictSetNet() + out = net(x) \ No newline at end of file