From 47dd17a325fb1b7060895a3ec7eca9e3c47bce6a Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 22 Feb 2021 16:38:35 +0800 Subject: [PATCH] support convert ValueDict to py dict --- mindspore/ccsrc/utils/convert_utils_py.cc | 146 +++++++++--------- mindspore/core/ir/value.h | 2 +- .../python/pipeline/parse/test_dictionary.py | 23 ++- 3 files changed, 99 insertions(+), 72 deletions(-) diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index e3816d88be..fd1c751eb6 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -50,55 +50,68 @@ py::object TensorToPyData(const tensor::TensorPtr &tensor) { return v[0]; } +py::object ScalarPtrToPyData(const ScalarPtr &value) { + py::int_ int_v; + py::float_ float_v; + py::bool_ bool_v; + TypeId scalar_type = value->type()->type_id(); + switch (scalar_type) { + case kNumberTypeUInt8: + MS_LOG(DEBUG) << "uint8"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeUInt16: + MS_LOG(DEBUG) << "uint16"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeUInt32: + MS_LOG(DEBUG) << "uint32"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeUInt64: + MS_LOG(DEBUG) << "uint64"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeInt8: + MS_LOG(DEBUG) << "int8"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeInt16: + MS_LOG(DEBUG) << "int16"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeInt32: + MS_LOG(DEBUG) << "int32"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeInt64: + MS_LOG(DEBUG) << "int64"; + int_v = value->cast()->value(); + return std::move(int_v); + case kNumberTypeFloat32: + MS_LOG(DEBUG) << "float"; + float_v = value->cast()->value(); + return std::move(float_v); + case kNumberTypeFloat64: + MS_LOG(DEBUG) << "double"; + float_v = value->cast()->value(); + return std::move(float_v); + case kNumberTypeBool: + MS_LOG(DEBUG) << "bool"; + bool_v = value->cast()->value(); + return std::move(bool_v); + default: + MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString(); + } +} + py::object ValuePtrToPyData(const ValuePtr &value) { if (value == nullptr) { MS_LOG(EXCEPTION) << "value is null"; } py::object ret; - if (value->isa()) { - MS_LOG(DEBUG) << "int8"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "int16"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "int32"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "int64"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "uint8"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "uint16"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "uint32"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "uint64"; - py::int_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "bool"; - py::bool_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "double"; - py::float_ v = value->cast()->value(); - ret = v; - } else if (value->isa()) { - MS_LOG(DEBUG) << "float"; - py::float_ v = value->cast()->value(); - ret = v; + if (value->isa()) { + ret = ScalarPtrToPyData(value->cast()); } else if (value->isa()) { MS_LOG(DEBUG) << "String"; py::str v = value->cast()->value(); @@ -117,28 +130,27 @@ py::object ValuePtrToPyData(const ValuePtr &value) { py::tuple v(1); v[0] = value->cast(); ret = v[0]; - } else if (value->isa()) { - MS_LOG(DEBUG) << "tuple"; - auto value_tuple = value->cast()->value(); - py::tuple rets(value_tuple.size()); + } else if (value->isa()) { + MS_LOG(DEBUG) << "tuple or list"; + auto value_sequeue = value->cast()->value(); + py::tuple ret_sequeue(value_sequeue.size()); - size_t i = 0; - for (auto &v : value_tuple) { - rets[i] = ValuePtrToPyData(v); - i++; + for (size_t i = 0; i < value_sequeue.size(); i++) { + ret_sequeue[i] = ValuePtrToPyData(value_sequeue[i]); } - ret = rets; - } else if (value->isa()) { - MS_LOG(DEBUG) << "list"; - auto value_list = value->cast()->value(); - py::list rets(value_list.size()); - - size_t i = 0; - for (auto &v : value_list) { - rets[i] = ValuePtrToPyData(v); - i++; + if (value->isa()) { + ret = ret_sequeue; + } else { + ret = ret_sequeue.cast(); } - ret = rets; + } else if (value->isa()) { + MS_LOG(DEBUG) << "dict"; + auto value_list = value->cast()->value(); + py::dict ret_dict; + for (const auto &v : value_list) { + ret_dict[py::str(v.first)] = ValuePtrToPyData(v.second); + } + ret = ret_dict; } else if (value->isa()) { ret = py::ellipsis(); } else if (value->isa()) { @@ -152,15 +164,9 @@ py::object ValuePtrToPyData(const ValuePtr &value) { py::tuple v(1); v[0] = value->cast(); ret = v[0]; - } else if (value->isa()) { - ret = py::none(); - } else if (value->isa()) { - ret = py::none(); - } else if (value->isa()) { + } else if (value->isa() || value->isa() || value->isa() || value->isa()) { // FuncGraph is not used in the backend, return None ret = py::none(); - } else if (value->isa()) { - ret = py::none(); } else { MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData."; } diff --git a/mindspore/core/ir/value.h b/mindspore/core/ir/value.h index a3e43d2bc8..c2db08c7a0 100644 --- a/mindspore/core/ir/value.h +++ b/mindspore/core/ir/value.h @@ -184,7 +184,7 @@ class ValueDictionary : public Value { buffer << ") values:("; for (const auto &value : values) { MS_EXCEPTION_IF_NULL(value); - buffer << value->DumpText() << ", "; + buffer << value->ToString() << ", "; } buffer << ")"; return buffer.str(); diff --git a/tests/ut/python/pipeline/parse/test_dictionary.py b/tests/ut/python/pipeline/parse/test_dictionary.py index a2e4adfdcb..977dd83154 100644 --- a/tests/ut/python/pipeline/parse/test_dictionary.py +++ b/tests/ut/python/pipeline/parse/test_dictionary.py @@ -15,9 +15,11 @@ """ test_dictionary """ import numpy as np -from mindspore import Tensor +from mindspore import Tensor, context from mindspore.nn import Cell +context.set_context(mode=context.GRAPH_MODE) + class Net1(Cell): def __init__(self): @@ -32,6 +34,7 @@ class Net1(Cell): output.append(j) return output + class Net2(Cell): def __init__(self): super().__init__() @@ -45,6 +48,7 @@ class Net2(Cell): output.append(j) return output + class Net3(Cell): def __init__(self): super().__init__() @@ -59,6 +63,7 @@ class Net3(Cell): output.append(j) return output + def test_dict1(): input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) input_me = Tensor(input_np) @@ -73,9 +78,25 @@ def test_dict2(): net = Net2() net(input_me) + def test_dict3(): input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) input_me = Tensor(input_np) net = Net3() out_me = net(input_me) assert out_me == ('x', 'y', 0, (0, 1)) + + +def test_dict4(): + class Net(Cell): + def __init__(self): + super().__init__() + + def construct(self, tuple_x): + output = tuple_x + tuple_x + return output + + x = (1, Tensor([1, 2, 3]), {"a": Tensor([1, 2, 3]), "b": 1}) + net = Net() + out_me = net(x) + assert out_me == x + x