Fix deserialize bug

guochaorong-patch-1
yuyang18 7 years ago
parent 97b774dfe5
commit 47ad8d4909
No known key found for this signature in database
GPG Key ID: 6DFF29878217BE5F

@ -15,6 +15,7 @@
#include <algorithm>
#include <limits>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
namespace paddle {
namespace framework {
@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type());
auto* data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
@ -331,6 +333,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims));
void* buf;
auto ctx = platform::CPUDeviceContext();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor;
@ -338,7 +343,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
#else
@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), tensor->memory_size());
is.read(static_cast<char*>(buf), size);
}
}
}

Loading…
Cancel
Save