|
|
|
@ -108,15 +108,22 @@ class CheckpointLoadOp : public framework::OperatorBase {
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
std::string dir = Attr<std::string>("dir");
|
|
|
|
|
int serial_num = Attr<int>("Serial");
|
|
|
|
|
std::string serial_num = Attr<std::string>("Serial");
|
|
|
|
|
|
|
|
|
|
std::string serial_var_name = std::string(SERIAL_VAR);
|
|
|
|
|
auto *serial_var = scope.FindVar(serial_var_name);
|
|
|
|
|
auto *serial_num;
|
|
|
|
|
if (serial_var == nullptr) {
|
|
|
|
|
*serial_var = scope.Var(serial_var_name);
|
|
|
|
|
*serial_num = serial_var->GetMutable<std::string>();
|
|
|
|
|
serial_num->append("0");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *serial_var = scope.FindVar(SERIAL_VAR);
|
|
|
|
|
serial_var = serial_num;
|
|
|
|
|
*serial_num = serial_var->GetMutable<std::string>();
|
|
|
|
|
VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER
|
|
|
|
|
<< " value: " << serial_num;
|
|
|
|
|
|
|
|
|
|
std::string success;
|
|
|
|
|
= GenePath(dir, std::to_string(serial_num));
|
|
|
|
|
std::string success = GenePath(dir, serial_num);
|
|
|
|
|
VLOG(3) << "Load checkpoint from dir: " << success;
|
|
|
|
|
success = GenePath(success, SUCCESS);
|
|
|
|
|
bool is_present = FileExists(success);
|
|
|
|
@ -157,9 +164,10 @@ This operator will serialize and write a list of input LoDTensor variables
|
|
|
|
|
to a file on disk.
|
|
|
|
|
)DOC");
|
|
|
|
|
|
|
|
|
|
AddAttr<int>("Serial",
|
|
|
|
|
"(int)"
|
|
|
|
|
"The serial number of the checkpoint will to be load.");
|
|
|
|
|
AddAttr<std::string>(
|
|
|
|
|
"Serial",
|
|
|
|
|
"(std::string)"
|
|
|
|
|
"The serial number of the checkpoint will to be load.");
|
|
|
|
|
AddAttr<std::string>(
|
|
|
|
|
"dir",
|
|
|
|
|
"(string)"
|
|
|
|
|