From 025e3e94d2b216cc278de103cbef27b851274bf5 Mon Sep 17 00:00:00 2001
From: Yu Yang <yuyang18@baidu.com>
Date: Tue, 20 Dec 2016 23:00:34 +0800
Subject: [PATCH] Add GradientMachine::start/finish to API

---
 demo/mnist/api_train.py         | 7 ++++++-
 paddle/api/GradientMachine.cpp  | 4 ++++
 paddle/api/PaddleAPI.h          | 9 +++++++++
 paddle/api/ParameterUpdater.cpp | 2 ++
 4 files changed, 21 insertions(+), 1 deletion(-)

diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py
index 5d4ef90f10..b061cfb2b8 100644
--- a/demo/mnist/api_train.py
+++ b/demo/mnist/api_train.py
@@ -30,7 +30,12 @@ def main():
     updater = api.ParameterUpdater.createLocalUpdater(opt_config)
     assert isinstance(updater, api.ParameterUpdater)
     updater.init(m)
-    updater.startPass()
+    m.start()
+
+    for _ in xrange(100):
+        updater.startPass()
+
+    m.finish()
 
 
 if __name__ == '__main__':
diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp
index 297eaa19bb..2cece21097 100644
--- a/paddle/api/GradientMachine.cpp
+++ b/paddle/api/GradientMachine.cpp
@@ -64,6 +64,10 @@ GradientMachine* GradientMachine::createByModelConfig(
   return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
 }
 
+void GradientMachine::start() { m->machine->start(); }
+
+void GradientMachine::finish() { m->machine->finish(); }
+
 void GradientMachine::forward(const Arguments& inArgs,
                               Arguments* outArgs,
                               PassType passType) {
diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h
index bd413eb1e9..c074325091 100644
--- a/paddle/api/PaddleAPI.h
+++ b/paddle/api/PaddleAPI.h
@@ -716,6 +716,13 @@ public:
       GradientMatchineCreateMode mode = CREATE_MODE_NORMAL,
       const std::vector<int>& parameterTypes = defaultParamTypes);
 
+  /**
+   * @brief finish
+   */
+  void finish();
+
+  void start();
+
   /**
    * The forward stage of GradientMachine.
    *
@@ -790,6 +797,8 @@ public:
 
   void startPass();
 
+  void finishPass();
+
 private:
   ParameterUpdaterPrivate* m;
 };
diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp
index af5b746a7c..3b626c0507 100644
--- a/paddle/api/ParameterUpdater.cpp
+++ b/paddle/api/ParameterUpdater.cpp
@@ -33,3 +33,5 @@ void ParameterUpdater::init(const GradientMachine &gm) {
 }
 
 void ParameterUpdater::startPass() { m->updater->startPass(); }
+
+void ParameterUpdater::finishPass() {}