shanyi15-patch-3
tangwei12 7 years ago
parent a4fd3756bb
commit f688652f1e

@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase {
std::string dir = Attr<std::string>("dir");
std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(IsNumber(serial_num_attr),
PADDLE_ENFORCE(!IsNumber(serial_num_attr),
"Checkpoint Serial must be a number");
std::string serial_var_name = std::string(SERIAL_VAR);
@ -124,7 +124,8 @@ class CheckpointLoadOp : public framework::OperatorBase {
serial_var_name);
auto *serial_num = serial_var->GetMutable<std::string>();
serial_num = serial_num_attr;
serial_num->clear();
serial_num->append(serial_num_attr);
VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR
<< " value: " << serial_num;

@ -69,6 +69,8 @@ TEST(CheckpointLoadOp, CPU) {
}
scope.Var("SERIAL_NUMBER");
auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable<std::string>();
serial_num->append("0");
paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")});

@ -94,8 +94,8 @@ class CheckpointSaveOp : public framework::OperatorBase {
VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR
<< " value: " << serial_num;
if (!IsNumber(serial_num)) {
serial_num = "0";
if (serial_num->empty()) {
serial_num->append("0");
}
std::string dir = GenePath(ck_dir, serial_num->c_str());

Loading…
Cancel
Save