|
|
|
@ -57,11 +57,12 @@ static void MkDirRecursively(const char *fullpath) {
|
|
|
|
|
MkDir(fullpath);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class CkptSaveOp : public framework::OperatorBase {
|
|
|
|
|
class CheckpointSaveOp : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
CkptSaveOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
CheckpointSaveOp(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -122,9 +123,9 @@ class CkptSaveOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CkptSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
CkptSaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
CheckpointSaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput(
|
|
|
|
|
"X",
|
|
|
|
@ -155,4 +156,5 @@ to a file on disk.
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(ckpt_save, ops::CkptSaveOp, ops::CkptSaveOpProtoMaker);
|
|
|
|
|
REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
|
|
|
|
|
ops::CheckpointSaveOpProtoMaker);
|