fix bug of tensor copy of CUDAPinnedPlace (#27966)

swt-req
Zhou Wei 4 years ago committed by GitHub
parent f58434ef2c
commit 2ac6c6c3af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -84,6 +84,12 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
}
#endif
#ifdef PADDLE_WITH_CUDA
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cuda_pinned_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
}
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
@ -285,6 +291,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
}
#endif
#ifdef PADDLE_WITH_CUDA
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cuda_pinned_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
}
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,

@ -141,7 +141,7 @@ class TestVarBase(unittest.TestCase):
_test_place(core.CPUPlace())
if core.is_compiled_with_cuda():
#_test_place(core.CUDAPinnedPlace())
_test_place(core.CUDAPinnedPlace())
_test_place(core.CUDAPlace(0))
def test_to_variable(self):

Loading…
Cancel
Save