Fix bug of convert python object

pull/5760/head
fary86 5 years ago
parent d76ac7c6e8
commit d2d8a911ff

@ -284,17 +284,13 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
return false;
}
bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
if (dtype == nullptr) {
*data = std::make_shared<Int32Imm>(obj);
return true;
}
template <typename T>
bool ConvertNumberWithType(const T &obj, ValuePtr *const data, TypePtr dtype) {
auto int_dypte = dyn_cast<Int>(dtype);
if (int_dypte != nullptr) {
switch (int_dypte->nbits()) {
case 8:
*data = std::make_shared<Int8Imm>(static_cast<int8_t>(obj));
*data = std::make_shared<Int8Imm>(obj);
break;
case 16:
*data = std::make_shared<Int16Imm>(obj);
@ -312,7 +308,7 @@ bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype
}
auto uint_dypte = dyn_cast<UInt>(dtype);
if (int_dypte != nullptr) {
if (uint_dypte != nullptr) {
switch (uint_dypte->nbits()) {
case 8:
*data = std::make_shared<UInt8Imm>(obj);
@ -350,28 +346,22 @@ bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype
return false;
}
bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
if (dtype == nullptr) {
*data = std::make_shared<FP32Imm>(obj);
*data = std::make_shared<Int32Imm>(obj);
return true;
}
auto float_dypte = dyn_cast<Float>(dtype);
if (float_dypte == nullptr) {
return false;
}
return ConvertNumberWithType<int>(obj, data, dtype);
}
switch (float_dypte->nbits()) {
case 32:
*data = std::make_shared<FP32Imm>(obj);
break;
case 64:
*data = std::make_shared<FP64Imm>(obj);
break;
default:
*data = std::make_shared<FP32Imm>(obj);
bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
if (dtype == nullptr) {
*data = std::make_shared<FP32Imm>(obj);
return true;
}
return true;
return ConvertNumberWithType<float>(obj, data, dtype);
}
} // namespace

@ -80,7 +80,6 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)

Loading…
Cancel
Save