|
|
|
@ -87,7 +87,7 @@ class CheckpointSaveOp : public framework::OperatorBase {
|
|
|
|
|
std::string *serial_num = serial_var->GetMutable<std::string>();
|
|
|
|
|
serial_num->append("0");
|
|
|
|
|
dir.append("/");
|
|
|
|
|
dir.append(serial_num);
|
|
|
|
|
dir.append(serial_num->c_str());
|
|
|
|
|
MkDirRecursively(dir.c_str());
|
|
|
|
|
|
|
|
|
|
auto inp_var_names = Inputs("X");
|
|
|
|
@ -159,10 +159,29 @@ to a file on disk.
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
auto out_var_name = op_desc.Output("Serial").front();
|
|
|
|
|
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
|
|
|
|
|
auto var_type = framework::proto::VarType::RAW;
|
|
|
|
|
out_var.SetType(var_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
|
|
|
|
|
ops::CheckpointSaveOpProtoMaker);
|
|
|
|
|
REGISTER_OPERATOR(send_vars, ops::CheckpointSaveOp,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::CheckpointSaveOpProtoMaker,
|
|
|
|
|
ops::CheckpointSaveOpVarTypeInference,
|
|
|
|
|
ops::CheckpointSaveOpShapeInference);
|
|
|
|
|