|
|
|
@ -18,6 +18,7 @@ limitations under the License. */
|
|
|
|
|
#include <numeric>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/framework/data_type_transform.h"
|
|
|
|
|
#include "paddle/fluid/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
@ -68,6 +69,7 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto filename = Attr<std::string>("file_path");
|
|
|
|
|
auto overwrite = Attr<bool>("overwrite");
|
|
|
|
|
auto save_as_fp16 = Attr<bool>("save_as_fp16");
|
|
|
|
|
|
|
|
|
|
if (FileExists(filename) && !overwrite) {
|
|
|
|
|
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
|
|
|
@ -96,8 +98,19 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
auto in_dtype = framework::ToDataType(tensor.type());
|
|
|
|
|
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
|
|
|
|
|
|
|
|
|
|
if (in_dtype != out_dtype) {
|
|
|
|
|
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
|
|
|
|
|
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
|
|
|
|
|
framework::LoDTensor out;
|
|
|
|
|
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
|
|
|
|
|
framework::SerializeToStream(fout, out, dev_ctx);
|
|
|
|
|
} else {
|
|
|
|
|
framework::SerializeToStream(fout, tensor, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -114,6 +127,12 @@ This operator will serialize and write a tensor variable to file on disk.
|
|
|
|
|
"(boolean, default true)"
|
|
|
|
|
"Overwrite the output file if exist")
|
|
|
|
|
.SetDefault(true);
|
|
|
|
|
AddAttr<bool>("save_as_fp16",
|
|
|
|
|
"(boolean, default false)"
|
|
|
|
|
"If true, the tensor will be converted to float16 data "
|
|
|
|
|
"type and then saved. Otherwise, the tensor will be "
|
|
|
|
|
"directly saved without data type conversion.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<std::string>("file_path",
|
|
|
|
|
"(string)"
|
|
|
|
|
"The \"file_path\" where the variable will be saved.")
|
|
|
|
|