copy_to_cpu support uint8 (#28372)

TCChenlong-patch-1
Wilber 4 years ago committed by GitHub
parent 09fd2b2aab
commit 6f0f45f69c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -119,9 +119,12 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
case PaddleDType::FLOAT32:
dt = py::dtype::of<float>();
break;
case PaddleDType::UINT8:
dt = py::dtype::of<uint8_t>();
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and "
"Unsupported data type. Now only supports INT32, INT64, UINT8 and "
"FLOAT32."));
}
@ -187,9 +190,12 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
case PaddleDType::FLOAT32:
tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data()));
break;
case PaddleDType::UINT8:
tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and "
"Unsupported data type. Now only supports INT32, INT64, UINT8 and "
"FLOAT32."));
}
return array;

@ -112,4 +112,4 @@ if(LINUX AND NOT WITH_SW)
message(FATAL_ERROR "patchelf not found, please install it.\n"
"For Ubuntu, the command is: apt-get install -y patchelf.")
endif()
endif(LINUX)
endif()

Loading…
Cancel
Save