!12507 support convert ValueDict to py dict

From: @zhangbuxue
Reviewed-by: 
Signed-off-by:
pull/12507/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e73f43aa85

@ -50,55 +50,68 @@ py::object TensorToPyData(const tensor::TensorPtr &tensor) {
return v[0];
}
py::object ScalarPtrToPyData(const ScalarPtr &value) {
py::int_ int_v;
py::float_ float_v;
py::bool_ bool_v;
TypeId scalar_type = value->type()->type_id();
switch (scalar_type) {
case kNumberTypeUInt8:
MS_LOG(DEBUG) << "uint8";
int_v = value->cast<UInt8ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeUInt16:
MS_LOG(DEBUG) << "uint16";
int_v = value->cast<UInt16ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeUInt32:
MS_LOG(DEBUG) << "uint32";
int_v = value->cast<UInt32ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeUInt64:
MS_LOG(DEBUG) << "uint64";
int_v = value->cast<UInt64ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeInt8:
MS_LOG(DEBUG) << "int8";
int_v = value->cast<Int8ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeInt16:
MS_LOG(DEBUG) << "int16";
int_v = value->cast<Int16ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeInt32:
MS_LOG(DEBUG) << "int32";
int_v = value->cast<Int32ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeInt64:
MS_LOG(DEBUG) << "int64";
int_v = value->cast<Int64ImmPtr>()->value();
return std::move(int_v);
case kNumberTypeFloat32:
MS_LOG(DEBUG) << "float";
float_v = value->cast<FP32ImmPtr>()->value();
return std::move(float_v);
case kNumberTypeFloat64:
MS_LOG(DEBUG) << "double";
float_v = value->cast<FP64ImmPtr>()->value();
return std::move(float_v);
case kNumberTypeBool:
MS_LOG(DEBUG) << "bool";
bool_v = value->cast<BoolImmPtr>()->value();
return std::move(bool_v);
default:
MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
}
}
py::object ValuePtrToPyData(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(EXCEPTION) << "value is null";
}
py::object ret;
if (value->isa<Int8Imm>()) {
MS_LOG(DEBUG) << "int8";
py::int_ v = value->cast<Int8ImmPtr>()->value();
ret = v;
} else if (value->isa<Int16Imm>()) {
MS_LOG(DEBUG) << "int16";
py::int_ v = value->cast<Int16ImmPtr>()->value();
ret = v;
} else if (value->isa<Int32Imm>()) {
MS_LOG(DEBUG) << "int32";
py::int_ v = value->cast<Int32ImmPtr>()->value();
ret = v;
} else if (value->isa<Int64Imm>()) {
MS_LOG(DEBUG) << "int64";
py::int_ v = value->cast<Int64ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt8Imm>()) {
MS_LOG(DEBUG) << "uint8";
py::int_ v = value->cast<UInt8ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt16Imm>()) {
MS_LOG(DEBUG) << "uint16";
py::int_ v = value->cast<UInt16ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt32Imm>()) {
MS_LOG(DEBUG) << "uint32";
py::int_ v = value->cast<UInt32ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt64Imm>()) {
MS_LOG(DEBUG) << "uint64";
py::int_ v = value->cast<UInt64ImmPtr>()->value();
ret = v;
} else if (value->isa<BoolImm>()) {
MS_LOG(DEBUG) << "bool";
py::bool_ v = value->cast<BoolImmPtr>()->value();
ret = v;
} else if (value->isa<FP64Imm>()) {
MS_LOG(DEBUG) << "double";
py::float_ v = value->cast<FP64ImmPtr>()->value();
ret = v;
} else if (value->isa<FP32Imm>()) {
MS_LOG(DEBUG) << "float";
py::float_ v = value->cast<FP32ImmPtr>()->value();
ret = v;
if (value->isa<Scalar>()) {
ret = ScalarPtrToPyData(value->cast<ScalarPtr>());
} else if (value->isa<StringImm>()) {
MS_LOG(DEBUG) << "String";
py::str v = value->cast<StringImmPtr>()->value();
@ -117,28 +130,27 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
py::tuple v(1);
v[0] = value->cast<RefKeyPtr>();
ret = v[0];
} else if (value->isa<ValueTuple>()) {
MS_LOG(DEBUG) << "tuple";
auto value_tuple = value->cast<ValueTuplePtr>()->value();
py::tuple rets(value_tuple.size());
} else if (value->isa<ValueSequeue>()) {
MS_LOG(DEBUG) << "tuple or list";
auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
py::tuple ret_sequeue(value_sequeue.size());
size_t i = 0;
for (auto &v : value_tuple) {
rets[i] = ValuePtrToPyData(v);
i++;
for (size_t i = 0; i < value_sequeue.size(); i++) {
ret_sequeue[i] = ValuePtrToPyData(value_sequeue[i]);
}
ret = rets;
} else if (value->isa<ValueList>()) {
MS_LOG(DEBUG) << "list";
auto value_list = value->cast<ValueListPtr>()->value();
py::list rets(value_list.size());
size_t i = 0;
for (auto &v : value_list) {
rets[i] = ValuePtrToPyData(v);
i++;
if (value->isa<ValueTuple>()) {
ret = ret_sequeue;
} else {
ret = ret_sequeue.cast<py::list>();
}
ret = rets;
} else if (value->isa<ValueDictionary>()) {
MS_LOG(DEBUG) << "dict";
auto value_list = value->cast<ValueDictionaryPtr>()->value();
py::dict ret_dict;
for (const auto &v : value_list) {
ret_dict[py::str(v.first)] = ValuePtrToPyData(v.second);
}
ret = ret_dict;
} else if (value->isa<Ellipsis>()) {
ret = py::ellipsis();
} else if (value->isa<ValueSlice>()) {
@ -152,15 +164,9 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
py::tuple v(1);
v[0] = value->cast<TypePtr>();
ret = v[0];
} else if (value->isa<AnyValue>()) {
ret = py::none();
} else if (value->isa<None>()) {
ret = py::none();
} else if (value->isa<FuncGraph>()) {
} else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) {
// FuncGraph is not used in the backend, return None
ret = py::none();
} else if (value->isa<Monad>()) {
ret = py::none();
} else {
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
}

@ -184,7 +184,7 @@ class ValueDictionary : public Value {
buffer << ") values:(";
for (const auto &value : values) {
MS_EXCEPTION_IF_NULL(value);
buffer << value->DumpText() << ", ";
buffer << value->ToString() << ", ";
}
buffer << ")";
return buffer.str();

@ -15,9 +15,11 @@
""" test_dictionary """
import numpy as np
from mindspore import Tensor
from mindspore import Tensor, context
from mindspore.nn import Cell
context.set_context(mode=context.GRAPH_MODE)
class Net1(Cell):
def __init__(self):
@ -32,6 +34,7 @@ class Net1(Cell):
output.append(j)
return output
class Net2(Cell):
def __init__(self):
super().__init__()
@ -45,6 +48,7 @@ class Net2(Cell):
output.append(j)
return output
class Net3(Cell):
def __init__(self):
super().__init__()
@ -59,6 +63,7 @@ class Net3(Cell):
output.append(j)
return output
def test_dict1():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
@ -73,9 +78,25 @@ def test_dict2():
net = Net2()
net(input_me)
def test_dict3():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net3()
out_me = net(input_me)
assert out_me == ('x', 'y', 0, (0, 1))
def test_dict4():
class Net(Cell):
def __init__(self):
super().__init__()
def construct(self, tuple_x):
output = tuple_x + tuple_x
return output
x = (1, Tensor([1, 2, 3]), {"a": Tensor([1, 2, 3]), "b": 1})
net = Net()
out_me = net(x)
assert out_me == x + x

Loading…
Cancel
Save