fix serial number

shanyi15-patch-3
tangwei12 7 years ago
parent dbd023771f
commit 22df4c278c

@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase {
std::string dir = Attr<std::string>("dir");
std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(!IsNumber(serial_num_attr),
PADDLE_ENFORCE(IsNumber(serial_num_attr),
"Checkpoint Serial must be a number");
std::string serial_var_name = std::string(SERIAL_VAR);

@ -96,8 +96,7 @@ class CheckpointSaveOp : public framework::OperatorBase {
int serials = 0;
if (!serial_num->empty()) {
std::string::size_type sz;
serials = std::stoi(serial_num->data, &sz);
serials = std::stoi(serial_num->data());
serials += 1;
}

@ -545,6 +545,7 @@ class DistributeTranspiler:
startup_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir,
"Serial": serial_number})
return startup_prog
@ -616,6 +617,7 @@ class DistributeTranspiler:
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir,
"Serial": serial_number})
@ -640,7 +642,7 @@ class DistributeTranspiler:
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(cur_dir):
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
try:

Loading…
Cancel
Save