|
|
|
@ -41,18 +41,19 @@ class SaveOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto *input_var = ctx.InputVar("X");
|
|
|
|
|
auto iname = ctx.InputNames("X").data();
|
|
|
|
|
PADDLE_ENFORCE(input_var != nullptr, "Cannot find variable %s for save_op",
|
|
|
|
|
iname);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
input_var, platform::errors::InvalidArgument(
|
|
|
|
|
"The variable %s to be saved cannot be found.", iname));
|
|
|
|
|
|
|
|
|
|
if (input_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
SaveLodTensor(ctx, place, input_var);
|
|
|
|
|
} else if (input_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
SaveSelectedRows(ctx, place, input_var);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
false,
|
|
|
|
|
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
|
|
|
|
|
iname);
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Save operator only supports saving LoDTensor and SelectedRows "
|
|
|
|
|
"variable, %s has wrong type",
|
|
|
|
|
iname));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -62,10 +63,11 @@ class SaveOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto filename = ctx.Attr<std::string>("file_path");
|
|
|
|
|
auto overwrite = ctx.Attr<bool>("overwrite");
|
|
|
|
|
|
|
|
|
|
if (FileExists(filename) && !overwrite) {
|
|
|
|
|
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
|
|
|
|
filename, overwrite);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
FileExists(filename) && !overwrite, false,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"%s exists!, cannot save to it when overwrite is set to false.",
|
|
|
|
|
filename, overwrite));
|
|
|
|
|
|
|
|
|
|
MkDirRecursively(DirName(filename).c_str());
|
|
|
|
|
|
|
|
|
@ -78,8 +80,9 @@ class SaveOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// FIXME(yuyang18): We save variable to local file now, but we should change
|
|
|
|
|
// it to save an output stream.
|
|
|
|
|
std::ofstream fout(filename, std::ios::binary);
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
|
|
|
|
filename);
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
|
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
|
"Cannot open %s to save variables.", filename));
|
|
|
|
|
|
|
|
|
|
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
|
|
|
|
|
auto in_dtype = tensor.type();
|
|
|
|
@ -117,10 +120,11 @@ class SaveOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FileExists(filename) && !overwrite) {
|
|
|
|
|
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
|
|
|
|
filename, overwrite);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
FileExists(filename) && !overwrite, false,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"%s exists!, cannot save to it when overwrite is set to false.",
|
|
|
|
|
filename, overwrite));
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "SaveSelectedRows get File name: " << filename;
|
|
|
|
|
|
|
|
|
@ -135,8 +139,9 @@ class SaveOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// FIXME(yuyang18): We save variable to local file now, but we should change
|
|
|
|
|
// it to save an output stream.
|
|
|
|
|
std::ofstream fout(filename, std::ios::binary);
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
|
|
|
|
filename);
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
|
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
|
"Cannot open %s to save variables.", filename));
|
|
|
|
|
framework::SerializeToStream(fout, selectedRows, dev_ctx);
|
|
|
|
|
fout.close();
|
|
|
|
|
}
|
|
|
|
|