parent
3c820064de
commit
f04b23adf9
@ -0,0 +1,87 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <stdint.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/data_type_transform.h"
|
||||
#include "paddle/fluid/framework/framework.pb.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
constexpr char kSEP = '/';
|
||||
// write empty file named _SUCCESS
|
||||
const char SUCCESS[] = "_SUCCESS";
|
||||
|
||||
static bool FileExists(const std::string &filepath) {
|
||||
struct stat buffer;
|
||||
return (stat(filepath.c_str(), &buffer) == 0);
|
||||
}
|
||||
|
||||
static std::string DirName(const std::string &filepath) {
|
||||
auto pos = filepath.rfind(kSEP);
|
||||
if (pos == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
return filepath.substr(0, pos);
|
||||
}
|
||||
|
||||
class CheckpointLoadOp : public framework::OperatorBase {
|
||||
public:
|
||||
CheckpointLoadOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto dir = Attr<std::string>("dir");
|
||||
bool is_present = FileExists(dir);
|
||||
if (!is_present) {
|
||||
return;
|
||||
}
|
||||
|
||||
// UPDATE LATER ...
|
||||
}
|
||||
};
|
||||
|
||||
class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddAttr<std::string>(
|
||||
"dir",
|
||||
"(string)"
|
||||
"The \"file_path\" where the LoDTensor variables will be saved.")
|
||||
.AddCustomChecker(
|
||||
[](const std::string &path) { return !path.empty(); });
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
|
||||
ops::CheckpointLoadOpProtoMaker);
|
Loading…
Reference in new issue