From a4fd3756bbd95fb8c676af9aab7a22cfe87d9cc5 Mon Sep 17 00:00:00 2001
From: tangwei12 <tangwei12@baidu.com>
Date: Fri, 18 May 2018 09:46:14 +0800
Subject: [PATCH] bug fix

---
 paddle/fluid/operators/checkpoint_load_op.cc | 85 +++++++++++++-------
 paddle/fluid/operators/checkpoint_op_test.cc | 24 +++++-
 paddle/fluid/operators/checkpoint_save_op.cc | 36 +++++----
 3 files changed, 95 insertions(+), 50 deletions(-)

diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc
index 5fd3a7af9c..d24c781999 100644
--- a/paddle/fluid/operators/checkpoint_load_op.cc
+++ b/paddle/fluid/operators/checkpoint_load_op.cc
@@ -17,6 +17,7 @@ limitations under the License. */
 #include <fstream>
 #include <numeric>
 #include <sstream>
+#include <streambuf>
 #include <string>
 #include "paddle/fluid/framework/data_type.h"
 #include "paddle/fluid/framework/data_type_transform.h"
@@ -43,7 +44,13 @@ static std::string GenePath(const std::string &dir, const std::string &file) {
   file_path.append(file_path);
   file_path.append("/");
   file_path.append(file);
-  return full_path;
+  return file_path;
+}
+
+static bool IsNumber(const std::string &s) {
+  std::string::const_iterator it = s.begin();
+  while (it != s.end() && std::isdigit(*it)) ++it;
+  return !s.empty() && it == s.end();
 }
 
 static void LoadInputVars(const framework::Scope &scope,
@@ -62,7 +69,7 @@ static void LoadInputVars(const framework::Scope &scope,
                    "Cannot find variable %s for save_combine_op",
                    inp_var_names[i]);
     PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
-                   "SaveCombineOp only supports LoDTensor, %s has wrong type",
+                   "LoadCombineOp only supports LoDTensor, %s has wrong type",
                    inp_var_names[i]);
 
     std::string var_file = GenePath(dir, inp_var_names[i]);
@@ -78,21 +85,18 @@ static void LoadInputVars(const framework::Scope &scope,
 
 static void LoadStringArgv(const framework::Scope &scope,
                            const platform::Place &place,
-                           const std::string &argv, const std::string &dir) {
-  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
-  auto &dev_ctx = *pool.Get(place);
-
+                           const std::vector<std::string> &argv,
+                           const std::string &dir) {
   for (size_t i = 0; i < argv.size(); i++) {
-    auto *var = scope.FindVar(inp_var_names[i]);
+    auto *var = scope.FindVar(argv[i]);
     std::string *var_str = var->GetMutable<std::string>();
-
-    std::string var_file = GenePath(dir, argv);
+    std::string var_file = GenePath(dir, argv[i]);
     std::ifstream fin(var_file);
     PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
                    var_file);
-    std::getline(fin, var_str);
+    std::getline(fin, *var_str);
     fin.close();
-    VLOG(3) << " load String argv: " << argv << " value is: " << var_str;
+    VLOG(3) << " load String argv: " << argv[i] << " value is: " << var_str;
   }
 }
 
