|
|
|
@ -119,9 +119,12 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
|
|
|
|
|
case PaddleDType::FLOAT32:
|
|
|
|
|
dt = py::dtype::of<float>();
|
|
|
|
|
break;
|
|
|
|
|
case PaddleDType::UINT8:
|
|
|
|
|
dt = py::dtype::of<uint8_t>();
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"Unsupported data type. Now only supports INT32, INT64 and "
|
|
|
|
|
"Unsupported data type. Now only supports INT32, INT64, UINT8 and "
|
|
|
|
|
"FLOAT32."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -187,9 +190,12 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
|
|
|
|
|
case PaddleDType::FLOAT32:
|
|
|
|
|
tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data()));
|
|
|
|
|
break;
|
|
|
|
|
case PaddleDType::UINT8:
|
|
|
|
|
tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"Unsupported data type. Now only supports INT32, INT64 and "
|
|
|
|
|
"Unsupported data type. Now only supports INT32, INT64, UINT8 and "
|
|
|
|
|
"FLOAT32."));
|
|
|
|
|
}
|
|
|
|
|
return array;
|
|
|
|
|