|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
@ -45,22 +46,19 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
out_var_name);
|
|
|
|
|
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
SaveLodTensor(filename, place, out_var);
|
|
|
|
|
LoadLodTensor(filename, place, out_var);
|
|
|
|
|
} else if (out_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
SaveSelectedRows(filename, scope, place, out_var);
|
|
|
|
|
LoadSelectedRows(filename, scope, place, out_var);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
false,
|
|
|
|
|
"Load only support LoDTensor and SelectedRows, %s has wrong type",
|
|
|
|
|
iname);
|
|
|
|
|
out_var_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LoadLodTensor(const std::string &filename, const platform::Place &place,
|
|
|
|
|
framework::Variable *var) const {
|
|
|
|
|
auto &tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
@ -68,10 +66,10 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
// FIXME(yuyang18): We save variable to local file now, but we should change
|
|
|
|
|
// it to save an output stream.
|
|
|
|
|
std::ifstream fin(filename);
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
|
|
|
|
|
filename);
|
|
|
|
|
|
|
|
|
|
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
DeserializeFromStream(fin, tensor, *dev_ctx);
|
|
|
|
|
|
|
|
|
@ -90,10 +88,11 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
&fp16_tensor);
|
|
|
|
|
|
|
|
|
|
// reset output tensor
|
|
|
|
|
out_var->Clear();
|
|
|
|
|
tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
var->Clear();
|
|
|
|
|
tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->set_lod(fp16_tensor.lod());
|
|
|
|
|
tensor->ShareDataWith(fp16_tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LoadSelectedRows(const std::string &filename,
|
|
|
|
@ -110,7 +109,7 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
// FIXME(yuyang18): We save variable to local file now, but we should change
|
|
|
|
|
// it to save an output stream.
|
|
|
|
|
std::ifstream fin(filename);
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to write",
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
|
|
|
|
|
filename);
|
|
|
|
|
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|