@@ -108,22 +112,24 @@ class CheckpointLoadOp : public framework::OperatorBase {
   void RunImpl(const framework::Scope &scope,
                const platform::Place &place) const override {
     std::string dir = Attr<std::string>("dir");
-    std::string serial_num = Attr<std::string>("Serial");
+    std::string serial_num_attr = Attr<std::string>("Serial");
+
+    PADDLE_ENFORCE(IsNumber(serial_num_attr),
+                   "Checkpoint Serial must be a number");
 
     std::string serial_var_name = std::string(SERIAL_VAR);
     auto *serial_var = scope.FindVar(serial_var_name);
-
-    if (serial_var == nullptr) {
-      *serial_var = scope.Var(serial_var_name);
-      auto *serial_tmp = serial_var->GetMutable<std::string>();
-      serial_tmp->append("0");
-    }
+    PADDLE_ENFORCE(serial_var != nullptr,
+                   "Cannot find variable %s for checkpoint_load_op",
+                   serial_var_name);
 
     auto *serial_num = serial_var->GetMutable<std::string>();
-    VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER
+    serial_num = serial_num_attr;
+
+    VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR
             << " value: " << serial_num;
 
-    std::string success = GenePath(dir, serial_num);
+    std::string success = GenePath(dir, serial_num->c_str());
     VLOG(3) << "Load checkpoint from dir: " << success;
     success = GenePath(success, SUCCESS);
     bool is_present = FileExists(success);
@@ -137,11 +143,11 @@ class CheckpointLoadOp : public framework::OperatorBase {
     auto inp_var_names = Inputs("X");
     PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
                       "The number of input variables should be greater than 0");
-    LoadInputVars(scope, place, &inp_var_names);
+    LoadInputVars(scope, place, inp_var_names, dir);
 
-    VLOG(3) << "Ready to load string argv to scope";
-    auto argv = Inputs("Argv");
-    LoadStringArgv(scope, place, &argv, &dir);
+    // VLOG(3) << "Ready to load string argv to scope";
+    // auto argv = Output("Argv");
+    // LoadStringArgv(scope, place, argv, dir);
   }
 };
 
@@ -153,14 +159,13 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
         "X",
         "(vector) Input LoDTensors that need to be saved together in a file.")
         .AsDuplicable();
-    AddInput(
+    AddOutput(
         "Argv",
-        "(vector) Input LoDTensors that need to be saved together in a file.")
-        .AsDuplicable();
+        "(vector) Input LoDTensors that need to be saved together in a file.");
     AddComment(R"DOC(
 CheckpointLoad operator
 
-This operator will serialize and write a list of input LoDTensor variables 
+This operator will serialize and write a list of input LoDTensor variables
 to a file on disk.
 )DOC");
 
@@ -177,10 +182,32 @@ to a file on disk.
   }
 };
 
+class CheckpointLoadOpVarTypeInference : public framework::VarTypeInference {
+ public:
+  void operator()(const framework::OpDesc &op_desc,
+                  framework::BlockDesc *block) const override {
+    auto out_var_name = op_desc.Output("Argv").front();
+    auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
+    auto var_type = framework::proto::VarType::RAW;
+    out_var.SetType(var_type);
+  }
+};
+
+class CheckpointLoadOpShapeInference : public framework::InferShapeBase {
+ public:
+  void operator()(framework::InferShapeContext *ctx) const override {}
+};
+
 }  // namespace operators
 }  // namespace paddle
 
 namespace ops = paddle::operators;
 
 REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
-                  ops::CheckpointLoadOpProtoMaker);
+                  paddle::framework::EmptyGradOpMaker,
+                  ops::CheckpointLoadOpProtoMaker,
+                  ops::CheckpointLoadOpVarTypeInference,
+                  ops::CheckpointLoadOpShapeInference);
+
+// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
+//                  ops::CheckpointLoadOpProtoMaker);
diff --git a/paddle/fluid/operators/checkpoint_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc
index 75bfc3f840..2acce227d2 100644
--- a/paddle/fluid/operators/checkpoint_op_test.cc
+++ b/paddle/fluid/operators/checkpoint_op_test.cc
@@ -44,7 +44,7 @@ TEST(CheckpointSaveOp, CPU) {
   attrs.insert({"dir", std::string("ckpt")});
 
   auto save_op = paddle::framework::OpRegistry::CreateOp(
-      "checkpoint_save", {{"X", {"test_var"}}}, attrs);
+      "checkpoint_save", {{"X", {"test_var"}}}, {}, attrs);
   save_op->Run(scope, place);
 }
 
