Feature/save op (#5090)
* Init * Stash * Polish SaveLoadOp * Fix CI * Polish code * Save GPU Tensor * Stash * Fix CIrevert-4814-Add_sequence_project_op
parent
f8c6dadae1
commit
efc2464f6c
@ -1,39 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
syntax = "proto2";
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
package paddle.framework;
|
||||
|
||||
import "framework.proto";
|
||||
|
||||
/**
|
||||
* This file contains necessary information for model, checkpoint.
|
||||
* etc.
|
||||
*/
|
||||
|
||||
message LoDInfo { repeated int64 level = 1; }
|
||||
|
||||
/**
|
||||
* Save the LoDTensorDesc information through LoDTensorProto, its data memory
|
||||
* is copyed to c buffer immediately. See model_format.md for details.
|
||||
*/
|
||||
|
||||
message LoDTensorProto {
|
||||
optional DataType data_type = 1;
|
||||
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
|
||||
repeated LoDInfo levels = 3;
|
||||
optional int32 lod_level = 4 [ default = 0 ];
|
||||
optional int32 version = 5;
|
||||
}
|
@ -0,0 +1,132 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LoadOp : public framework::OperatorBase {
|
||||
public:
|
||||
LoadOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
void Run(const framework::Scope &scope,
|
||||
const platform::DeviceContext &dev_ctx) const override {
|
||||
auto filename = Attr<std::string>("file_path");
|
||||
std::ifstream fin(filename);
|
||||
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
|
||||
filename);
|
||||
|
||||
auto out_var_name = Output("Out");
|
||||
auto *out_var = scope.FindVar(out_var_name);
|
||||
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
|
||||
out_var_name);
|
||||
|
||||
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
|
||||
|
||||
uint32_t version;
|
||||
fin.read(reinterpret_cast<char *>(&version), sizeof(version));
|
||||
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
|
||||
framework::TensorDesc desc;
|
||||
{ // int32_t size
|
||||
// proto buffer
|
||||
int32_t size;
|
||||
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
|
||||
std::unique_ptr<char[]> buf(new char[size]);
|
||||
fin.read(reinterpret_cast<char *>(buf.get()), size);
|
||||
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
|
||||
"Cannot parse tensor desc");
|
||||
}
|
||||
{ // read tensor
|
||||
std::vector<int64_t> dims;
|
||||
dims.reserve(static_cast<size_t>(desc.dims().size()));
|
||||
std::copy(desc.dims().begin(), desc.dims().end(),
|
||||
std::back_inserter(dims));
|
||||
tensor->Resize(framework::make_ddim(dims));
|
||||
|
||||
void *buf;
|
||||
platform::Place cpu = platform::CPUPlace();
|
||||
switch (desc.data_type()) {
|
||||
case framework::FP32:
|
||||
buf = tensor->mutable_data<float>(cpu);
|
||||
break;
|
||||
case framework::FP64:
|
||||
buf = tensor->mutable_data<double>(cpu);
|
||||
break;
|
||||
case framework::INT32:
|
||||
buf = tensor->mutable_data<int>(cpu);
|
||||
break;
|
||||
case framework::INT64:
|
||||
buf = tensor->mutable_data<int64_t>(cpu);
|
||||
break;
|
||||
default:
|
||||
PADDLE_THROW("DataType %d not supported", desc.data_type());
|
||||
}
|
||||
fin.read(static_cast<char *>(buf), tensor->memory_size());
|
||||
}
|
||||
{ // read lod
|
||||
uint64_t lod_level;
|
||||
fin.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
|
||||
auto &lod = *tensor->mutable_lod();
|
||||
lod.resize(lod_level);
|
||||
for (uint64_t i = 0; i < lod_level; ++i) {
|
||||
uint64_t size;
|
||||
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
|
||||
std::vector<size_t> tmp(size / sizeof(size_t));
|
||||
fin.read(reinterpret_cast<char *>(tmp.data()),
|
||||
static_cast<std::streamsize>(size));
|
||||
lod[i] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
auto place = dev_ctx.GetPlace();
|
||||
if (platform::is_gpu_place(place)) {
|
||||
// copy CPU to GPU
|
||||
framework::LoDTensor cpu_tensor;
|
||||
cpu_tensor.ShareDataWith(*tensor);
|
||||
cpu_tensor.set_lod(tensor->lod());
|
||||
|
||||
// reset tensor
|
||||
out_var->Clear();
|
||||
tensor = out_var->GetMutable<framework::LoDTensor>();
|
||||
tensor->set_lod(cpu_tensor.lod());
|
||||
tensor->CopyFrom(cpu_tensor, place, dev_ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LoadOpProtoMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddOutput("Out", "The tensor need to be loaded");
|
||||
AddComment(R"DOC(Load Operator
|
||||
Load operator will load a tensor variable from disk file.
|
||||
)DOC");
|
||||
AddAttr<std::string>("file_path",
|
||||
"Variable will be loaded from \"file_path\".")
|
||||
.AddCustomChecker(
|
||||
[](const std::string &path) { return !path.empty(); });
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
|
@ -0,0 +1,63 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
USE_NO_KERNEL_OP(save);
|
||||
USE_NO_KERNEL_OP(load);
|
||||
|
||||
TEST(SaveLoadOp, CPU) {
|
||||
paddle::framework::Scope scope;
|
||||
paddle::platform::CPUPlace place;
|
||||
paddle::platform::CPUDeviceContext ctx(place);
|
||||
auto var = scope.Var("test_var");
|
||||
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
|
||||
tensor->Resize({10, 10});
|
||||
paddle::framework::LoD expect_lod;
|
||||
expect_lod.resize(1);
|
||||
expect_lod[0].push_back(0);
|
||||
expect_lod[0].push_back(1);
|
||||
expect_lod[0].push_back(2);
|
||||
expect_lod[0].push_back(3);
|
||||
|
||||
tensor->set_lod(expect_lod);
|
||||
int* expect = tensor->mutable_data<int>(place);
|
||||
for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) {
|
||||
expect[i] = static_cast<int>(i);
|
||||
}
|
||||
paddle::framework::AttributeMap attrs;
|
||||
attrs.insert({"file_path", std::string("tensor.save")});
|
||||
|
||||
auto save_op = paddle::framework::OpRegistry::CreateOp(
|
||||
"save", {{"X", {"test_var"}}}, {}, attrs);
|
||||
save_op->Run(scope, ctx);
|
||||
|
||||
auto load_var = scope.Var("out_var");
|
||||
auto target = load_var->GetMutable<paddle::framework::LoDTensor>();
|
||||
auto load_op = paddle::framework::OpRegistry::CreateOp(
|
||||
"load", {}, {{"Out", {"out_var"}}}, attrs);
|
||||
load_op->Run(scope, ctx);
|
||||
int* actual = target->data<int>();
|
||||
for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) {
|
||||
EXPECT_EQ(expect[i], actual[i]);
|
||||
}
|
||||
auto& actual_lod = target->lod();
|
||||
EXPECT_EQ(expect_lod.size(), actual_lod.size());
|
||||
for (size_t i = 0; i < expect_lod.size(); ++i) {
|
||||
for (size_t j = 0; j < expect_lod[i].size(); ++j) {
|
||||
EXPECT_EQ(expect_lod[i][j], actual_lod[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,184 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <stdint.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include "paddle/framework/data_type.h"
|
||||
#include "paddle/framework/framework.pb.h"
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// TODO(yuyang18): If the functions below are needed by other files, move them
|
||||
// to paddle::filesystem namespace.
|
||||
constexpr char kSEP = '/';
|
||||
static bool FileExists(const std::string &filepath) {
|
||||
struct stat buffer;
|
||||
return (stat(filepath.c_str(), &buffer) == 0);
|
||||
}
|
||||
|
||||
static std::string DirName(const std::string &filepath) {
|
||||
auto pos = filepath.rfind(kSEP);
|
||||
if (pos == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
return filepath.substr(0, pos);
|
||||
}
|
||||
|
||||
static void MkDir(const char *path) {
|
||||
if (mkdir(path, 0755)) {
|
||||
PADDLE_ENFORCE_EQ(errno, EEXIST, "%s mkdir failed!", path);
|
||||
}
|
||||
}
|
||||
|
||||
static void MkDirRecursively(const char *fullpath) {
|
||||
if (*fullpath == '\0') return; // empty string
|
||||
if (FileExists(fullpath)) return;
|
||||
|
||||
MkDirRecursively(DirName(fullpath).c_str());
|
||||
MkDir(fullpath);
|
||||
}
|
||||
|
||||
class SaveOp : public framework::OperatorBase {
|
||||
public:
|
||||
SaveOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
void Run(const framework::Scope &scope,
|
||||
const platform::DeviceContext &dev_ctx) const override {
|
||||
auto filename = Attr<std::string>("file_path");
|
||||
auto overwrite = Attr<bool>("overwrite");
|
||||
|
||||
if (FileExists(filename) && !overwrite) {
|
||||
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
||||
filename, overwrite);
|
||||
}
|
||||
|
||||
MkDirRecursively(DirName(filename).c_str());
|
||||
|
||||
// FIXME(yuyang18): We save variable to local file now, but we should change
|
||||
// it to save an output stream.
|
||||
std::ofstream fout(filename);
|
||||
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
|
||||
filename);
|
||||
|
||||
auto iname = Input("X");
|
||||
auto *var = scope.FindVar(iname);
|
||||
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
|
||||
iname);
|
||||
|
||||
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
|
||||
"SaveOp only support LoDTensor, %s has wrong type", iname);
|
||||
|
||||
auto &tensor = var->Get<framework::LoDTensor>();
|
||||
|
||||
{ // the 1st field, uint32_t version
|
||||
constexpr uint32_t version = 0;
|
||||
fout.write(reinterpret_cast<const char *>(&version), sizeof(version));
|
||||
}
|
||||
{ // the 2nd field, tensor description
|
||||
// int32_t size
|
||||
// void* protobuf message
|
||||
framework::TensorDesc desc;
|
||||
desc.set_data_type(framework::ToDataType(tensor.type()));
|
||||
auto dims = framework::vectorize(tensor.dims());
|
||||
auto *pb_dims = desc.mutable_dims();
|
||||
pb_dims->Resize(static_cast<int>(dims.size()), 0);
|
||||
std::copy(dims.begin(), dims.end(), pb_dims->begin());
|
||||
int32_t size = desc.ByteSize();
|
||||
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
|
||||
auto out = desc.SerializeAsString();
|
||||
fout.write(out.data(), size);
|
||||
}
|
||||
{ // the 3rd field, tensor data
|
||||
uint64_t size = tensor.memory_size();
|
||||
auto *data_ptr = tensor.data<void>();
|
||||
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
|
||||
"Index overflow when writing tensor");
|
||||
if (platform::is_gpu_place(tensor.place())) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
|
||||
std::unique_ptr<char[]> buf(new char[kBufSize]);
|
||||
auto &gpu_dev_ctx =
|
||||
static_cast<const platform::CUDADeviceContext &>(dev_ctx);
|
||||
platform::CPUPlace cpu;
|
||||
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
|
||||
while (size != 0) {
|
||||
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
|
||||
memory::Copy(cpu, buf.get(),
|
||||
boost::get<platform::GPUPlace>(tensor.place()),
|
||||
reinterpret_cast<const void *>(data), size_to_write,
|
||||
gpu_dev_ctx.stream());
|
||||
gpu_dev_ctx.Wait();
|
||||
fout.write(buf.get(), size_to_write);
|
||||
data += size_to_write;
|
||||
size -= size_to_write;
|
||||
}
|
||||
#else
|
||||
PADDLE_THROW("Unexpected branch");
|
||||
#endif
|
||||
} else {
|
||||
fout.write(static_cast<const char *>(data_ptr),
|
||||
static_cast<std::streamsize>(size));
|
||||
}
|
||||
}
|
||||
{ // the 4th field, lod information
|
||||
// uint64_t lod_level
|
||||
// uint64_t lod_level_1 size in byte.
|
||||
// int* lod_level_1 data
|
||||
// ...
|
||||
auto lod = tensor.lod();
|
||||
uint64_t size = lod.size();
|
||||
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
|
||||
|
||||
for (auto &each : lod) {
|
||||
size = each.size() * sizeof(framework::LoD::value_type::value_type);
|
||||
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
|
||||
fout.write(reinterpret_cast<const char *>(each.data()),
|
||||
static_cast<std::streamsize>(size));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SaveOpProtoMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "The tensor need to be saved");
|
||||
AddComment(R"DOC(Save operator
|
||||
Save operator will serialize and write a tensor variable to disk file.
|
||||
)DOC");
|
||||
AddAttr<bool>("overwrite", "Overwrite the output file if exist")
|
||||
.SetDefault(true);
|
||||
AddAttr<std::string>("file_path",
|
||||
"Variable will be saved to \"file_path\".")
|
||||
.AddCustomChecker(
|
||||
[](const std::string &path) { return !path.empty(); });
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker);
|
@ -1,147 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
using framework::LoDTensor;
|
||||
|
||||
inline static std::string VarToFileName(const std::string& folder_path,
|
||||
const std::string& var_name) {
|
||||
return folder_path + "/__" + var_name + "__";
|
||||
}
|
||||
|
||||
class SaveOp : public framework::OperatorBase {
|
||||
public:
|
||||
SaveOp(const std::string& type, const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const override {
|
||||
const auto& var_names = this->Inputs("X");
|
||||
for (const auto& name : var_names) {
|
||||
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
|
||||
"Can not find variable '%s' in the scope.", name);
|
||||
}
|
||||
std::string folder_path = this->Attr<std::string>("folderPath");
|
||||
PADDLE_ENFORCE(!folder_path.empty(),
|
||||
"'folderPath' of SaveOp shouldn't be empty.");
|
||||
|
||||
VLOG(1) << "Save variables to folder: " << folder_path;
|
||||
for (const auto& name : var_names) {
|
||||
std::string file_name = VarToFileName(folder_path, name);
|
||||
std::ofstream fout(file_name, std::ofstream::out);
|
||||
PADDLE_ENFORCE(fout.is_open(), "Fail to create file %s.", file_name);
|
||||
const LoDTensor& tensor = scope.FindVar(name)->Get<LoDTensor>();
|
||||
std::string bytes = tensor.SerializeToString();
|
||||
fout << bytes;
|
||||
fout.close();
|
||||
}
|
||||
VLOG(1) << "Compelete saving variables. Items count: " << var_names.size();
|
||||
}
|
||||
};
|
||||
|
||||
class SaveOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SaveOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(tensor), the tensor count can be 1~INT_MAX, tensors names which "
|
||||
"values will be saved.")
|
||||
.AsDuplicable();
|
||||
AddAttr<std::string>("folderPath", "the folderPath for save model.");
|
||||
AddComment(R"DOC(
|
||||
Save the input tensors to a binary file based on input tensor names and absolute path.
|
||||
|
||||
All the inputs can carry the LoD (Level of Details) information,
|
||||
or not.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class RestoreOp : public framework::OperatorBase {
|
||||
public:
|
||||
RestoreOp(const std::string& type, const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const override {
|
||||
const auto& var_names = this->Outputs("Out");
|
||||
for (const auto& name : var_names) {
|
||||
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
|
||||
"Can not find variable '%s' in the scope.", name);
|
||||
}
|
||||
std::string folder_path = this->Attr<std::string>("folderPath");
|
||||
PADDLE_ENFORCE(!folder_path.empty(),
|
||||
"'folderPath' of RestoreOp shouldn't be empty.");
|
||||
|
||||
VLOG(1) << "Try loading variables from folder: " << folder_path;
|
||||
|
||||
for (const auto& name : var_names) {
|
||||
std::string file_name = VarToFileName(folder_path, name);
|
||||
std::ifstream fin(file_name, std::ifstream::in);
|
||||
PADDLE_ENFORCE(fin.is_open(), "Fail to open file %s.", file_name);
|
||||
const size_t kBufferSize = 4096; // equal to linux page size
|
||||
char buffer[kBufferSize];
|
||||
std::string cache;
|
||||
while (!fin.eof()) {
|
||||
fin.read(buffer, kBufferSize);
|
||||
cache.append(buffer, fin.gcount());
|
||||
}
|
||||
LoDTensor* tensor = scope.FindVar(name)->GetMutable<LoDTensor>();
|
||||
tensor->DeserializeFromString(cache, dev_ctx.GetPlace());
|
||||
fin.close();
|
||||
}
|
||||
VLOG(1) << "Complete loading variables.";
|
||||
}
|
||||
};
|
||||
|
||||
class RestoreOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
RestoreOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddOutput("Out",
|
||||
"(tensor), the tensor count can be 1~INT_MAX, tensors which "
|
||||
"values will be restores.")
|
||||
.AsDuplicable();
|
||||
AddAttr<std::string>("folderPath", "the folderPath for model file.");
|
||||
AddAttr<int>("data_type", "output tensor data type")
|
||||
.SetDefault(framework::DataType::FP32);
|
||||
AddComment(R"DOC(
|
||||
Restore the tensors from model file based on absolute path.
|
||||
|
||||
All the tensors outputs may carry the LoD (Level of Details) information,
|
||||
or not.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OPERATOR(save, paddle::operators::SaveOp,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
paddle::operators::SaveOpMaker);
|
||||
|
||||
REGISTER_OPERATOR(restore, paddle::operators::RestoreOp,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
paddle::operators::RestoreOpMaker);
|
@ -1,71 +0,0 @@
|
||||
import paddle.v2.framework.core as core
|
||||
import paddle.v2.framework.framework as framework
|
||||
import paddle.v2.framework.executor as executor
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
FOLDER_PATH = "./tmp_test_dir"
|
||||
|
||||
|
||||
class TestSaveRestoreOp(unittest.TestCase):
|
||||
def test_save_restore_op(self):
|
||||
tensor_1_val = np.random.rand(3, 9).astype("float32")
|
||||
tensor_2_val = np.random.randint(0, 20, size=(4, 2)).astype("int32")
|
||||
place = core.CPUPlace()
|
||||
|
||||
program = framework.Program()
|
||||
block = program.global_block()
|
||||
v_a = block.create_var(
|
||||
dtype="float32", shape=[3, 9], lod_level=0, name="tensor_1")
|
||||
v_b = block.create_var(
|
||||
dtype="int32", shape=[4, 2], lod_level=0, name="tensor_2")
|
||||
|
||||
t_1 = core.LoDTensor()
|
||||
t_1.set(tensor_1_val, place)
|
||||
t_2 = core.LoDTensor()
|
||||
t_2.set(tensor_2_val, place)
|
||||
block.append_op(
|
||||
type="save",
|
||||
inputs={"X": [v_a, v_b]},
|
||||
attrs={"folderPath": FOLDER_PATH})
|
||||
block.append_op(
|
||||
type="fill_constant",
|
||||
outputs={"Out": [v_a]},
|
||||
attrs={"shape": [2, 2],
|
||||
"value": 0.0})
|
||||
block.append_op(
|
||||
type="fill_constant",
|
||||
outputs={"Out": [v_b]},
|
||||
attrs={"shape": [2, 2],
|
||||
"value": 0.0})
|
||||
block.append_op(
|
||||
type="restore",
|
||||
outputs={"Out": [v_a, v_b]},
|
||||
attrs={"folderPath": FOLDER_PATH})
|
||||
|
||||
if os.path.exists(FOLDER_PATH):
|
||||
shutil.rmtree(FOLDER_PATH)
|
||||
os.makedirs(FOLDER_PATH)
|
||||
|
||||
exe = executor.Executor(place)
|
||||
out = exe.run(program,
|
||||
feed={"tensor_1": t_1,
|
||||
"tensor_2": t_2},
|
||||
fetch_list=[v_a, v_b])
|
||||
|
||||
self.assertTrue(os.path.isdir(FOLDER_PATH))
|
||||
self.assertTrue(os.path.isfile(FOLDER_PATH + "/__tensor_1__"))
|
||||
self.assertTrue(os.path.isfile(FOLDER_PATH + "/__tensor_2__"))
|
||||
|
||||
self.assertTrue(np.array_equal(np.array(out[0]), tensor_1_val))
|
||||
self.assertTrue(np.array_equal(np.array(out[1]), tensor_2_val))
|
||||
|
||||
shutil.rmtree(FOLDER_PATH)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue