|
|
|
@ -50,55 +50,68 @@ py::object TensorToPyData(const tensor::TensorPtr &tensor) {
|
|
|
|
|
return v[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::object ScalarPtrToPyData(const ScalarPtr &value) {
|
|
|
|
|
py::int_ int_v;
|
|
|
|
|
py::float_ float_v;
|
|
|
|
|
py::bool_ bool_v;
|
|
|
|
|
TypeId scalar_type = value->type()->type_id();
|
|
|
|
|
switch (scalar_type) {
|
|
|
|
|
case kNumberTypeUInt8:
|
|
|
|
|
MS_LOG(DEBUG) << "uint8";
|
|
|
|
|
int_v = value->cast<UInt8ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeUInt16:
|
|
|
|
|
MS_LOG(DEBUG) << "uint16";
|
|
|
|
|
int_v = value->cast<UInt16ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeUInt32:
|
|
|
|
|
MS_LOG(DEBUG) << "uint32";
|
|
|
|
|
int_v = value->cast<UInt32ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeUInt64:
|
|
|
|
|
MS_LOG(DEBUG) << "uint64";
|
|
|
|
|
int_v = value->cast<UInt64ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeInt8:
|
|
|
|
|
MS_LOG(DEBUG) << "int8";
|
|
|
|
|
int_v = value->cast<Int8ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeInt16:
|
|
|
|
|
MS_LOG(DEBUG) << "int16";
|
|
|
|
|
int_v = value->cast<Int16ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeInt32:
|
|
|
|
|
MS_LOG(DEBUG) << "int32";
|
|
|
|
|
int_v = value->cast<Int32ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeInt64:
|
|
|
|
|
MS_LOG(DEBUG) << "int64";
|
|
|
|
|
int_v = value->cast<Int64ImmPtr>()->value();
|
|
|
|
|
return std::move(int_v);
|
|
|
|
|
case kNumberTypeFloat32:
|
|
|
|
|
MS_LOG(DEBUG) << "float";
|
|
|
|
|
float_v = value->cast<FP32ImmPtr>()->value();
|
|
|
|
|
return std::move(float_v);
|
|
|
|
|
case kNumberTypeFloat64:
|
|
|
|
|
MS_LOG(DEBUG) << "double";
|
|
|
|
|
float_v = value->cast<FP64ImmPtr>()->value();
|
|
|
|
|
return std::move(float_v);
|
|
|
|
|
case kNumberTypeBool:
|
|
|
|
|
MS_LOG(DEBUG) << "bool";
|
|
|
|
|
bool_v = value->cast<BoolImmPtr>()->value();
|
|
|
|
|
return std::move(bool_v);
|
|
|
|
|
default:
|
|
|
|
|
MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::object ValuePtrToPyData(const ValuePtr &value) {
|
|
|
|
|
if (value == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "value is null";
|
|
|
|
|
}
|
|
|
|
|
py::object ret;
|
|
|
|
|
if (value->isa<Int8Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "int8";
|
|
|
|
|
py::int_ v = value->cast<Int8ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<Int16Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "int16";
|
|
|
|
|
py::int_ v = value->cast<Int16ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<Int32Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "int32";
|
|
|
|
|
py::int_ v = value->cast<Int32ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<Int64Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "int64";
|
|
|
|
|
py::int_ v = value->cast<Int64ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<UInt8Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "uint8";
|
|
|
|
|
py::int_ v = value->cast<UInt8ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<UInt16Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "uint16";
|
|
|
|
|
py::int_ v = value->cast<UInt16ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<UInt32Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "uint32";
|
|
|
|
|
py::int_ v = value->cast<UInt32ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<UInt64Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "uint64";
|
|
|
|
|
py::int_ v = value->cast<UInt64ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<BoolImm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "bool";
|
|
|
|
|
py::bool_ v = value->cast<BoolImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<FP64Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "double";
|
|
|
|
|
py::float_ v = value->cast<FP64ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
} else if (value->isa<FP32Imm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "float";
|
|
|
|
|
py::float_ v = value->cast<FP32ImmPtr>()->value();
|
|
|
|
|
ret = v;
|
|
|
|
|
if (value->isa<Scalar>()) {
|
|
|
|
|
ret = ScalarPtrToPyData(value->cast<ScalarPtr>());
|
|
|
|
|
} else if (value->isa<StringImm>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "String";
|
|
|
|
|
py::str v = value->cast<StringImmPtr>()->value();
|
|
|
|
@ -117,28 +130,27 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|
|
|
|
py::tuple v(1);
|
|
|
|
|
v[0] = value->cast<RefKeyPtr>();
|
|
|
|
|
ret = v[0];
|
|
|
|
|
} else if (value->isa<ValueTuple>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "tuple";
|
|
|
|
|
auto value_tuple = value->cast<ValueTuplePtr>()->value();
|
|
|
|
|
py::tuple rets(value_tuple.size());
|
|
|
|
|
} else if (value->isa<ValueSequeue>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "tuple or list";
|
|
|
|
|
auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
|
|
|
|
|
py::tuple ret_sequeue(value_sequeue.size());
|
|
|
|
|
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
for (auto &v : value_tuple) {
|
|
|
|
|
rets[i] = ValuePtrToPyData(v);
|
|
|
|
|
i++;
|
|
|
|
|
for (size_t i = 0; i < value_sequeue.size(); i++) {
|
|
|
|
|
ret_sequeue[i] = ValuePtrToPyData(value_sequeue[i]);
|
|
|
|
|
}
|
|
|
|
|
ret = rets;
|
|
|
|
|
} else if (value->isa<ValueList>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "list";
|
|
|
|
|
auto value_list = value->cast<ValueListPtr>()->value();
|
|
|
|
|
py::list rets(value_list.size());
|
|
|
|
|
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
for (auto &v : value_list) {
|
|
|
|
|
rets[i] = ValuePtrToPyData(v);
|
|
|
|
|
i++;
|
|
|
|
|
if (value->isa<ValueTuple>()) {
|
|
|
|
|
ret = ret_sequeue;
|
|
|
|
|
} else {
|
|
|
|
|
ret = ret_sequeue.cast<py::list>();
|
|
|
|
|
}
|
|
|
|
|
ret = rets;
|
|
|
|
|
} else if (value->isa<ValueDictionary>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "dict";
|
|
|
|
|
auto value_list = value->cast<ValueDictionaryPtr>()->value();
|
|
|
|
|
py::dict ret_dict;
|
|
|
|
|
for (const auto &v : value_list) {
|
|
|
|
|
ret_dict[py::str(v.first)] = ValuePtrToPyData(v.second);
|
|
|
|
|
}
|
|
|
|
|
ret = ret_dict;
|
|
|
|
|
} else if (value->isa<Ellipsis>()) {
|
|
|
|
|
ret = py::ellipsis();
|
|
|
|
|
} else if (value->isa<ValueSlice>()) {
|
|
|
|
@ -152,15 +164,9 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|
|
|
|
py::tuple v(1);
|
|
|
|
|
v[0] = value->cast<TypePtr>();
|
|
|
|
|
ret = v[0];
|
|
|
|
|
} else if (value->isa<AnyValue>()) {
|
|
|
|
|
ret = py::none();
|
|
|
|
|
} else if (value->isa<None>()) {
|
|
|
|
|
ret = py::none();
|
|
|
|
|
} else if (value->isa<FuncGraph>()) {
|
|
|
|
|
} else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) {
|
|
|
|
|
// FuncGraph is not used in the backend, return None
|
|
|
|
|
ret = py::none();
|
|
|
|
|
} else if (value->isa<Monad>()) {
|
|
|
|
|
ret = py::none();
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
|
|
|
|
|
}
|
|
|
|
|