|
|
|
@ -23,6 +23,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/framework/variable.h"
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -70,7 +71,6 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto filename = Attr<std::string>("file_path");
|
|
|
|
|
auto overwrite = Attr<bool>("overwrite");
|
|
|
|
|
auto save_as_fp16 = Attr<bool>("save_as_fp16");
|
|
|
|
|
|
|
|
|
|
if (FileExists(filename) && !overwrite) {
|
|
|
|
|
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
|
|
|
@ -97,7 +97,7 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SaveLodTensor(const std::string &filename, const platform::Place &place,
|
|
|
|
|
Variable *var) {
|
|
|
|
|
framework::Variable *var) {
|
|
|
|
|
auto &tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
@ -110,6 +110,7 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
|
|
|
|
filename);
|
|
|
|
|
|
|
|
|
|
auto save_as_fp16 = Attr<bool>("save_as_fp16");
|
|
|
|
|
auto in_dtype = framework::ToDataType(tensor.type());
|
|
|
|
|
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
|
|
|
|
|
|
|
|
|
@ -124,11 +125,11 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
} else {
|
|
|
|
|
framework::SerializeToStream(fout, tensor, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
fout.close()
|
|
|
|
|
fout.close();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SaveSelectedRows(const std::string &filename, const platform::Place &place,
|
|
|
|
|
Variable *var) {
|
|
|
|
|
framework::Variable *var) {
|
|
|
|
|
auto &selectedRows = var->Get<framework::SelectedRows>();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|