[oneDNN] Tensor copy fix to oneDNN tensors (#29771)

* - Tensor copy fix to oneDNN tensors

* - Fixes after review
revert-31562-mean
Jacek Czaja 5 years ago committed by GitHub
parent a400b76db7
commit 7b33720c90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,20 +43,32 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
dst->Resize(src.dims());
dst->set_layout(src.layout());
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
#endif
auto src_place = src.place();
auto src_ptr = src.data<void>();
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
// oneDNN tensors due to padding may be of bigger size
// than numel()*size(type())
auto dst_ptr =
src.layout() == DataLayout::kMKLDNN
? dst->mutable_data(dst_place, src.type(), src.memory_size())
: dst->mutable_data(dst_place, src.type());
#else
auto dst_ptr = dst->mutable_data(dst_place, src.type());
#endif
if (src_ptr == dst_ptr && src_place == dst_place) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
return;
}
#ifdef PADDLE_WITH_MKLDNN
auto size = src.layout() == DataLayout::kMKLDNN
? src.memory_size()
: src.numel() * SizeOfType(src.type());
#else
auto size = src.numel() * SizeOfType(src.type());
#endif
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,

Loading…
Cancel
Save