@@ -52,13 +52,29 @@ TEST(CheckpointLoadOp, CPU) {
   paddle::framework::Scope scope;
   paddle::platform::CPUPlace place;
 
-  scope.Var("test_var");
+  auto var = scope.Var("test_var");
+  auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
+  tensor->Resize({3, 10});
+  paddle::framework::LoD expect_lod;
+  expect_lod.resize(1);
+  expect_lod[0].push_back(0);
+  expect_lod[0].push_back(1);
+  expect_lod[0].push_back(2);
+  expect_lod[0].push_back(3);
+
+  tensor->set_lod(expect_lod);
+  float* expect = tensor->mutable_data<float>(place);
+  for (int64_t i = 0; i < tensor->numel(); ++i) {
+    expect[i] = static_cast<float>(paddle::platform::float16(i));
+  }
+
+  scope.Var("SERIAL_NUMBER");
 
   paddle::framework::AttributeMap attrs;
   attrs.insert({"dir", std::string("ckpt")});
+  attrs.insert({"Serial", std::string("SERIAL_NUMBER")});
 
   auto load_op = paddle::framework::OpRegistry::CreateOp(
-      "checkpoint_load", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}},
-      attrs);
+      "checkpoint_load", {{"X", {"test_var"}}}, {{"Argv", {}}}, attrs);
   load_op->Run(scope, place);
 }
diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc
index 5fccefeed2..bab979e407 100644
--- a/paddle/fluid/operators/checkpoint_save_op.cc
+++ b/paddle/fluid/operators/checkpoint_save_op.cc
@@ -33,12 +33,18 @@ constexpr char kSEP = '/';
 const char SUCCESS[] = "_SUCCESS";
 const char SERIAL_VAR[] = "SERIAL_NUMBER";
 
+static bool IsNumber(const std::string &s) {
+  std::string::const_iterator it = s.begin();
+  while (it != s.end() && std::isdigit(*it)) ++it;
+  return !s.empty() && it == s.end();
+}
+
 static std::string GenePath(const std::string &dir, const std::string &file) {
   std::string file_path;
-  file_path.append(file_path);
+  file_path.append(dir);
   file_path.append("/");
   file_path.append(file);
-  return full_path;
+  return file_path;
 }
 
 static bool FileExists(const std::string &filepath) {
@@ -79,28 +85,24 @@ class CheckpointSaveOp : public framework::OperatorBase {
  private:
   void RunImpl(const framework::Scope &scope,
                const platform::Place &place) const override {
-    auto dir = Attr<std::string>("dir");
+    auto ck_dir = Attr<std::string>("dir");
     auto overwrite = Attr<bool>("overwrite");
 
     std::string serial_var_name = std::string(SERIAL_VAR);
-    auto *serial_var = scope.FindVar(serial_var_name);
-
-    if (serial_var == nullptr) {
-      *serial_var = scope.Var(serial_var_name);
-      *serial_tmp = serial_var->GetMutable<std::string>();
-      serial_tmp->append("0");
-    }
-    auto *serial_num = serial_var->GetMutable<std::string>();
-    VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER
+    auto *serial_num =
+        scope.FindVar(serial_var_name)->GetMutable<std::string>();
+    VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR
             << " value: " << serial_num;
 
-    auto *serial_num = serial_var->GetMutable<std::string>();
-    serial_num->append("1");
+    if (!IsNumber(serial_num)) {
+      serial_num = "0";
+    }
 
-    dir = GenePath(dir, serial_num);
+    std::string dir = GenePath(ck_dir, serial_num->c_str());
+    VLOG(1) << "CheckpointSaveOp current dir: " << dir;
     bool is_present = FileExists(dir);
     if (is_present && !overwrite) {
-      PADDLE_THROW("%s exists!, checkpoint save cannot to  overwrite it", dir,
+      PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
                    overwrite);
     }
     MkDirRecursively(dir.c_str());
@@ -150,7 +152,7 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
     AddComment(R"DOC(
 CheckpointSave operator
 
-This operator will serialize and write a list of input LoDTensor variables 
+This operator will serialize and write a list of input LoDTensor variables
 to a file on disk.
 )DOC");
     AddAttr<bool>("overwrite",