|
|
|
@ -46,19 +46,6 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
DeserializeFromStream(fin, tensor, *dev_ctx);
|
|
|
|
|
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
// copy CPU to GPU
|
|
|
|
|
framework::LoDTensor cpu_tensor;
|
|
|
|
|
cpu_tensor.ShareDataWith(*tensor);
|
|
|
|
|
cpu_tensor.set_lod(tensor->lod());
|
|
|
|
|
|
|
|
|
|
// reset tensor
|
|
|
|
|
out_var->Clear();
|
|
|
|
|
tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->set_lod(cpu_tensor.lod());
|
|
|
|
|
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|