|
|
|
@ -22,6 +22,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
MkDirRecursively(DirName(filename).c_str());
|
|
|
|
|
|
|
|
|
|
// FIXME(yuyang18): We save variable to local file now, but we should change
|
|
|
|
|
// it to save an output stream.
|
|
|
|
|
std::ofstream fout(filename);
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
|
|
|
|
filename);
|
|
|
|
|
|
|
|
|
|
auto iname = Input("X");
|
|
|
|
|
auto *var = scope.FindVar(iname);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
|
|
|
|
|
iname);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
|
|
|
|
|
"SaveOp only support LoDTensor, %s has wrong type", iname);
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
SaveLodTensor(filename, place, var);
|
|
|
|
|
} else if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
SaveSelectedRows(filename, place, var);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
false,
|
|
|
|
|
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
|
|
|
|
|
iname);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SaveLodTensor(const string &filename, const platform::Place &place,
|
|
|
|
|
Variable *var) {
|
|
|
|
|
auto &tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// FIXME(yuyang18): We save variable to local file now, but we should change
|
|
|
|
|
// it to save an output stream.
|
|
|
|
|
std::ofstream fout(filename);
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
|
|
|
|
filename);
|
|
|
|
|
|
|
|
|
|
auto in_dtype = framework::ToDataType(tensor.type());
|
|
|
|
|
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
|
|
|
|
|
|
|
|
|
@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
} else {
|
|
|
|
|
framework::SerializeToStream(fout, tensor, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
fout.close()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SaveSelectedRows(const string &filename, const platform::Place &place,
|
|
|
|
|
Variable *var) {
|
|
|
|
|
auto &selectedRows = var->Get<framework::SelectedRows>();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// FIXME(yuyang18): We save variable to local file now, but we should change
|
|
|
|
|
// it to save an output stream.
|
|
|
|
|
std::ofstream fout(filename);
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
|
|
|
|
filename);
|
|
|
|
|
framework::SerializeToStream(fout, selectedRows, dev_ctx);
|
|
|
|
|
fout.close()
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X", "(Tensor ) Input tensor to be saved");
|
|
|
|
|
AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Save operator
|
|
|
|
|
|
|
|
|
|
This operator will serialize and write a tensor variable to file on disk.
|
|
|
|
|
This operator will serialize and write a tensor/selected rows variable to file on disk.
|
|
|
|
|
)DOC");
|
|
|
|
|
AddAttr<bool>("overwrite",
|
|
|
|
|
"(boolean, default true)"
|
|
|
|
|