From 3b218228244bf0e6ca2af16bfc6afd4608f01b22 Mon Sep 17 00:00:00 2001 From: simson Date: Thu, 29 Oct 2020 15:13:49 +0800 Subject: [PATCH] get keys and values from dictionary & set tuple to dictionary --- mindspore/ccsrc/frontend/optimizer/clean.cc | 9 +++ mindspore/ccsrc/pipeline/jit/resource.cc | 2 + mindspore/core/abstract/infer_functions.h | 4 + mindspore/core/abstract/prim_structures.cc | 26 ++++++ .../core/abstract/primitive_infer_map.cc | 2 + mindspore/core/base/core_ops.h | 2 + .../composite/multitype_ops/setitem_impl.py | 14 ++++ .../python/pipeline/parse/test_dictionary.py | 81 +++++++++++++++++++ 8 files changed, 140 insertions(+) create mode 100644 tests/ut/python/pipeline/parse/test_dictionary.py diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index 1059fc2f63..79f97fc6ea 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -304,6 +304,13 @@ AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { return inputs[2]; } +AnfNodePtr EraseDictGetValues(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() == 2 && "DictGetValues should have two inputs"); + return inputs[1]; +} + AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); const auto &inputs = node->inputs(); @@ -374,6 +381,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr new_node = ConvertDictGetItemToTupleGetItem(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { new_node = ConvertDictSetItemToTupleSetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) { + new_node = EraseDictGetValues(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { new_node = EraseMakeDictNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index ede5c59bae..7f8d732531 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -141,6 +141,8 @@ BuiltInTypeMap &GetMethodMap() { {"__len__", prim::kPrimDictLen}, // P.dict_len {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, + {"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys, + {"values", prim::kPrimDictGetValues}, // P.dict_getvalues, {"__bool__", std::string("dict_bool")} // C.dict_bool }}, {kObjectTypeTensorType, diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 13f6a370d8..dc93c6b306 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -131,6 +131,10 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_structures.cc b/mindspore/core/abstract/prim_structures.cc index 73b799566a..0388c4ab49 100644 --- a/mindspore/core/abstract/prim_structures.cc +++ b/mindspore/core/abstract/prim_structures.cc @@ -249,6 +249,32 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP return std::make_shared(dict_elems); } +AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + std::vector dict_elems = dict->elements(); + AbstractBasePtrList keys; + std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(keys), + [](const AbstractAttribute &item) { return std::make_shared(item.first); }); + return std::make_shared(keys); +} + +AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + std::vector dict_elems = dict->elements(); + AbstractBasePtrList values; + std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(values), + [](const AbstractAttribute &item) { return item.second; }); + return std::make_shared(values); +} + AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a list and an object of a subclass of AbstractBase. diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index bf92b9ca0a..eed75d1b44 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -72,6 +72,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimListSetItem, {InferImplListSetItem, true}}, {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, + {prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}}, + {prim::kPrimDictGetValues, {InferImplDictGetValues, true}}, {prim::kPrimListAppend, {InferImplListAppend, true}}, {prim::kPrimTupleLen, {InferImplTupleLen, true}}, {prim::kPrimListLen, {InferImplListLen, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index d02996c1b2..ebf9b52b11 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -279,6 +279,8 @@ inline const PrimitivePtr kPrimListGetItem = std::make_shared("list_g inline const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); inline const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); inline const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); +inline const PrimitivePtr kPrimDictGetKeys = std::make_shared("dict_getkeys"); +inline const PrimitivePtr kPrimDictGetValues = std::make_shared("dict_getvalues"); inline const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); inline const PrimitivePtr kPrimListLen = std::make_shared("list_len"); diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index e6c7e28c95..b18bdcf4e8 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -132,6 +132,20 @@ def _dict_setitem_with_number(data, key, value): """ return F.dict_setitem(data, key, value) +@setitem.register("Dictionary", "String", "Tuple") +def _dict_setitem_with_tuple(data, key, value): + """ + Assigns value to dictionary. + + Inputs: + data (dict): Data of type dict. + key (str): Key of the data. + value (Tuple): Value given. + + Outputs: + dict, type is as same as the element type of data. + """ + return F.dict_setitem(data, key, value) @setitem.register("Tensor", "Tensor", "Tensor") def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): diff --git a/tests/ut/python/pipeline/parse/test_dictionary.py b/tests/ut/python/pipeline/parse/test_dictionary.py new file mode 100644 index 0000000000..a2e4adfdcb --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_dictionary.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_dictionary """ +import numpy as np + +from mindspore import Tensor +from mindspore.nn import Cell + + +class Net1(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + dic = {'x': 0, 'y': 1} + output = [] + for i in dic.keys(): + output.append(i) + for j in dic.values(): + output.append(j) + return output + +class Net2(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + dic = {'x': x, 'y': 1} + output = [] + for i in dic.keys(): + output.append(i) + for j in dic.values(): + output.append(j) + return output + +class Net3(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + dic = {'x': 0} + dic['y'] = (0, 1) + output = [] + for i in dic.keys(): + output.append(i) + for j in dic.values(): + output.append(j) + return output + +def test_dict1(): + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me = Tensor(input_np) + net = Net1() + out_me = net(input_me) + assert out_me == ('x', 'y', 0, 1) + + +def test_dict2(): + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me = Tensor(input_np) + net = Net2() + net(input_me) + +def test_dict3(): + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me = Tensor(input_np) + net = Net3() + out_me = net(input_me) + assert out_me == ('x', 'y', 0, (0, 1))