diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake
index 78be074909..dc6730662f 100644
--- a/cmake/external/anakin.cmake
+++ b/cmake/external/anakin.cmake
@@ -52,9 +52,8 @@ ExternalProject_Add(
     extern_anakin
     ${EXTERNAL_PROJECT_LOG_ARGS}
     DEPENDS             ${MKLML_PROJECT}
-    # Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code.
-    GIT_REPOSITORY      "https://github.com/luotao1/Anakin"
-    GIT_TAG             "211d1fc5d813d70c0c14072f9083cf25f40940ea"
+    GIT_REPOSITORY      "https://github.com/PaddlePaddle/Anakin"
+    GIT_TAG             "9424277cf9ae180a14aff09560d3cd60a49c76d2"
     PREFIX              ${ANAKIN_SOURCE_DIR}
     UPDATE_COMMAND      ""
     CMAKE_ARGS          -DUSE_GPU_PLACE=YES
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 7ae0f445a8..106198362f 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -55,9 +55,10 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path
 paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,))
 paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
-paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
+paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
+paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
 paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
-paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
+paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
 paddle.fluid.InferenceTranspiler.__init__ 
 paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
@@ -299,6 +300,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', '
 paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
 paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3))
 paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
+paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
 paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
 paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
 paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
@@ -334,9 +336,10 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array
 paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
 paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
-paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
+paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
+paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
 paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
-paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
+paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
 paddle.fluid.transpiler.InferenceTranspiler.__init__ 
 paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index 2c62d4ed6b..0668ff43c8 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
 cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
 
 if(WITH_DISTRIBUTE)
-  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr)
+  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
   set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
   set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 else()
-  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
+  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
 endif()
 
 if (NOT WIN32)
diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index da0955a9a0..bfc649017f 100644
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -3,14 +3,18 @@ cc_library(graph SRCS graph.cc DEPS node)
 cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
 cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
 cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
+cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
 cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
-cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
-cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
+cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
+cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
+cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
 cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
-
+cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
+cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
 
 cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
 cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
 cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
-cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
-cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
+cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
+cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
+cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
new file mode 100644
index 0000000000..2876de88f1
--- /dev/null
+++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
@@ -0,0 +1,273 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
+#include "paddle/fluid/framework/ir/graph_viz_pass.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/inference/api/helper.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+struct Param {
+  std::string X = "concat_0.tmp_0";
+  std::string C0 = "cell_init";
+  std::string H0 = "hidden_init";
+  std::string AttentionWeight = "attention_fc.w_0";
+  std::string AttentionBias = "attention_fc.b_0";
+  std::string AttentionScalar = "attention_output.w_0";
+  std::string AttentionScalarBias = "attention_output.b_0";
+  std::string LSTMWeight = "attention_w.new";
+  std::string LSTMBias = "attention_b.new";
+  std::string Hidden = "array_to_lod_tensor_0.tmp_0";
+  std::string Cell = "at.cell.new";
+  std::string AttentionedX = "at.x.new";
+  std::string AttentionFCOut = "at.fc.new";
+  std::string LSTMX = "at.lstmx.new";
+  std::string LSTMOUT = "at.lstmout.new";
+};
+
+void PrepareParameters(Graph* graph, const Param& param);
+
+void FindWhileOp(Graph* graph) {
+  GraphPatternDetector gpd;
+  std::unordered_set<int> fused_external_ops(
+      {35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
+       57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51});
+
+  gpd.mutable_pattern()->NewNode(
+      [&](Node* n) { return fused_external_ops.count(n->id()); }, "while");
+
+  if (!graph->Has(kGraphvizMarkedNodeAttr)) {
+    graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
+  }
+  auto& marked_nodes =
+      graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);
+
+  auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
+                    Graph* g) {
+    auto* while_pat_node = gpd.pattern().RetriveNode("while");
+    auto* while_node = subgraph.at(while_pat_node);
+    marked_nodes.insert(while_node);
+  };
+  gpd(graph, handle);
+
+  Param param;
+  // Add AttentionLSTM node
+  OpDesc op_desc;
+  op_desc.SetType("attention_lstm");
+
+#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
+#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
+  OP_SET_IN(X);
+  OP_SET_IN(C0);
+  OP_SET_IN(H0);
+  OP_SET_IN(AttentionWeight);
+  OP_SET_IN(AttentionBias);
+  OP_SET_IN(AttentionScalar);
+  OP_SET_IN(AttentionScalarBias);
+  OP_SET_IN(LSTMWeight);
+  OP_SET_IN(LSTMBias);
+
+  OP_SET_OUT(Hidden);
+  OP_SET_OUT(Cell);
+  OP_SET_OUT(AttentionedX);
+  OP_SET_OUT(AttentionFCOut);
+  OP_SET_OUT(LSTMX);
+  OP_SET_OUT(LSTMOUT);
+#undef OP_SET_IN
+#undef OP_SET_OUT
+
+  auto* X = graph->RetriveNode(34);
+  auto* LSTMOUT = graph->RetriveNode(81);
+  auto* cell_init = graph->RetriveNode(6);
+  auto* hidden_init = graph->RetriveNode(8);
+
+#define LINK_TO(node0, node1)      \
+  node0->outputs.push_back(node1); \
+  node1->inputs.push_back(node0);
+
+  auto* lstm_op = graph->CreateOpNode(&op_desc);
+  PrepareParameters(graph, param);
+
+  LINK_TO(X, lstm_op);
+  LINK_TO(cell_init, lstm_op);
+  LINK_TO(hidden_init, lstm_op);
+  LINK_TO(lstm_op, LSTMOUT);
+
+  GraphSafeRemoveNodes(graph, marked_nodes);
+}
+
+#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
+#define CHECK_P2(x0, x1) \
+  CHECK_P1(x0);          \
+  CHECK_P1(x1);
+#define CHECK_P3(x0, x1, x2) \
+  CHECK_P2(x0, x1);          \
+  CHECK_P1(x2);
+#define CHECK_P4(x0, x1, x2, x3) \
+  CHECK_P3(x0, x1, x2);          \
+  CHECK_P1(x3);
+#define CHECK_P5(x0, x1, x2, x3, x4) \
+  CHECK_P4(x0, x1, x2, x3);          \
+  CHECK_P1(x4);
+
+void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
+                       const LoDTensor& W_forget_w1,
+                       const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
+                       const LoDTensor& W_output_w0,
+                       const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
+                       const LoDTensor& W_cell_w1, LoDTensor* out);
+
+void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
+                     const LoDTensor& B_output, const LoDTensor& B_cell,
+                     LoDTensor* out);
+
+void PrepareParameters(Graph* graph, const Param& param) {
+  // Check parameters
+  PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
+  auto* scope = graph->Get<Scope*>(kParamScopeAttr);
+
+  // Create new parameters.
+  scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
+  scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
+  scope->Var(param.Hidden)->GetMutable<LoDTensor>();
+  scope->Var(param.Cell)->GetMutable<LoDTensor>();
+  scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
+  scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
+  scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
+  scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
+
+#define GATE_W(name__)                                               \
+  auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0");            \
+  auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1");            \
+  auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0");            \
+  CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0);       \
+  VLOG(4) << #name__ "_w0"                                           \
+          << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
+  VLOG(4) << #name__ "_w1"                                           \
+          << " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
+  VLOG(4) << #name__ "_b0"                                           \
+          << " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
+  auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>();       \
+  auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>();       \
+  auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
+
+  GATE_W(forget);
+  GATE_W(input);
+  GATE_W(output);
+  GATE_W(c);
+#undef GATE_W
+
+  auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
+  auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
+  auto* attention_output_w = scope->FindVar("attention_output.w_0");
+  auto* attention_output_b = scope->FindVar("attention_output.b_0");
+  CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
+           attention_output_b);
+
+  auto* lstm_weight = scope->Var(param.LSTMWeight);
+  auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
+  auto* lstm_bias = scope->Var(param.LSTMBias);
+  auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
+
+  // reshape attention_bias
+  auto* attention_bias_t =
+      scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
+  PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
+  attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
+
+  auto* attention_scalar_bias_t =
+      scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
+  attention_scalar_bias_t->Resize(
+      make_ddim({1, attention_scalar_bias_t->dims()[0]}));
+
+  PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t,
+                    W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t,
+                    lstm_weight_t);
+  PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t,
+                  lstm_bias_t);
+}
+
+// Prepare parameters
+void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
+                       const LoDTensor& W_forget_w1,
+                       const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
+                       const LoDTensor& W_output_w0,
+                       const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
+                       const LoDTensor& W_cell_w1, LoDTensor* out) {
+  int D = W_forget_w0.dims()[0];
+  int M = W_forget_w1.dims()[0];
+  out->Resize(make_ddim({D + M, 4 * D}));
+  VLOG(3) << "LSTMWeight resized to " << out->dims();
+
+  float* out_data = out->mutable_data<float>(platform::CPUPlace());
+  std::array<const float*, 4> tensors(
+      {W_forget_w0.data<float>(), W_input_w0.data<float>(),
+       W_output_w0.data<float>(), W_cell_w0.data<float>()});
+  std::array<const float*, 4> tensors1(
+      {W_forget_w1.data<float>(), W_input_w1.data<float>(),
+       W_output_w1.data<float>(), W_cell_w1.data<float>()});
+
+  for (int row = 0; row < D; row++) {
+    for (int col = 0; col < 4; col++) {
+      float* dst = out_data + 4 * D * row + D * col;
+      const float* src = tensors[col] + D * row;
+      memcpy(dst, src, D * sizeof(float));
+    }
+  }
+
+  for (int row = 0; row < M; row++) {
+    for (int col = 0; col < 4; col++) {
+      float* dst = out_data + 4 * D * (D + row) + D * col;
+      const float* src = tensors1[col] + D * row;
+      memcpy(dst, src, D * sizeof(float));
+    }
+  }
+}
+
+void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
+                     const LoDTensor& B_output, const LoDTensor& B_cell,
+                     LoDTensor* out) {
+  std::array<const float*, 4> tensors(
+      {B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
+       B_cell.data<float>()});
+
+  PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
+  int D = B_forget.dims()[0];
+  out->Resize(make_ddim({1, 4 * D}));
+  auto* out_data = out->mutable_data<float>(platform::CPUPlace());
+  for (size_t i = 0; i < tensors.size(); i++) {
+    memcpy(out_data + D * i, tensors[i], D * sizeof(float));
+  }
+}
+
+// Parameters
+
+std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
+    std::unique_ptr<ir::Graph> graph) const {
+  PDPattern external_pattern, subblock_pattern;
+
+  FindWhileOp(graph.get());
+  return graph;
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(attention_lstm_fuse_pass,
+              paddle::framework::ir::AttentionLSTMFusePass);
diff --git a/paddle/fluid/inference/analysis/dot.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
similarity index 62%
rename from paddle/fluid/inference/analysis/dot.cc
rename to paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
index d5471ffcb5..a756dfc1b9 100644
--- a/paddle/fluid/inference/analysis/dot.cc
+++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
@@ -1,4 +1,4 @@
-//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -12,12 +12,19 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "paddle/fluid/inference/analysis/dot.h"
+#pragma once
+
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
 
 namespace paddle {
-namespace inference {
-namespace analysis {
-size_t Dot::counter = 0;
-}  // namespace analysis
-}  // namespace inference
+namespace framework {
+namespace ir {
+
+class AttentionLSTMFusePass : public FusePassBase {
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
+};
+
+}  // namespace ir
+}  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc
index f4327742ea..201160f29d 100644
--- a/paddle/fluid/framework/ir/fc_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc
@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
       },
       "elementwise_add_out");
 
-  pattern->AddEdge(mul_parameter_var, mul_op);
-  pattern->AddEdge(mul_tmp_input_var, mul_op);
-  pattern->AddEdge(mul_op, mul_out_var);
-  pattern->AddEdge(mul_out_var, elementwise_add_op);
-  pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
-  pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
+  mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
+      .LinksTo({mul_out_var});
+  elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
+      .LinksTo({elementwise_add_out_var});
 }
 
 // Replace the node `from` in the links to `to`
@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
 
   std::unordered_set<Node*> nodes2delete;
 
-  GraphPatternDetecter gpd;
+  GraphPatternDetector gpd;
   BuildFCPattern(gpd.mutable_pattern());
 
 #define GET_NODE(id)                                             \
@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
   auto* id = subgraph.at(gpd.pattern().RetriveNode(#id));        \
   PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
 
-  auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph,
+  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                      Graph* g) {
     VLOG(4) << "handle FC fuse";
     // Currently, there is no FC op available, so I will just simulate the
diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.h b/paddle/fluid/framework/ir/fc_fuse_pass.h
index eb43dd4486..31ed0e362f 100644
--- a/paddle/fluid/framework/ir/fc_fuse_pass.h
+++ b/paddle/fluid/framework/ir/fc_fuse_pass.h
@@ -13,7 +13,7 @@
 // limitations under the License.
 
 #include "paddle/fluid/framework/ir/graph.h"
-#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
 #include "paddle/fluid/framework/ir/pass.h"
 
 namespace paddle {
diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
new file mode 100644
index 0000000000..daecf3b407
--- /dev/null
+++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
@@ -0,0 +1,126 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
+    std::unique_ptr<ir::Graph> graph) const {
+  GraphPatternDetector gpd;
+  auto* pattern = gpd.mutable_pattern();
+
+  std::unordered_set<int> fused_ops({// first lstm
+                                     13, 15, 16,
+                                     // second lstm
+                                     23, 25, 26});
+
+  pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); },
+                   "any_node");
+
+  std::unordered_set<Node*> marked_nodes;
+
+  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
+                     Graph* g) {
+
+    auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node"));
+    marked_nodes.insert(id);
+  };
+  gpd(graph.get(), handler);
+
+  // Create New OpDesc
+  auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
+                          int bias, int hidden, int cell, int xx) {
+#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
+    GET_NODE(input);
+    GET_NODE(weight_x);
+    GET_NODE(weight_h);
+    GET_NODE(bias);
+    GET_NODE(hidden);
+    GET_NODE(cell);
+    GET_NODE(xx);
+    GET_NODE(lstm);
+
+    OpDesc op_desc;
+    op_desc.SetType("fusion_lstm");
+#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
+    SET_IN(X, input);
+    SET_IN(WeightX, weight_x);
+    SET_IN(WeightH, weight_h);
+    SET_IN(Bias, bias);
+#undef GET_NODE
+#undef SET_IN
+
+    LOG(INFO) << "hidden_n: " << hidden_n->Name();
+    LOG(INFO) << "cell: " << cell_n->Name();
+    LOG(INFO) << "xx: " << xx_n->Name();
+
+    op_desc.SetInput("H0", {});
+    op_desc.SetInput("C0", {});
+    op_desc.SetOutput("Hidden", {hidden_n->Name()});
+    op_desc.SetOutput("Cell", {cell_n->Name()});
+    op_desc.SetOutput("XX", {xx_n->Name()});
+    op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
+    op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
+    op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
+    op_desc.SetAttr("use_peepholes", false);
+    auto* op = graph->CreateOpNode(&op_desc);
+
+#define LINK_TO(a, b)      \
+  a->outputs.push_back(b); \
+  b->inputs.push_back(a);
+    LINK_TO(input_n, op);
+    LINK_TO(weight_x_n, op);
+    LINK_TO(weight_h_n, op);
+    LINK_TO(bias_n, op);
+    LINK_TO(op, hidden_n);
+#undef LINK_TO
+    return op;
+
+  };
+
+  lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
+  lstm_creator(26, 12, 24, 28, 27, 32, 31, 29);
+
+  // remove all the nodes
+
+  for (auto* node : marked_nodes) {
+    graph->RemoveNode(const_cast<Node*>(node));
+  }
+
+  for (auto* node : graph->Nodes()) {
+    for (auto it = node->inputs.begin(); it != node->inputs.end();) {
+      if (marked_nodes.count(*it)) {
+        it = const_cast<Node*>(node)->inputs.erase(it);
+      } else
+        it++;
+    }
+    for (auto it = node->outputs.begin(); it != node->outputs.end();) {
+      if (marked_nodes.count(*it)) {
+        it = const_cast<Node*>(node)->outputs.erase(it);
+      } else
+        it++;
+    }
+  }
+
+  return graph;
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);
diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
new file mode 100644
index 0000000000..74b08ae558
--- /dev/null
+++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
@@ -0,0 +1,33 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
+#include "paddle/fluid/framework/ir/pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+class FCLstmFusePass : public Pass {
+ public:
+  virtual ~FCLstmFusePass() {}
+
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/ir/fuse_pass_base.h b/paddle/fluid/framework/ir/fuse_pass_base.h
new file mode 100644
index 0000000000..bf6a0ae827
--- /dev/null
+++ b/paddle/fluid/framework/ir/fuse_pass_base.h
@@ -0,0 +1,44 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/pass.h"
+#include "paddle/fluid/framework/scope.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+static const char kParamScopeAttr[] = "param_scope";
+
+class FusePassBase : public Pass {
+ public:
+  void Init(Graph* graph) const { graph_ = graph; }
+
+  Scope* param_scope() const {
+    PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
+    return graph_->Get<framework::Scope*>(kParamScopeAttr);
+  }
+
+  virtual ~FusePassBase() {}
+
+ protected:
+  mutable Graph* graph_;
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h
index 0d27be5fc0..b696489565 100644
--- a/paddle/fluid/framework/ir/graph.h
+++ b/paddle/fluid/framework/ir/graph.h
@@ -99,13 +99,13 @@ class Graph {
   // Create a normal variable with non-null VarDesc.
   ir::Node *CreateVarNode(VarDesc *var_desc) {
     PADDLE_ENFORCE(var_desc);
-    return AddNode(new ir::Node(var_desc));
+    return AddNode(new ir::Node(var_desc, node_count_++));
   }
 
   // Create a normal runnable operator with OpDesc.
   ir::Node *CreateOpNode(OpDesc *op_desc) {
     PADDLE_ENFORCE(op_desc);
-    return AddNode(new ir::Node(op_desc));
+    return AddNode(new ir::Node(op_desc, node_count_++));
   }
 
   // Create a control dependency var that connects 2 operations. The
@@ -115,13 +115,14 @@ class Graph {
     // TODO(panyx0718): control var name should be really unique.
     const std::string name = string::Sprintf(
         "%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
-    return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
+    return AddNode(
+        new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
   }
 
   // A more free style way of creating a graph node. Mostly use for test
   // or "copy" from another node. Avoid using it if possible.
   ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
-    return AddNode(new ir::Node(name, type));
+    return AddNode(new ir::Node(name, type, node_count_++));
   }
 
   // Clear all node information of the graph and return the ownership of the
@@ -142,12 +143,20 @@ class Graph {
     nodes_.erase(node);
   }
 
+  Node *RetriveNode(int id) {
+    auto it = id2node_.find(id);
+    if (it != id2node_.end()) return it->second;
+    return nullptr;
+  }
+
  private:
   // This method takes ownership of `node`.
   ir::Node *AddNode(ir::Node *node) {
     PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
     nodes_[node].reset(node);
     node_set_.insert(node);
+    PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
+    id2node_[node->id()] = node;
     return node;
   }
 
@@ -157,6 +166,8 @@ class Graph {
   std::map<std::string, std::function<void(void)>> attr_dels_;
   std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
   std::unordered_set<ir::Node *> node_set_;
+  std::map<int, Node *> id2node_;
+  int node_count_{0};
 };
 
 bool IsControlDepVar(const ir::Node &var);
diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc
index dc81a2cac5..62f94a1c0e 100644
--- a/paddle/fluid/framework/ir/graph_helper.cc
+++ b/paddle/fluid/framework/ir/graph_helper.cc
@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
     for (auto &var : n->inputs) {
       for (auto &adj_n : var->inputs) {
         PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
-        adj_list[n].insert(adj_n);
         VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
                 << " -> " << n->Name() << reinterpret_cast<void *>(n)
                 << "  via " << var->Name() << reinterpret_cast<void *>(var);
+        adj_list[n].insert(adj_n);
       }
     }
   }
diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc
similarity index 71%
rename from paddle/fluid/framework/ir/graph_pattern_detecter.cc
rename to paddle/fluid/framework/ir/graph_pattern_detector.cc
index e197861251..dce4be8ff0 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detecter.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc
@@ -17,7 +17,7 @@
 #include <vector>
 
 #include "paddle/fluid/framework/ir/graph_helper.h"
-#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
 #include "paddle/fluid/framework/ir/graph_traits.h"
 #include "paddle/fluid/platform/enforce.h"
 
@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
                       name);
   }
 
-  nodes_.emplace_back(new PDNode(std::move(teller), name));
+  nodes_.emplace_back(new PDNode(std::move(teller), this, name));
   auto* cur = nodes_.back().get();
   node_map_[name] = cur;
   return cur;
@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
   edges_.emplace_back(a, b);
 }
 
-void GraphPatternDetecter::operator()(Graph* graph,
-                                      GraphPatternDetecter::handle_t handler) {
+void GraphPatternDetector::operator()(Graph* graph,
+                                      GraphPatternDetector::handle_t handler) {
   if (!MarkPDNodesInGraph(*graph)) return;
   auto subgraphs = DetectPatterns();
   UniquePatterns(&subgraphs);
   RemoveOverlappedMatch(&subgraphs);
 
+  LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
+  int id = 0;
   for (auto& g : subgraphs) {
+    LOG(INFO) << "optimizing #" << id++ << " subgraph";
     handler(g, graph);
   }
 }
 
-bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) {
+bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
   VLOG(4) << "mark pdnodes in graph";
   if (graph.Nodes().empty()) return false;
 
@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
   return false;
 }
 
-std::vector<GraphPatternDetecter::subgraph_t>
-GraphPatternDetecter::DetectPatterns() {
+std::vector<GraphPatternDetector::subgraph_t>
+GraphPatternDetector::DetectPatterns() {
   // Init empty subgraphs.
-  std::vector<GraphPatternDetecter::subgraph_t> result;
+  std::vector<GraphPatternDetector::subgraph_t> result;
   std::vector<HitGroup> init_groups;
-  PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
-  auto* first_pnode = pattern_.edges().front().first;
+  std::array<std::vector<HitGroup>, 2> bi_records;
+  // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
+  auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
+                                               : pattern_.edges().front().first;
   if (!pdnodes2nodes_.count(first_pnode)) return result;
   for (auto* node : pdnodes2nodes_[first_pnode]) {
     HitGroup group;
@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
   }
 
   int step = 0;
-  std::array<std::vector<HitGroup>, 2> bi_records;
   bi_records[0] = std::move(init_groups);
 
   // Extend a PDNode to subgraphs by deducing the connection relations defined
@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
     auto& pre_groups = bi_records[step % 2];
     auto& cur_groups = bi_records[1 - (step++ % 2)];
     cur_groups.clear();
+    if (pre_groups.empty()) break;
     // source -> target
     for (Node* source : pdnodes2nodes_[edge.first]) {
       for (Node* target : pdnodes2nodes_[edge.second]) {
@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
   }
 
   for (auto& group : bi_records[step % 2]) {
-    GraphPatternDetecter::subgraph_t subgraph;
+    GraphPatternDetector::subgraph_t subgraph;
     for (auto& role : group.roles) {
       subgraph.emplace(role.first, role.second);
     }
@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
   return result;
 }
 
-void GraphPatternDetecter::UniquePatterns(
-    std::vector<GraphPatternDetecter::subgraph_t>* subgraphs) {
+void GraphPatternDetector::UniquePatterns(
+    std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
   if (subgraphs->empty()) return;
-  std::vector<GraphPatternDetecter::subgraph_t> result;
+  std::vector<GraphPatternDetector::subgraph_t> result;
 
   std::unordered_set<size_t> set;
   for (auto& g : *subgraphs) {
@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
   *subgraphs = result;
 }
 
-void GraphPatternDetecter::RemoveOverlappedMatch(
+void GraphPatternDetector::RemoveOverlappedMatch(
     std::vector<subgraph_t>* subgraphs) {
   std::vector<subgraph_t> result;
   std::unordered_set<Node*> node_set;
@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
   *subgraphs = result;
 }
 
+std::string PDPattern::DotString() const {
+  using inference::analysis::Dot;
+  Dot dot;
+  int id = 0;
+  // Create Nodes
+  std::unordered_map<PDNode*, std::string> node2dot;
+  for (const auto& node : nodes()) {
+    std::string node_id = "Node" + std::to_string(id++);
+    dot.AddNode(node_id, {}, node->name());
+    node2dot[node.get()] = node_id;
+  }
+  // Create Edges
+  for (const auto& edge : edges()) {
+    if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
+      LOG(ERROR) << "no node " << edge.first << " " << edge.second;
+      continue;
+    }
+    auto& src = node2dot.at(edge.first);
+    auto& trg = node2dot.at(edge.second);
+    dot.AddEdge(src, trg, {});
+  }
+  return dot.Build();
+}
+
+PDNode& PDNode::LinksTo(const std::vector<PDNode*>& others) {
+  // extend outlinks.
+  for (PDNode* x : others) {
+    pattern_->AddEdge(this, x);
+  }
+  return *this;
+}
+
+PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
+  // extend outlinks.
+  for (PDNode* x : others) {
+    pattern_->AddEdge(x, this);
+  }
+  return *this;
+}
+
 }  // namespace ir
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.h b/paddle/fluid/framework/ir/graph_pattern_detector.h
similarity index 72%
rename from paddle/fluid/framework/ir/graph_pattern_detecter.h
rename to paddle/fluid/framework/ir/graph_pattern_detector.h
index 68c39902b5..0ac34a57aa 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detecter.h
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.h
@@ -21,12 +21,14 @@
 #include <numeric>
 #include "paddle/fluid/framework/ir/graph.h"
 #include "paddle/fluid/framework/ir/node.h"
+#include "paddle/fluid/inference/analysis/dot.h"
 
 namespace paddle {
 namespace framework {
 namespace ir {
+class PDPattern;
 
-// Some basic torminolygies:
+// Some basic terminologies:
 //   - PDPattern: a pattern defined as a data flow graph.
 //   - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
 //     that meets some conditions defined in `PDNode.teller`.
@@ -36,30 +38,43 @@ namespace ir {
 struct PDNode {
   // tell whether an ir::Node* is a candidation for a PDNode.
   using teller_t = std::function<bool(Node*)>;
+  enum class Type { kOp, kVar };
 
-  PDNode(teller_t&& teller, const std::string& name = "")
-      : teller_(teller), name_(name) {
-    PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
-  }
-
-  PDNode(PDNode&& other) = default;
-
-  std::vector<PDNode*> inlinks;
-  std::vector<PDNode*> outlinks;
+  // this link to others
+  PDNode& LinksTo(const std::vector<PDNode*>& others);
+  PDNode& LinksFrom(const std::vector<PDNode*>& others);
 
   bool Tell(Node* node) const {
     PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
     return teller_(node);
   }
 
+  bool IsOp() const { return type_ == Type::kOp; }
+  bool IsVar() const { return type_ == Type::kVar; }
+
   const std::string& name() const { return name_; }
 
   PDNode(const PDNode&) = delete;
   PDNode& operator=(const PDNode&) = delete;
 
  private:
+  PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
+         Type type = Type::kVar)
+      : teller_(std::move(teller)),
+        pattern_(pattern),
+        name_(name),
+        type_(type) {
+    PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
+  }
+
+  PDNode(PDNode&& other) = default;
+
+  friend class PDPattern;
+
   teller_t teller_;
+  PDPattern* pattern_;
   std::string name_;
+  Type type_;
 };
 
 /*
@@ -102,6 +117,8 @@ class PDPattern {
   const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
   const std::vector<edge_t>& edges() const { return edges_; }
 
+  std::string DotString() const;
+
  private:
 #ifdef PADDLE_WITH_TESTING
   FRIEND_TEST(PDPattern, AddEdge);
@@ -117,7 +134,7 @@ class PDPattern {
 };
 
 /*
- * GraphPatternDetecter helps to detect the specific patterns in the graph.
+ * GraphPatternDetector helps to detect the specific patterns in the graph.
  * Input a pattern, output a list of the matched subgraphs/nodes.
  * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
  *
@@ -129,7 +146,7 @@ class PDPattern {
  *
  * Usage:
  *    // Create a detector
- *    GraphPatternDetecter detector;
+ *    GraphPatternDetector detector;
  *    // Define the detector's pattern, by adding PDNode and define the edges.
  *    auto* node0 = detector.mutable_pattern().AddNode(...)
  *    auto* node1 = detector.mutable_pattern().AddNode(...)
@@ -138,11 +155,11 @@ class PDPattern {
  *    detector.mutable_pattern().AddEdge(node0, node1);
  *    // Create an handler, to define the behavior of treating the filtered
  *    // subgraphs that comply with the patterns.
- *    GraphPatternDetecter::handle_t handler = some labmda
+ *    GraphPatternDetector::handle_t handler = some labmda
  *    // Execute the detector.
  *    detector(&graph, handler);
  */
-class GraphPatternDetecter {
+class GraphPatternDetector {
  public:
   using subgraph_t = std::unordered_map<PDNode*, Node*>;
 
@@ -177,10 +194,62 @@ class GraphPatternDetecter {
   using hit_rcd_t =
       std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
   PDPattern pattern_;
-  std::vector<hit_rcd_t> marked_records_;
   std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_;
 };
 
+// some helper methods.
+
+// Op's input.
+static bool VarLinksToOp(Node* node, const std::string& op_type) {
+  for (auto* out : node->outputs) {
+    if (out->IsOp() && out->Op()->Type() == op_type) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Op's output.
+static bool VarLinksFromOp(Node* node, const std::string& op_type) {
+  for (auto* out : node->inputs) {
+    if (out->IsOp() && out->Op()->Type() == op_type) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Check whether a var node is a op node's nth input.
+static bool IsNthInput(Node* var, Node* op, const std::string& argument,
+                       size_t nth) {
+  PADDLE_ENFORCE(var->IsVar());
+  PADDLE_ENFORCE(op->IsOp());
+  if (op->inputs.size() <= nth) return false;
+  return var->Name() == op->Op()->Input(argument)[nth];
+}
+
+static void GraphSafeRemoveNodes(Graph* graph,
+                                 const std::unordered_set<const Node*>& nodes) {
+  for (auto* node : nodes) {
+    graph->RemoveNode(const_cast<Node*>(node));
+  }
+
+  for (auto* node : graph->Nodes()) {
+    for (auto it = node->inputs.begin(); it != node->inputs.end();) {
+      if (nodes.count(*it)) {
+        it = const_cast<Node*>(node)->inputs.erase(it);
+      } else
+        it++;
+    }
+    for (auto it = node->outputs.begin(); it != node->outputs.end();) {
+      if (nodes.count(*it)) {
+        it = const_cast<Node*>(node)->outputs.erase(it);
+      } else
+        it++;
+    }
+  }
+}
+
 }  // namespace ir
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
similarity index 95%
rename from paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
rename to paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
index 06f9df5546..a4d0646230 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
 
 #include <gtest/gtest.h>
 
@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
 }
 
 TEST(GraphPatternDetecter, MarkPDNodesInGraph) {
-  GraphPatternDetecter x;
+  GraphPatternDetector x;
   // mark o2, o3, v2
 
   // The pattern is a graph:
@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
   Graph graph(program);
   BuildGraph(&graph);
 
-  GraphPatternDetecter x;
+  GraphPatternDetector x;
 
   // The pattern is a graph:
   //   op -> var
@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
   x.mutable_pattern()->AddEdge(any_var, any_op1);
 
   int count = 0;
-  GraphPatternDetecter::handle_t handle = [&](
-      const GraphPatternDetecter::subgraph_t& s, Graph* g) {
+  GraphPatternDetector::handle_t handle = [&](
+      const GraphPatternDetector::subgraph_t& s, Graph* g) {
     LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> "
               << s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name();
     count++;
diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.cc b/paddle/fluid/framework/ir/graph_to_program_pass.cc
new file mode 100644
index 0000000000..414d8f79b1
--- /dev/null
+++ b/paddle/fluid/framework/ir/graph_to_program_pass.cc
@@ -0,0 +1,65 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/graph_helper.h"
+
+#include "paddle/fluid/framework/program_desc.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
+    std::unique_ptr<Graph> graph) const {
+  ProgramDesc& program = Get<ProgramDesc>("program");
+
+  std::unique_ptr<proto::ProgramDesc> program_pb(
+      new proto::ProgramDesc(*program.Proto()));
+
+  auto block = program_pb->mutable_blocks(kRootBlockIndex);
+  block->clear_vars();
+  std::unordered_set<std::string> visited_vars;
+  for (ir::Node* n : graph->Nodes()) {
+    if (n->NodeType() == ir::Node::Type::kVariable) {
+      if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
+        visited_vars.insert(n->Var()->Name());
+        block->add_vars()->MergeFrom(*n->Var()->Proto());
+      }
+    }
+  }
+
+  block->clear_ops();
+  std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
+  for (ir::Node* n : nodes) {
+    if (!n->Op()) {
+      continue;
+    }
+    block->add_ops()->MergeFrom(*n->Op()->Proto());
+  }
+
+  program.CopyFrom(*program_pb);
+  return graph;
+}
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(graph_to_program_pass, paddle::framework::ir::GraphToProgramPass);
diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.h b/paddle/fluid/framework/ir/graph_to_program_pass.h
new file mode 100644
index 0000000000..124ec5a8e7
--- /dev/null
+++ b/paddle/fluid/framework/ir/graph_to_program_pass.h
@@ -0,0 +1,30 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "paddle/fluid/framework/ir/pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+class GraphToProgramPass : public Pass {
+ protected:
+  std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc
new file mode 100644
index 0000000000..88ad17a0c6
--- /dev/null
+++ b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc
@@ -0,0 +1,110 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
+
+#include <string>
+#include <vector>
+#include "gtest/gtest.h"
+#include "paddle/fluid/framework/program_desc.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+void BuildNoCircleGraph(Graph* g) {
+  OpDesc op1;
+  op1.SetType("op1");
+  OpDesc op2;
+  op2.SetType("op2");
+  OpDesc op3;
+  op3.SetType("op3");
+  OpDesc op4;
+  op4.SetType("op4");
+  OpDesc op5;
+  op5.SetType("op5");
+  VarDesc var1("var1");
+  VarDesc var2("var2");
+  VarDesc var3("var3");
+  VarDesc var4("var4");
+
+  ir::Node* o1 = g->CreateOpNode(&op1);
+  ir::Node* o2 = g->CreateOpNode(&op2);
+  ir::Node* o3 = g->CreateOpNode(&op3);
+  ir::Node* o4 = g->CreateOpNode(&op4);
+  ir::Node* o5 = g->CreateOpNode(&op5);
+  ir::Node* v1 = g->CreateVarNode(&var1);
+  ir::Node* v2 = g->CreateVarNode(&var2);
+  ir::Node* v3 = g->CreateVarNode(&var3);
+  ir::Node* v4 = g->CreateVarNode(&var4);
+
+  // o1->v1->o2
+  o1->outputs.push_back(v1);
+  o2->inputs.push_back(v1);
+  v1->inputs.push_back(o1);
+  v1->outputs.push_back(o2);
+  // o2->v2->o3
+  // o2->v2->o4
+  o2->outputs.push_back(v2);
+  o3->inputs.push_back(v2);
+  o4->inputs.push_back(v2);
+  v2->outputs.push_back(o3);
+  v2->outputs.push_back(o4);
+  v2->inputs.push_back(o2);
+  // o2->v3->o5
+  o2->outputs.push_back(v3);
+  o5->inputs.push_back(v3);
+  v3->inputs.push_back(o2);
+  v3->outputs.push_back(o5);
+  // o3-v4->o5
+  o3->outputs.push_back(v4);
+  o5->inputs.push_back(v4);
+  v4->inputs.push_back(o3);
+  v4->outputs.push_back(o5);
+}
+
+TEST(GraphToProgramPass, Basic) {
+  ProgramDesc prog;
+  std::unique_ptr<Graph> g(new Graph(prog));
+  BuildNoCircleGraph(g.get());
+
+  auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
+      "graph_to_program_pass");
+
+  ProgramDesc compiled_prog;
+  pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
+  pass->Apply(std::move(g));
+  std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
+  EXPECT_EQ(ops[0]->Type(), "op1");
+  EXPECT_EQ(ops[1]->Type(), "op2");
+  if (ops[2]->Type() == "op3") {
+    EXPECT_EQ(ops[3]->Type(), "op4");
+  } else if (ops[2]->Type() == "op4") {
+    EXPECT_EQ(ops[3]->Type(), "op3");
+  }
+  EXPECT_EQ(ops[4]->Type(), "op5");
+
+  std::unordered_set<std::string> vars;
+  for (VarDesc* v : compiled_prog.Block(0).AllVars()) {
+    vars.insert(v->Name());
+  }
+  EXPECT_TRUE(vars.find("var1") != vars.end());
+  EXPECT_TRUE(vars.find("var2") != vars.end());
+  EXPECT_TRUE(vars.find("var3") != vars.end());
+}
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+USE_PASS(graph_to_program_pass);
diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc
index e7ff0c1dac..3a114c6a23 100644
--- a/paddle/fluid/framework/ir/graph_viz_pass.cc
+++ b/paddle/fluid/framework/ir/graph_viz_pass.cc
@@ -16,11 +16,13 @@ limitations under the License. */
 #include <unordered_set>
 
 #include "paddle/fluid/framework/ir/graph_viz_pass.h"
+#include "paddle/fluid/inference/analysis/dot.h"
 
 namespace paddle {
 namespace framework {
 namespace ir {
 static const char kGraphVizPath[] = "graph_viz_path";
+using inference::analysis::Dot;
 
 std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
   PADDLE_ENFORCE(fout->good());
   std::ostream& sout = *fout;
 
-  size_t var_id = 0;
-  std::unordered_map<const ir::Node*, size_t> vars;
-
-  sout << "digraph G {\n";
-
-  for (const ir::Node* n : graph->Nodes()) {
-    if (n->NodeType() != ir::Node::Type::kVariable) continue;
-    size_t cur_var_id = var_id++;
-    vars[n] = cur_var_id;
-
-    sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]"
-         << std::endl;
-  }
-
-  size_t op_id = 0;
-  for (const ir::Node* n : graph->Nodes()) {
-    if (n->NodeType() != ir::Node::Type::kOperation) continue;
-    std::string op_name = "op_" + std::to_string(op_id++);
-    sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]"
-         << std::endl;
-    for (auto in : n->inputs) {
-      std::string var_name = "var_" + std::to_string(vars[in]);
-      sout << var_name << " -> " << op_name << std::endl;
+  std::unordered_map<const ir::Node*, std::string> node2dot;
+
+  Dot dot;
+
+  std::vector<Dot::Attr> op_attrs({Dot::Attr("style", "filled"),
+                                   Dot::Attr("shape", "box"),
+                                   Dot::Attr("fillcolor", "red")});
+  std::vector<Dot::Attr> var_attrs({Dot::Attr("style", "filled,rounded"),
+                                    // Dot::Attr("shape", "diamond"),
+                                    Dot::Attr("fillcolor", "yellow")});
+
+  std::vector<Dot::Attr> marked_op_attrs({Dot::Attr("style", "filled"),
+                                          Dot::Attr("shape", "box"),
+                                          Dot::Attr("fillcolor", "lightgray")});
+  std::vector<Dot::Attr> marked_var_attrs(
+      {Dot::Attr("style", "filled,rounded"),
+       // Dot::Attr("shape", "diamond"),
+       Dot::Attr("fillcolor", "lightgray")});
+
+  auto marked_nodes = ConsumeMarkedNodes(graph.get());
+  // Create nodes
+  for (const Node* n : graph->Nodes()) {
+    std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")";
+    if (n->IsOp()) {
+      decltype(op_attrs) attr =
+          marked_nodes.count(n) ? marked_op_attrs : op_attrs;
+      dot.AddNode(node_id, attr, node_id);
+    } else if (n->IsVar()) {
+      decltype(op_attrs) attr =
+          marked_nodes.count(n) ? marked_var_attrs : var_attrs;
+      dot.AddNode(node_id, attr, node_id);
     }
-
-    for (auto out : n->outputs) {
-      std::string var_name = "var_" + std::to_string(vars[out]);
-      sout << op_name << " -> " << var_name << std::endl;
+    node2dot[n] = node_id;
+  }
+  // Create edges
+  for (const Node* n : graph->Nodes()) {
+    const auto& src_id = node2dot.at(n);
+    for (auto* out : n->outputs) {
+      const auto& trg_id = node2dot.at(out);
+      dot.AddEdge(src_id, trg_id, {});
     }
   }
 
-  sout << "}\n";
+  sout << dot.Build();
+
   return graph;
 }
 
+GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
+    Graph* graph) const {
+  marked_nodes_t res;
+  if (graph->Has(kGraphvizMarkedNodeAttr)) {
+    auto& attr = graph->Get<marked_nodes_t>(kGraphvizMarkedNodeAttr);
+    res = attr;
+    attr.clear();
+  }
+  return res;
+}
+
 }  // namespace ir
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h
index 1fd8c8a26e..8d885cb9e4 100644
--- a/paddle/fluid/framework/ir/graph_viz_pass.h
+++ b/paddle/fluid/framework/ir/graph_viz_pass.h
@@ -27,10 +27,19 @@ namespace paddle {
 namespace framework {
 namespace ir {
 
+const char kGraphvizMarkedNodeAttr[] = "__graphviz__marked_node__";
+
 class GraphVizPass : public Pass {
+ public:
+  using marked_nodes_t = std::unordered_set<const Node*>;
+
  protected:
   std::unique_ptr<ir::Graph> ApplyImpl(
       std::unique_ptr<ir::Graph> graph) const override;
+
+  // Tell whether there are any marked nodes in the graph. Consume the
+  // corresponding attribute.
+  marked_nodes_t ConsumeMarkedNodes(Graph* graph) const;
 };
 
 }  // namespace ir
diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h
index aab3180e7e..6d40e38522 100644
--- a/paddle/fluid/framework/ir/node.h
+++ b/paddle/fluid/framework/ir/node.h
@@ -29,20 +29,26 @@ class Node {
   enum class Type { kOperation, kVariable };
   static constexpr char kControlDepVarName[] = "__control_var";
 
-  explicit Node(const std::string& name, Type type)
-      : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
+  explicit Node(const std::string& name, Type type, int id = -1)
+      : name_(name),
+        var_desc_(nullptr),
+        op_desc_(nullptr),
+        type_(type),
+        id_(id) {}
 
-  explicit Node(VarDesc* var_desc)
+  explicit Node(VarDesc* var_desc, int id = -1)
       : name_(var_desc->Name()),
         var_desc_(new VarDesc(*var_desc)),
         op_desc_(nullptr),
-        type_(Type::kVariable) {}
+        type_(Type::kVariable),
+        id_(id) {}
 
-  explicit Node(OpDesc* op_desc)
+  explicit Node(OpDesc* op_desc, int id = -1)
       : name_(op_desc->Type()),
         var_desc_(nullptr),
         op_desc_(new OpDesc(*op_desc, op_desc->Block())),
-        type_(Type::kOperation) {}
+        type_(Type::kOperation),
+        id_(id) {}
 
   Type NodeType() const { return type_; }
 
@@ -58,6 +64,8 @@ class Node {
     return op_desc_.get();
   }
 
+  int id() const { return id_; }
+
   bool IsOp() const { return type_ == Type::kOperation; }
   bool IsVar() const { return type_ == Type::kVariable; }
 
@@ -69,6 +77,7 @@ class Node {
   std::unique_ptr<VarDesc> var_desc_;
   std::unique_ptr<OpDesc> op_desc_;
   Type type_;
+  int id_;
 
  private:
   DISABLE_COPY_AND_ASSIGN(Node);
diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
new file mode 100644
index 0000000000..9bb5c232e5
--- /dev/null
+++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
@@ -0,0 +1,256 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
+#include "paddle/fluid/framework/ir/graph_viz_pass.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+struct FuseExpr {};
+
+// sequence expand, concat fuse pattern, return concat's output
+PDNode* BuildSeqExpandConcatPattern(PDPattern* pattern) {
+  // The following operators will be fused:
+  // concat
+  // sequence_expand
+  // sequence_expand
+
+  // The following variables will be treat as inputs:
+  // concat mid input, 0th input for fused op
+  // sequence_expand input, 1th input for fused op
+  // sequence_expand input, 2th input for fused op
+
+  // The following variables will be treat as outputs:
+  // concat output
+
+  // So the following variables will be removed:
+  // sequence-expand output
+  // sequence-expand output
+
+  // Three operators
+  auto* sequence_expand0 = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
+      },
+      "sequence_expand0");
+
+  auto* sequence_expand1 = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
+      },
+      "sequence_expand1");
+
+  auto* concat = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsOp() && x->Op()->Type() == "concat" &&  // basic check
+               x->Op()->Input("X").size() == 3;                  // Special case
+      },
+      "concat");
+
+  auto* sequence_expand0_in = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
+      },
+      "sequence_expand0_in");
+  auto* sequence_expand1_in = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
+      },
+      "sequence_expand1_in");
+
+  // The variables
+  auto* sequence_expand0_out = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() &&
+               VarLinksFromOp(x, "sequence_expand") &&  // basic check
+               VarLinksToOp(x, "concat") &&             // is concat's input
+               IsNthInput(x, x->outputs[0], "X", 1);    // X[0]
+      },
+      "sequence_expand0_out");
+
+  auto* sequence_expand1_out = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() &&
+               VarLinksFromOp(x, "sequence_expand") &&  // basic check
+               VarLinksToOp(x, "concat") &&             // is concat's input
+               IsNthInput(x, x->outputs[0], "X", 2);    // x[2]
+      },
+      "sequence_expand1_out");
+
+  auto* concat_in0 = pattern->NewNode(
+      [](Node* x) { return x && x->IsVar() && VarLinksToOp(x, "concat"); },
+      "concat_in0");
+
+  auto* concat_out = pattern->NewNode(
+      [](Node* x) { return x && x->IsVar() && VarLinksFromOp(x, "concat"); },
+      "concat_out");
+
+  // Links
+  sequence_expand0->LinksFrom({sequence_expand0_in})
+      .LinksTo({sequence_expand0_out});
+  sequence_expand1->LinksFrom({sequence_expand1_in})
+      .LinksTo({sequence_expand1_out});
+  concat->LinksFrom({sequence_expand0_out, sequence_expand1_out, concat_in0})
+      .LinksTo({concat_out});
+  return concat_out;
+}
+
+PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
+  PDNode* fc_w = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() &&                 // basic
+               VarLinksToOp(x, "mul") &&          // link
+               x->Var()->Proto()->persistable();  // is a parameter
+      },
+      "fc_w");
+
+  PDNode* mul_out = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() &&                     // basic
+               VarLinksFromOp(x, "mul") &&            // link
+               VarLinksToOp(x, "elementwise_add") &&  //
+               !x->Var()->Proto()->persistable();     // is a parameter
+      },
+      "mul_out");
+
+  PDNode* fc_mul = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsOp() && x->Op()->Type() == "mul";  // basic
+      },
+      "fc_mul");
+
+  PDNode* fc_bias = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() &&                     // basic
+               VarLinksToOp(x, "elementwise_add") &&  // link
+               x->Var()->Proto()->persistable();      // is a parameter
+      },
+      "fc_bias");
+
+  PDNode* elementwise_add = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsOp() && x->Op()->Type() == "elementwise_add";
+      },
+      "elementwise_add");
+
+  PDNode* add_out = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() &&                       // basic
+               VarLinksFromOp(x, "elementwise_add") &&  // link
+               !x->Var()->Proto()->persistable();       // is a parameter
+      },
+      "add_out");
+
+  std::set<std::string> acts({"sigmoid", "tanh", "relu", "identity"});
+  PDNode* act = pattern->NewNode(
+      [=](Node* x) {
+        return x && x->IsOp() && acts.count(x->Op()->Type());
+
+      },
+      "act");
+
+  PDNode* fc_out = pattern->NewNode(
+      [](Node* x) {
+        return x && x->IsVar() &&                  // basic
+               !x->Var()->Proto()->persistable();  // is a parameter
+      },
+      "fc_out");
+
+  fc_mul->LinksFrom({fc_w, fc_x}).LinksTo({mul_out});
+  elementwise_add->LinksFrom({mul_out, fc_bias}).LinksTo({add_out});
+  act->LinksFrom({add_out}).LinksTo({fc_out});
+  return fc_out;
+}
+
+std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
+    std::unique_ptr<ir::Graph> graph) const {
+  FusePassBase::Init(graph.get());
+  GraphPatternDetector detector;
+  auto* pattern = detector.mutable_pattern();
+  auto* concat_out = BuildSeqExpandConcatPattern(pattern);
+  BuildFCPattern(pattern, concat_out);
+
+#define GET_NODE(id, pattern)                              \
+  PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
+                 "pattern has no Node called %s", #id);    \
+  auto* id = subgraph.at(pattern.RetriveNode(#id));        \
+  PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
+
+  detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
+                            Graph* graph) {
+    VLOG(4) << "get one concat pattern";
+    // fc
+    GET_NODE(fc_w, detector.pattern());
+    GET_NODE(fc_bias, detector.pattern());
+    GET_NODE(act, detector.pattern());
+    GET_NODE(fc_out, detector.pattern());
+
+    // concat
+    GET_NODE(concat_in0, detector.pattern());
+    GET_NODE(sequence_expand0_in, detector.pattern());
+    GET_NODE(sequence_expand1_in, detector.pattern());
+
+    OpDesc op_desc;
+    op_desc.SetType("fusion_seqexpand_concat_fc");
+    op_desc.SetInput("X", {concat_in0->Name(), sequence_expand0_in->Name(),
+                           sequence_expand1_in->Name()});
+    op_desc.SetInput("FCWeight", {fc_w->Name()});
+    op_desc.SetInput("FCBias", {fc_bias->Name()});
+    const std::string fc_out_tmp = fc_out->Name() + ".tmp";
+    param_scope()->Var(fc_out_tmp)->GetMutable<framework::LoDTensor>();
+    op_desc.SetOutput("FCOut", {fc_out_tmp});
+    op_desc.SetOutput("Out", {fc_out->Name()});
+    op_desc.SetAttr("fc_activation", act->Op()->Type());
+
+    auto* op_node = graph->CreateOpNode(&op_desc);
+// Add links
+#define NODE_LINKS(a, b)   \
+  a->outputs.push_back(b); \
+  b->inputs.push_back(a);
+    NODE_LINKS(fc_w, op_node);
+    NODE_LINKS(fc_bias, op_node);
+    NODE_LINKS(concat_in0, op_node);
+    NODE_LINKS(sequence_expand0_in, op_node);
+    NODE_LINKS(sequence_expand1_in, op_node);
+    NODE_LINKS(op_node, fc_out);
+
+    // Clean nodes.
+    std::unordered_set<const Node*> marked_nodes;
+    for (auto& item : subgraph) {
+      marked_nodes.insert(item.second);
+    }
+    marked_nodes.erase(fc_w);
+    marked_nodes.erase(fc_bias);
+    marked_nodes.erase(concat_in0);
+    marked_nodes.erase(sequence_expand0_in);
+    marked_nodes.erase(sequence_expand1_in);
+    marked_nodes.erase(fc_out);
+
+    GraphSafeRemoveNodes(graph, marked_nodes);
+  });
+
+  return graph;
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(seq_concat_fc_fuse_pass,
+              paddle::framework::ir::SeqConcatFcFusePass);
diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
new file mode 100644
index 0000000000..9f5fd1a29a
--- /dev/null
+++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
@@ -0,0 +1,33 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+class SeqConcatFcFusePass : public FusePassBase {
+ public:
+  virtual ~SeqConcatFcFusePass() {}
+
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc
index 122dc161b4..555faba962 100644
--- a/paddle/fluid/framework/op_desc.cc
+++ b/paddle/fluid/framework/op_desc.cc
@@ -95,6 +95,12 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
   need_update_ = true;
 }
 
+OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
+  CopyFrom(other);
+  block_ = block;
+  need_update_ = true;
+}
+
 void OpDesc::CopyFrom(const OpDesc &op_desc) {
   desc_.set_type(op_desc.Type());
   inputs_ = op_desc.inputs_;
@@ -131,8 +137,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
   for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
     std::string attr_name = attr.name();
     // The sub_block referred to by the BLOCK attr hasn't been added
-    // to ProgramDesc class yet, we skip setting BLOCK attr here.
-    if (attr.type() != proto::AttrType::BLOCK) {
+    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS attr here.
+    if (attr.type() != proto::AttrType::BLOCK &&
+        attr.type() != proto::AttrType::BLOCKS) {
       attrs_[attr_name] = GetAttrValue(attr);
     }
   }
diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h
index 2422392e24..b4205aba83 100644
--- a/paddle/fluid/framework/op_desc.h
+++ b/paddle/fluid/framework/op_desc.h
@@ -37,11 +37,7 @@ class OpDesc {
 
   explicit OpDesc(BlockDesc *block) : block_(block) {}
 
-  OpDesc(const OpDesc &other, BlockDesc *block) {
-    *this = other;
-    block_ = block;
-    need_update_ = true;
-  }
+  OpDesc(const OpDesc &other, BlockDesc *block);
 
   void CopyFrom(const OpDesc &op_desc);
 
diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc
index 344c001a69..a63944eaee 100644
--- a/paddle/fluid/framework/program_desc.cc
+++ b/paddle/fluid/framework/program_desc.cc
@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
   InitFromProto();
 }
 
+void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
+  blocks_.clear();
+  desc_ = desc;
+  InitFromProto();
+}
+
 ProgramDesc::ProgramDesc(const std::string &binary_str) {
   PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
                  "Fail to parse program_desc from binary string.");
@@ -111,10 +117,16 @@ void ProgramDesc::InitFromProto() {
 
 const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
   auto &global_block = Block(0);
+  // The order of feed_target_names must follow the index specified in `col`.
+  // since feed operator's order doesn't necessary follow 'col'.
   std::vector<std::string> feed_target_names;
   for (auto *op : global_block.AllOps()) {
     if (op->Type() == kFeedOpType) {
-      feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]);
+      int col = boost::get<int>(op->GetAttr("col"));
+      if (col >= feed_target_names.size()) {
+        feed_target_names.resize(col + 1);
+      }
+      feed_target_names[col] = op->Output("Out")[0];
     }
   }
   return feed_target_names;
@@ -122,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
 
 const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
   auto &global_block = Block(0);
+  // The order of fetch_target_names must follow the index specified in `col`.
+  // since fetch operator's order doesn't necessary follow 'col'.
   std::vector<std::string> fetch_target_names;
   for (auto *op : global_block.AllOps()) {
     if (op->Type() == kFetchOpType) {
-      fetch_target_names.push_back(op->Input("X")[0]);
+      int col = boost::get<int>(op->GetAttr("col"));
+      if (col >= fetch_target_names.size()) {
+        fetch_target_names.resize(col + 1);
+      }
+      fetch_target_names[col] = op->Input("X")[0];
     }
   }
   return fetch_target_names;
diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h
index f3afc85eb9..a0e81cade1 100644
--- a/paddle/fluid/framework/program_desc.h
+++ b/paddle/fluid/framework/program_desc.h
@@ -53,6 +53,8 @@ class ProgramDesc {
 
   void Flush();
 
+  void CopyFrom(const proto::ProgramDesc &desc);
+
   proto::ProgramDesc *Proto();
 
   // The output variable of feed_op is referenced as feed_target.
diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt
index ba7645aa02..a4f6364ae5 100644
--- a/paddle/fluid/inference/CMakeLists.txt
+++ b/paddle/fluid/inference/CMakeLists.txt
@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
 # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
 cc_library(paddle_fluid_api
     SRCS io.cc
-    DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
+    DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} graph_to_program_pass)
 
 get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
 
diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt
index 4feaed2b0d..779ede5e46 100644
--- a/paddle/fluid/inference/analysis/CMakeLists.txt
+++ b/paddle/fluid/inference/analysis/CMakeLists.txt
@@ -1,5 +1,8 @@
 cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass)
-cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
+set(analysis_deps
+    framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor)
+
+cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
   analyzer.cc
   helper.cc
   # passes
@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
   tensorrt_subgraph_node_mark_pass.cc
   fluid_to_ir_pass.cc
   model_store_pass.cc
-  DEPS framework_proto proto_desc ir_pass_manager graph pass)
+  DEPS ${analysis_deps})
 
 cc_test(test_node SRCS node_tester.cc DEPS analysis)
 cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
-cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
+cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid)
 
 set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
 
@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
         endif()
         cc_test(${TARGET}
                 SRCS "${analysis_test_SRCS}"
-                DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS}
+                DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS}
                 ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
         set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
     endif(WITH_TESTING)
@@ -58,20 +61,25 @@ endif()
 
 inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
     EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
+    analysis_predictor
 		# ir
 		fc_fuse_pass
+		fc_lstm_fuse_pass
+    seq_concat_fc_fuse_pass
 		graph_viz_pass
 		infer_clean_graph_pass
-		graph_pattern_detecter
-        infer_clean_graph_pass
+		graph_pattern_detector
+    infer_clean_graph_pass
+    attention_lstm_fuse_pass
+    paddle_inference_api
 		pass
     ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
         --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
         --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
 
 inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
-inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
-inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc)
+inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc EXTRA_DEPS paddle_inference_api)
+inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc EXTRA_DEPS paddle_fluid)
 inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
 inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
 inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc
index 0d94ccb64e..05b606cd0f 100644
--- a/paddle/fluid/inference/analysis/analyzer.cc
+++ b/paddle/fluid/inference/analysis/analyzer.cc
@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
 Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
 
 void Analyzer::Run(Argument* argument) {
+  // Ungly support fluid-to-ir-pass
+  argument->Set(kFluidToIrPassesAttr,
+                new std::vector<std::string>({
+                    // Manual update the passes here.
+                    "graph_viz_pass",                              //
+                    "infer_clean_graph_pass", "graph_viz_pass",    //
+                    "attention_lstm_fuse_pass", "graph_viz_pass",  //
+                    "fc_lstm_fuse_pass", "graph_viz_pass",         //
+                    "seq_concat_fc_fuse_pass", "graph_viz_pass",   //
+                    "fc_fuse_pass", "graph_viz_pass"               //
+
+                }));
+
   for (auto& x : data_) {
     PADDLE_ENFORCE(x->Initialize(argument));
     x->RunAll();
diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc
index baa7600283..263fbb0449 100644
--- a/paddle/fluid/inference/analysis/analyzer_tester.cc
+++ b/paddle/fluid/inference/analysis/analyzer_tester.cc
@@ -20,6 +20,7 @@
 #include "paddle/fluid/inference/analysis/ut_helper.h"
 #include "paddle/fluid/inference/api/helper.h"
 #include "paddle/fluid/inference/api/paddle_inference_api.h"
+#include "paddle/fluid/platform/profiler.h"
 
 DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
 DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN");
@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
                            const std::string &data_path, int batch_size,
                            bool use_analysis, bool activate_ir,
                            int num_times = 1) {
-  FLAGS_IA_enable_ir = activate_ir;
-  FLAGS_IA_enable_tensorrt_subgraph_engine = false;
-  FLAGS_IA_output_storage_path = "./analysis.out";
-
-  std::string model_out;
-  if (use_analysis) {
-    Argument argument(model_path);
-    argument.model_output_store_path.reset(new std::string("./analysis.out"));
-
-    Analyzer analyzer;
-    analyzer.Run(&argument);
-
-    // Should get the transformed model stored to ./analysis.out
-    model_out = "./analysis.out";
-    ASSERT_TRUE(PathExists(model_out));
-  } else {
-    model_out = FLAGS_infer_ditu_rnn_model;
-  }
-
   NativeConfig config;
-  config.prog_file = model_out + "/__model__";
-  config.param_file = model_out + "/param";
+  config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
+  config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
   config.use_gpu = false;
   config.device = 0;
   config.specify_input_name = true;
 
-  auto predictor =
+  auto base_predictor =
       CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
+  auto predictor =
+      CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(config);
   std::vector<PaddleTensor> input_slots;
   DataRecord data(data_path, batch_size);
   // Prepare inputs.
   PrepareInputs(&input_slots, &data, batch_size);
-  std::vector<PaddleTensor> outputs;
+  std::vector<PaddleTensor> outputs, base_outputs;
+
+  base_predictor->Run(input_slots, &base_outputs);
 
   Timer timer;
   timer.tic();
@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
             << ", latency: " << timer.toc() / num_times << "ms";
   LOG(INFO) << "=====================================";
 
-  for (auto &out : outputs) {
+  PADDLE_ENFORCE_GT(outputs.size(), 0);
+  PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
+  for (size_t i = 0; i < outputs.size(); i++) {
+    auto &out = outputs[i];
+    auto &base_out = base_outputs[i];
     size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
                                   [](int a, int b) { return a * b; });
+    size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
+                                   1, [](int a, int b) { return a * b; });
+    PADDLE_ENFORCE_EQ(size, size1);
+    PADDLE_ENFORCE_GT(size, 0);
     float *data = static_cast<float *>(out.data.data());
-    for (size_t i = 0;
-         i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size);
-         i++) {
-      EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3);
+    float *base_data = static_cast<float *>(base_out.data.data());
+    for (size_t i = 0; i < size; i++) {
+      EXPECT_NEAR(data[i], base_data[i], 1e-3);
     }
   }
 }
 
-// Turn on the IR pass supportion, run a real inference and check the result.
-TEST(Analyzer, SupportIRPass) {
-  FLAGS_IA_enable_ir = true;
-  FLAGS_IA_enable_tensorrt_subgraph_engine = false;
-  FLAGS_IA_output_storage_path = "./analysis.out";
-
-  Argument argument(FLAGS_inference_model_dir);
-  argument.model_output_store_path.reset(new std::string("./analysis.out"));
-
-  Analyzer analyzer;
-  analyzer.Run(&argument);
-
-  // Should get the transformed model stored to ./analysis.out
-  ASSERT_TRUE(PathExists("./analysis.out"));
-
-  // Inference from this path.
-  TestWord2vecPrediction("./analysis.out");
-}
-
 // Directly infer with the original model.
 TEST(Analyzer, DituRNN_without_analysis) {
   TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
 }  // namespace paddle
 
 USE_PASS(fc_fuse_pass);
+USE_PASS(seq_concat_fc_fuse_pass);
+USE_PASS(fc_lstm_fuse_pass);
 USE_PASS(graph_viz_pass);
 USE_PASS(infer_clean_graph_pass);
+USE_PASS(attention_lstm_fuse_pass);
diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h
index a17d6281a2..4401d5c5a3 100644
--- a/paddle/fluid/inference/analysis/argument.h
+++ b/paddle/fluid/inference/analysis/argument.h
@@ -26,6 +26,7 @@
 #include <string>
 #include "paddle/fluid/framework/program_desc.h"
 #include "paddle/fluid/inference/analysis/data_flow_graph.h"
+#include "paddle/fluid/platform/variant.h"
 
 namespace paddle {
 namespace inference {
@@ -58,6 +59,46 @@ struct Argument {
 
   // The output storage path of ModelStorePass.
   std::unique_ptr<std::string> model_output_store_path;
+
+  // Support for any other attributes.
+  template <typename T>
+  void Set(const std::string& key, T* data) {
+    PADDLE_ENFORCE_NOT_NULL(data);
+    PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key);
+    attrs_[key] = data;
+    attr_deleters_[key] = [data, key, this]() {
+      VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
+      VLOG(3) << "argument delete attr: " << key;
+      delete data;
+    };
+  }
+
+  bool Has(const std::string& name) const { return attrs_.count(name); }
+
+  template <typename T>
+  T* Release(const std::string& key) {
+    PADDLE_ENFORCE(attrs_.count(key));
+    auto* res = boost::any_cast<T*>(attrs_.at(key));
+    attrs_.erase(key);
+    attr_deleters_.erase(key);
+    return res;
+  }
+
+  template <typename T>
+  T& Get(const std::string& key) {
+    PADDLE_ENFORCE(Has(key));
+    return *boost::any_cast<T*>(attrs_.at(key));
+  }
+
+  ~Argument() {
+    for (auto& item : attr_deleters_) {
+      item.second();
+    }
+  }
+
+ private:
+  std::unordered_map<std::string, boost::any> attrs_;
+  std::unordered_map<std::string, std::function<void()>> attr_deleters_;
 };
 
 #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
index 8c7dd146e4..8ca402da31 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
@@ -19,6 +19,7 @@
 #include "paddle/fluid/framework/proto_desc.h"
 #include "paddle/fluid/inference/analysis/analyzer.h"
 #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
+#include "paddle/fluid/inference/io.h"
 
 namespace paddle {
 namespace inference {
@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
     }
   }
 
+  if (argument_->Has("param_scope")) {
+    LOG(WARNING) << "parameter changes in the scope takes effect";
+  }
+
   PADDLE_ENFORCE(argument_->transformed_program_desc.get());
 }
 
diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h
index 4bf1840fdd..4693729cb4 100644
--- a/paddle/fluid/inference/analysis/dot.h
+++ b/paddle/fluid/inference/analysis/dot.h
@@ -29,13 +29,13 @@ namespace paddle {
 namespace inference {
 namespace analysis {
 
+static size_t dot_node_counter{0};
+
 /*
  * A Dot template that helps to build a DOT graph definition.
  */
 class Dot {
  public:
-  static size_t counter;
-
   struct Attr {
     std::string key;
     std::string value;
@@ -57,7 +57,7 @@ class Dot {
     Node(const std::string& name, const std::vector<Attr>& attrs)
         : name(name),
           attrs(attrs),
-          id_("node_" + std::to_string(Dot::counter++)) {}
+          id_("node_" + std::to_string(dot_node_counter++)) {}
 
     std::string id() const { return id_; }
 
@@ -65,6 +65,10 @@ class Dot {
       std::stringstream ss;
       CHECK(!name.empty());
       ss << id_;
+      if (attrs.empty()) {
+        ss << "[label=" << '"' << name << '"' << "]";
+        return ss.str();
+      }
       for (size_t i = 0; i < attrs.size(); i++) {
         if (i == 0) {
           ss << "[label=" << '"' << name << '"' << " ";
@@ -108,9 +112,11 @@ class Dot {
 
   explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
 
-  void AddNode(const std::string& name, const std::vector<Attr>& attrs) {
-    CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'";
-    nodes_.emplace(name, Node{name, attrs});
+  void AddNode(const std::string& id, const std::vector<Attr>& attrs,
+               std::string label = "") {
+    CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'";
+    if (label.empty()) label = id;
+    nodes_.emplace(id, Node{label, attrs});
   }
 
   void AddEdge(const std::string& source, const std::string& target,
diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
index 073f497528..5e53fff392 100644
--- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
+++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
@@ -13,3 +13,47 @@
 // limitations under the License.
 
 #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
+#include "paddle/fluid/framework/executor.h"
+#include "paddle/fluid/inference/io.h"
+#include "paddle/fluid/platform/device_context.h"
+#include "paddle/fluid/platform/place.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+void FluidToIrPass::EnableParamModify(const std::string &model_dir,
+                                      const std::string &prog_file,
+                                      const std::string &param_file) {
+  PADDLE_ENFORCE(argument_);
+  argument_->Set("param_scope", new framework::Scope);
+  // Load parameters.
+  VLOG(3) << "Loading parameters from " << model_dir;
+  LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir,
+             prog_file, param_file);
+}
+
+bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
+                               const std::string &prog_file,
+                               const std::string &param_file) {
+  platform::CPUPlace place;
+  platform::CPUDeviceContext ctx(place);
+  framework::Executor executor(place);
+  PADDLE_ENFORCE(argument_->origin_program_desc.get());
+  framework::ProgramDesc program(*argument_->origin_program_desc);
+  if ((!prog_file.empty()) && (!param_file.empty())) {
+    LOG(INFO) << "load single model file from " << prog_file;
+    Load(&executor, scope, prog_file, param_file);
+  } else if (!dir.empty()) {
+    LOG(INFO) << "load from dir " << dir;
+    Load(&executor, scope, dir);
+  } else {
+    LOG(ERROR) << "failed to load parameters";
+    return false;
+  }
+  return true;
+}
+
+}  // namespace analysis
+}  // namespace inference
+}  // namespace paddle
diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h
index fa3f8d313b..29008105f8 100644
--- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h
@@ -21,12 +21,17 @@ namespace paddle {
 namespace inference {
 namespace analysis {
 
+static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
+
 class FluidToIrPass final : public DataFlowGraphPass {
  public:
   FluidToIrPass() = default;
 
   bool Initialize(Argument *argument) override {
     ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
+    PADDLE_ENFORCE(argument->Has(kFluidToIrPassesAttr),
+                   "argument need the attr %s", kFluidToIrPassesAttr);
+    argument_ = argument;
     if (argument->origin_program_desc) {
       LOG(WARNING) << "argument's origin_program_desc is already set, might "
                       "duplicate called";
@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass {
     if (!argument->main_dfg) {
       argument->main_dfg.reset(new DataFlowGraph);
     }
-    // Persist the ProgramDesc in graph's attribute. The IR graph just keep the
-    // address, will segfault if the original ProgramDesc destroys.
-    auto &ir_program_p = argument->main_dfg->Attr("ir_program_desc").Pointer();
-    ir_program_p = new framework::ProgramDesc(program);
+    argument->Set("ir_program_desc", new framework::ProgramDesc(program));
+
+    LOG(INFO) << "Loading parameters";
+    // Load parameters to argument if needed.
+    if (argument->fluid_model_dir || (argument->fluid_model_program_path &&
+                                      argument->fluid_model_param_path)) {
+#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : "";
+      SAFE_GET(fluid_model_dir);
+      SAFE_GET(fluid_model_program_path);
+      SAFE_GET(fluid_model_param_path);
+#undef SAFE_GET
+      EnableParamModify(fluid_model_dir, fluid_model_program_path,
+                        fluid_model_param_path);
+    }
 
-    argument_ = argument;
     return true;
   }
 
@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass {
 
   void Run(DataFlowGraph *graph) override {
     // Call all the IR Passes
-    IRPassManager ir_passes(*static_cast<framework::ProgramDesc *>(
-        argument_->main_dfg->Attr("ir_program_desc").Pointer()));
-    ir_passes.Apply(std::vector<std::string>(
-        {// Manual update the passes here.
-         "graph_viz_pass", "infer_clean_graph_pass", "graph_viz_pass",
-         "fc_fuse_pass", "graph_viz_pass"}));
+    IRPassManager ir_passes(
+        argument_->Get<framework::ProgramDesc>("ir_program_desc"), nullptr);
+    // Pass the scope from analysis to IR if needed.
+    if (argument_->Has("param_scope")) {
+      // Here the address is passed, attention that IR doesn't own the scope, so
+      // the real scope in analysis should live during the IR phase.
+      ir_passes.graph().Set(
+          "param_scope", new framework::Scope *(
+                             &argument_->Get<framework::Scope>("param_scope")));
+    }
+
+    const auto &ir_passes_to_apply =
+        argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
+    ir_passes.Apply(ir_passes_to_apply);
 
     PADDLE_ENFORCE(argument_->main_dfg.get());
     argument_->main_dfg->Build(ir_passes.graph());
-    // PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
   }
 
+  void EnableParamModify(const std::string &model_dir,
+                         const std::string &prog_file,
+                         const std::string &param_file);
+
   std::string repr() const override { return "fluid-to-ir-pass"; }
 
+ private:
+  // Load parameters from a single file or from a directory.
+  bool LoadParams(framework::Scope *scope, const std::string &dir,
+                  const std::string &prog_file, const std::string &param_file);
+
  private:
   Argument *argument_{nullptr};
 };
diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
index af934f261b..6a13c60e7b 100644
--- a/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
+++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
@@ -24,6 +24,8 @@ namespace analysis {
 TEST(FluidToIrPass, Test) {
   FluidToIrPass pass;
   Argument argument(FLAGS_inference_model_dir);
+  argument.Set(kFluidToIrPassesAttr,
+               new std::vector<std::string>({"infer_clean_graph_pass"}));
   pass.Initialize(&argument);
   pass.Run(argument.main_dfg.get());
 }
@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) {
 }  // namespace inference
 }  // namespace paddle
 
-USE_PASS(fc_fuse_pass);
 USE_PASS(graph_viz_pass);
 USE_PASS(infer_clean_graph_pass);
+USE_PASS(attention_lstm_fuse_pass);
+USE_PASS(fc_lstm_fuse_pass);
+USE_PASS(seq_concat_fc_fuse_pass);
+USE_PASS(fc_fuse_pass);
diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc
index d849b637bc..5da5241e49 100644
--- a/paddle/fluid/inference/analysis/ir_pass_manager.cc
+++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc
@@ -14,20 +14,24 @@
 
 #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
 #include <string>
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/scope.h"
 
 namespace paddle {
 namespace inference {
 namespace analysis {
 
-IRPassManager::IRPassManager(const ProgramDesc& program) {
+IRPassManager::IRPassManager(const ProgramDesc &program,
+                             framework::Scope *scope)
+    : program_(program) {
   graph_.reset(new framework::ir::Graph(program));
+  if (scope) graph_->Set("param_scope", new framework::Scope *(scope));
 }
 
-void IRPassManager::Apply(const std::vector<std::string>& passes) {
-  graph_->Set("graph_viz_path", new std::string("./1.dot"));
+void IRPassManager::Apply(const std::vector<std::string> &passes) {
   // Apply all the passes
   std::string pre_pass;
-  for (const std::string& pass_name : passes) {
+  for (const std::string &pass_name : passes) {
     LOG(WARNING) << "Running IR pass [" << pass_name << "]";
     auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
     if (pass_name == "graph_viz_pass") {
diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.h b/paddle/fluid/inference/analysis/ir_pass_manager.h
index 3338e37ecf..bb230283b7 100644
--- a/paddle/fluid/inference/analysis/ir_pass_manager.h
+++ b/paddle/fluid/inference/analysis/ir_pass_manager.h
@@ -23,6 +23,7 @@
 #include "paddle/fluid/framework/ir/graph.h"
 #include "paddle/fluid/framework/ir/pass.h"
 #include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/framework/scope.h"
 
 namespace paddle {
 namespace inference {
@@ -31,14 +32,15 @@ using framework::ProgramDesc;
 
 class IRPassManager final {
  public:
-  IRPassManager(const ProgramDesc& program);
+  IRPassManager(const ProgramDesc &program, framework::Scope *scope);
 
-  void Apply(const std::vector<std::string>& passes);
+  void Apply(const std::vector<std::string> &passes);
 
-  framework::ir::Graph& graph() const { return *graph_; }
+  framework::ir::Graph &graph() const { return *graph_; }
 
  private:
   std::unique_ptr<framework::ir::Graph> graph_;
+  ProgramDesc program_;
 };
 
 }  // namespace analysis
diff --git a/paddle/fluid/inference/analysis/pass_manager.cc b/paddle/fluid/inference/analysis/pass_manager.cc
index cfdca33882..ff5ec94265 100644
--- a/paddle/fluid/inference/analysis/pass_manager.cc
+++ b/paddle/fluid/inference/analysis/pass_manager.cc
@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
 
 void DfgPassManager::RunAll() {
   PADDLE_ENFORCE(argument_);
-  LOG(INFO) << "Total " << data_.size() << " passes";
+  LOG(INFO) << "Total " << data_.size() << " Analysys passes";
   for (auto& pass : data_) {
-    LOG(WARNING) << "Running pass [" << pass->repr() << "]";
+    LOG(WARNING) << "Running Analysis pass [" << pass->repr() << "]";
     pass->Run(argument_->main_dfg.get());
   }
 }
diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt
index 0ca1af455c..adfe439244 100644
--- a/paddle/fluid/inference/api/CMakeLists.txt
+++ b/paddle/fluid/inference/api/CMakeLists.txt
@@ -20,7 +20,7 @@ endif(APPLE)
 
 set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager
   graph_viz_pass fc_fuse_pass
-    infer_clean_graph_pass
+  infer_clean_graph_pass
   )
 
 if(WITH_GPU AND TENSORRT_FOUND)
@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME)
     endif(WITH_TESTING)
 endfunction(inference_api_test)
 
-cc_library(paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor)
+cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor)
+cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api)
 
 cc_test(test_paddle_inference_api
         SRCS api_tester.cc
diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc
new file mode 100644
index 0000000000..0b29b23382
--- /dev/null
+++ b/paddle/fluid/inference/api/analysis_predictor.cc
@@ -0,0 +1,165 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include "paddle/fluid/framework/ir/pass.h"
+#include "paddle/fluid/framework/scope.h"
+#include "paddle/fluid/inference/analysis/analyzer.h"
+#include "paddle/fluid/inference/api/api_impl.h"
+#include "paddle/fluid/inference/api/paddle_inference_api.h"
+#include "paddle/fluid/inference/utils/singleton.h"
+
+namespace paddle {
+
+using inference::analysis::Argument;
+using inference::Singleton;
+using inference::analysis::Analyzer;
+using framework::proto::ProgramDesc;
+
+/* This predictor is based on the original native predictor with IR and Analysis
+ * support. It will optimize IR and Parameters in the runtime.
+ * TODO(Superjomn) Replace the Navive predictor?
+ */
+class AnalysisPredictor : public NativePaddlePredictor {
+ public:
+  explicit AnalysisPredictor(const NativeConfig& config)
+      : NativePaddlePredictor(config), config_(config) {}
+
+  bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
+    VLOG(3) << "Predictor::init()";
+    if (config_.use_gpu) {
+      place_ = paddle::platform::CUDAPlace(config_.device);
+    } else {
+      place_ = paddle::platform::CPUPlace();
+    }
+    PADDLE_ENFORCE(!parent_scope);
+    if (parent_scope) {
+      scope_ = parent_scope;
+      sub_scope_ = &(parent_scope->NewScope());
+    } else {
+      paddle::framework::InitDevices(false);
+      scope_.reset(new paddle::framework::Scope());
+    }
+
+    executor_.reset(new paddle::framework::Executor(place_));
+
+    // Initialize the inference program
+    if (!config_.model_dir.empty()) {
+      // Parameters are saved in separate files sited in
+      // the specified `dirname`.
+      inference_program_ = paddle::inference::Load(
+          executor_.get(), scope_.get(), config_.model_dir);
+    } else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
+      // All parameters are saved in a single file.
+      // The file names should be consistent with that used
+      // in Python API `fluid.io.save_inference_model`.
+      inference_program_ = paddle::inference::Load(
+          executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
+    } else {
+      LOG(ERROR) << "fail to load inference model.";
+      return false;
+    }
+
+    OptimizeInferenceProgram();
+    ctx_ = executor_->Prepare(*inference_program_, 0);
+
+    VLOG(5) << "to create variables";
+    PADDLE_ENFORCE(scope_.get());
+    executor_->CreateVariables(*inference_program_,
+                               sub_scope_ ? sub_scope_ : scope_.get(), 0);
+
+    // Get the feed_target_names and fetch_target_names
+    feed_target_names_ = inference_program_->GetFeedTargetNames();
+    fetch_target_names_ = inference_program_->GetFetchTargetNames();
+    return true;
+  }
+
+  bool Run(const std::vector<PaddleTensor>& inputs,
+           std::vector<PaddleTensor>* output_data,
+           int batch_size = -1) override {
+    return NativePaddlePredictor::Run(inputs, output_data, batch_size);
+  }
+
+  void OptimizeInferenceProgram() {
+    LOG(INFO) << "optimize begin";
+    FLAGS_IA_enable_ir = true;
+    FLAGS_IA_enable_tensorrt_subgraph_engine = false;
+    FLAGS_IA_output_storage_path = "";  // Don't output the model.
+    // Analyze inference_program
+    Argument argument;
+    if (!config_.model_dir.empty()) {
+      argument.fluid_model_dir.reset(new std::string(config_.model_dir));
+    } else {
+      PADDLE_ENFORCE(
+          !config_.param_file.empty(),
+          "Either model_dir or (param_file, prog_file) should be set.");
+      PADDLE_ENFORCE(!config_.prog_file.empty());
+      argument.fluid_model_program_path.reset(
+          new std::string(config_.prog_file));
+      argument.fluid_model_param_path.reset(
+          new std::string(config_.param_file));
+    }
+    argument.origin_program_desc.reset(
+        new ProgramDesc(*inference_program_->Proto()));
+    Singleton<Analyzer>::Global().Run(&argument);
+    CHECK(argument.transformed_program_desc);
+    VLOG(5) << "to prepare executor";
+    // LOG(INFO) << "transformed_parogram_desc " <<
+    // argument.transformed_program_desc->DebugString();
+    inference_program_.reset(
+        new framework::ProgramDesc(*argument.transformed_program_desc));
+    PADDLE_ENFORCE(argument.Has("param_scope"));
+    // Update scope.
+    scope_.reset(argument.Release<framework::Scope>("param_scope"));
+    LOG(INFO) << "optimize end ==";
+  }
+
+ private:
+  NativeConfig config_;
+};
+
+template <>
+std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
+    NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) {
+  VLOG(3) << "create NativePredictor";
+  if (config.use_gpu) {
+    // 1. GPU memeroy
+    PADDLE_ENFORCE_GT(
+        config.fraction_of_gpu_memory, 0.f,
+        "fraction_of_gpu_memory in the config should be set to range (0., 1.]");
+    PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
+    std::vector<std::string> flags;
+    if (config.fraction_of_gpu_memory >= 0.0f ||
+        config.fraction_of_gpu_memory <= 0.95f) {
+      flags.push_back("dummpy");
+      std::string flag = "--fraction_of_gpu_memory_to_use=" +
+                         std::to_string(config.fraction_of_gpu_memory);
+      flags.push_back(flag);
+      VLOG(3) << "set flag: " << flag;
+      framework::InitGflags(flags);
+    }
+  }
+
+  std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
+  if (!dynamic_cast<AnalysisPredictor*>(predictor.get())->Init(nullptr)) {
+    return nullptr;
+  }
+  return predictor;
+}
+
+}  // namespace paddle
+
+USE_PASS(fc_fuse_pass);
+USE_PASS(graph_viz_pass);
+USE_PASS(infer_clean_graph_pass);
diff --git a/paddle/fluid/inference/api/helper.cc b/paddle/fluid/inference/api/helper.cc
new file mode 100644
index 0000000000..9cc491e10d
--- /dev/null
+++ b/paddle/fluid/inference/api/helper.cc
@@ -0,0 +1,44 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/inference/api/helper.h"
+
+namespace paddle {
+namespace inference {
+
+template <>
+std::string to_string<std::vector<float>>(
+    const std::vector<std::vector<float>> &vec) {
+  std::stringstream ss;
+  for (const auto &piece : vec) {
+    ss << to_string(piece) << "\n";
+  }
+  return ss.str();
+}
+
+template <>
+std::string to_string<std::vector<std::vector<float>>>(
+    const std::vector<std::vector<std::vector<float>>> &vec) {
+  std::stringstream ss;
+  for (const auto &line : vec) {
+    for (const auto &rcd : line) {
+      ss << to_string(rcd) << ";\t";
+    }
+    ss << '\n';
+  }
+  return ss.str();
+}
+
+}  // namespace inference
+}  // namespace paddle
diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h
index 2c166cc062..e44b1b74bc 100644
--- a/paddle/fluid/inference/api/helper.h
+++ b/paddle/fluid/inference/api/helper.h
@@ -44,7 +44,8 @@ class Timer {
   }
 };
 
-void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
+static void split(const std::string &str, char sep,
+                  std::vector<std::string> *pieces) {
   pieces->clear();
   if (str.empty()) {
     return;
@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
     pieces->push_back(str.substr(pos));
   }
 }
-void split_to_float(const std::string &str, char sep, std::vector<float> *fs) {
+static void split_to_float(const std::string &str, char sep,
+                           std::vector<float> *fs) {
   std::vector<std::string> pieces;
   split(str, sep, &pieces);
   std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs),
@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) {
 }
 template <>
 std::string to_string<std::vector<float>>(
-    const std::vector<std::vector<float>> &vec) {
-  std::stringstream ss;
-  for (const auto &piece : vec) {
-    ss << to_string(piece) << "\n";
-  }
-  return ss.str();
-}
+    const std::vector<std::vector<float>> &vec);
+
 template <>
 std::string to_string<std::vector<std::vector<float>>>(
-    const std::vector<std::vector<std::vector<float>>> &vec) {
-  std::stringstream ss;
-  for (const auto &line : vec) {
-    for (const auto &rcd : line) {
-      ss << to_string(rcd) << ";\t";
-    }
-    ss << '\n';
-  }
-  return ss.str();
-}
+    const std::vector<std::vector<std::vector<float>>> &vec);
+
 // clang-format off
-void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) {
+static void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) {
   // Assign buffer
   int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; });
   tensor->data.Resize(sizeof(float) * dim);
diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h
index 36fd0727aa..1baa64c249 100644
--- a/paddle/fluid/inference/api/paddle_inference_api.h
+++ b/paddle/fluid/inference/api/paddle_inference_api.h
@@ -77,6 +77,7 @@ enum class PaddleEngineKind {
   kNative = 0,         // Use the native Fluid facility.
   kAnakin,             // Use Anakin for inference.
   kAutoMixedTensorRT,  // Automatically mix Fluid with TensorRT.
+  kAnalysis
   // TODO(Superjomn) support following engines latter.
   // kTensorRT,           // Use TensorRT for inference.
   // kAutoMixedAnakin,    // Automatically mix Fluid with Anakin.
diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc
index 181868977d..cef7b2a7e3 100644
--- a/paddle/fluid/inference/io.cc
+++ b/paddle/fluid/inference/io.cc
@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load(
   return main_program;
 }
 
+void SaveVars(const framework::Scope& scope,
+              const std::vector<std::string>& vars, const std::string& dirname,
+              bool predicate) {
+  framework::ProgramDesc prog;
+  auto* block = prog.MutableBlock(0);
+  auto* op = block->AppendOp();
+  op->SetType("save_combine");
+  op->SetInput("X", vars);
+  op->SetAttr("file_path", dirname + "/param");
+  op->CheckAttrs();
+
+  platform::CPUPlace place;
+  framework::Executor exe(place);
+  exe.Run(prog, const_cast<framework::Scope*>(&scope), 0, true, true);
+}
+
 }  // namespace inference
 }  // namespace paddle
diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h
index 01b50b3670..ab492577c1 100644
--- a/paddle/fluid/inference/io.h
+++ b/paddle/fluid/inference/io.h
@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
                                              const std::string& prog_filename,
                                              const std::string& param_filename);
 
+// Save the variables from a scope to disk.
+void SaveVars(const framework::Scope& scope,
+              const std::vector<std::string>& vars, const std::string& dirname,
+              bool predicate = true);
+
 }  // namespace inference
 }  // namespace paddle
diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h
index 695790a37d..94f0550df5 100644
--- a/paddle/fluid/inference/tests/test_helper.h
+++ b/paddle/fluid/inference/tests/test_helper.h
@@ -18,6 +18,7 @@ limitations under the License. */
 #include <string>
 #include <vector>
 
+#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
 #include "paddle/fluid/framework/lod_tensor.h"
 #include "paddle/fluid/inference/io.h"
 #include "paddle/fluid/platform/profiler.h"
@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
   return feed_target_shapes;
 }
 
+void Compile(paddle::framework::ProgramDesc* program) {
+  std::unique_ptr<paddle::framework::ir::Graph> g(
+      new paddle::framework::ir::Graph(*program));
+  auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
+      "graph_to_program_pass");
+  pass->SetNotOwned<paddle::framework::ProgramDesc>("program", program);
+  pass->Apply(std::move(g));
+}
+
 template <typename Place, bool CreateVars = true, bool PrepareContext = false>
 void TestInference(const std::string& dirname,
                    const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname,
         paddle::platform::DeviceContextPool::Instance().Get(place));
     inference_program = InitProgram(&executor, scope, dirname, is_combined);
   }
+  Compile(inference_program.get());
+
   // Disable the profiler and print the timing information
   paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
                                     "load_program_profiler");
@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname,
 
   delete scope;
 }
+
+USE_PASS(graph_to_program_pass);
diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc
index 8bab37c583..a02128c5a5 100644
--- a/paddle/fluid/operators/attention_lstm_op.cc
+++ b/paddle/fluid/operators/attention_lstm_op.cc
@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
   const int D = w_dims[1] / 4;
   PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
   PADDLE_ENFORCE_EQ(w_dims[0], D + M,
-                    "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
+                    "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
 
   auto b_dims = ctx->GetInputDim("LSTMBias");
   PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt
index a44d84cd7b..1301c8ae2b 100644
--- a/paddle/fluid/operators/detection/CMakeLists.txt
+++ b/paddle/fluid/operators/detection/CMakeLists.txt
@@ -29,6 +29,6 @@ target_assign_op.cu)
 detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
 polygon_box_transform_op.cu)
 detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
-
-# Export local libraries to parent
+detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
+#Export local libraries to parent
 set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc
new file mode 100644
index 0000000000..d29b015338
--- /dev/null
+++ b/paddle/fluid/operators/detection/generate_proposals_op.cc
@@ -0,0 +1,485 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <string>
+#include <vector>
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/gather.h"
+#include "paddle/fluid/operators/math/math_function.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
+
+struct AppendProposalsFunctor {
+  LoDTensor *out_;
+  int64_t offset_;
+  Tensor *to_add_;
+
+  AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add)
+      : out_(out), offset_(offset), to_add_(to_add) {}
+
+  template <typename T>
+  void operator()() const {
+    auto *out_data = out_->data<T>();
+    auto *to_add_data = to_add_->data<T>();
+    memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T));
+  }
+};
+
+class GenerateProposalsOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext *ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("Scores"), "Input(Scores) shouldn't be null.");
+    PADDLE_ENFORCE(ctx->HasInput("BboxDeltas"),
+                   "Input(BboxDeltas) shouldn't be null.");
+    PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Anchors"),
+                   "Input(Anchors) shouldn't be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Variances"),
+                   "Input(Variances) shouldn't be null.");
+
+    auto scores_dims = ctx->GetInputDim("Scores");
+    auto bbox_deltas_dims = ctx->GetInputDim("BboxDeltas");
+    auto im_info_dims = ctx->GetInputDim("ImInfo");
+    auto anchors_dims = ctx->GetInputDim("Anchors");
+    auto variances_dims = ctx->GetInputDim("Variances");
+
+    ctx->SetOutputDim("RpnRois", {-1, 4});
+    ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext &ctx) const override {
+    return framework::OpKernelType(
+        framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()),
+        platform::CPUPlace());
+  }
+};
+
+template <class T>
+void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
+              Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) {
+  T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
+
+  int64_t row = all_anchors->dims()[0];
+  int64_t len = all_anchors->dims()[1];
+
+  auto *bbox_deltas_data = bbox_deltas->data<T>();
+  auto *anchor_data = all_anchors->data<T>();
+  const T *variances_data = nullptr;
+  if (variances) {
+    variances_data = variances->data<T>();
+  }
+
+  for (int64_t i = 0; i < row; ++i) {
+    T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len];
+    T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1];
+
+    T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2;
+    T anchor_center_y =
+        (anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2;
+
+    T bbox_center_x = 0, bbox_center_y = 0;
+    T bbox_width = 0, bbox_height = 0;
+
+    if (variances) {
+      bbox_center_x =
+          variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
+          anchor_center_x;
+      bbox_center_y = variances_data[i * len + 1] *
+                          bbox_deltas_data[i * len + 1] * anchor_height +
+                      anchor_center_y;
+      bbox_width = std::exp(variances_data[i * len + 2] *
+                            bbox_deltas_data[i * len + 2]) *
+                   anchor_width;
+      bbox_height = std::exp(variances_data[i * len + 3] *
+                             bbox_deltas_data[i * len + 3]) *
+                    anchor_height;
+    } else {
+      bbox_center_x =
+          bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
+      bbox_center_y =
+          bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
+      bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width;
+      bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height;
+    }
+
+    proposals_data[i * len] = bbox_center_x - bbox_width / 2;
+    proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
+    proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2;
+    proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2;
+  }
+  // return proposals;
+}
+
+template <class T>
+void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info,
+                    Tensor *boxes) {
+  T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
+  const T *im_info_data = im_info.data<T>();
+  for (int64_t i = 0; i < boxes->numel(); ++i) {
+    if (i % 4 == 0) {
+      boxes_data[i] =
+          std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
+    } else if (i % 4 == 1) {
+      boxes_data[i] =
+          std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
+    } else if (i % 4 == 2) {
+      boxes_data[i] =
+          std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
+    } else {
+      boxes_data[i] =
+          std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
+    }
+  }
+}
+
+template <class T>
+void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
+                 float min_size, const Tensor &im_info, Tensor *keep) {
+  const T *im_info_data = im_info.data<T>();
+  T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
+  min_size *= im_info_data[2];
+  keep->Resize({boxes->dims()[0], 1});
+  int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
+
+  int keep_len = 0;
+  for (int i = 0; i < boxes->dims()[0]; ++i) {
+    T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
+    T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
+    T x_ctr = boxes_data[4 * i] + ws / 2;
+    T y_ctr = boxes_data[4 * i + 1] + hs / 2;
+    if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] &&
+        y_ctr <= im_info_data[0]) {
+      keep_data[keep_len++] = i;
+    }
+  }
+  keep->Resize({keep_len});
+}
+
+bool SortScorePairDescend(const std::pair<float, int> &pair1,
+                          const std::pair<float, int> &pair2) {
+  return pair1.first > pair2.first;
+}
+
+template <class T>
+void GetMaxScoreIndex(const std::vector<T> &scores,
+                      std::vector<std::pair<T, int>> *sorted_indices) {
+  for (size_t i = 0; i < scores.size(); ++i) {
+    sorted_indices->push_back(std::make_pair(scores[i], i));
+  }
+  // Sort the score pair according to the scores in descending order
+  std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
+                   SortScorePairDescend);
+}
+
+template <class T>
+T BBoxArea(const T *box, const bool normalized) {
+  if (box[2] < box[0] || box[3] < box[1]) {
+    // If coordinate values are is invalid
+    // (e.g. xmax < xmin or ymax < ymin), return 0.
+    return static_cast<T>(0.);
+  } else {
+    const T w = box[2] - box[0];
+    const T h = box[3] - box[1];
+    if (normalized) {
+      return w * h;
+    } else {
+      // If coordinate values are not within range [0, 1].
+      return (w + 1) * (h + 1);
+    }
+  }
+}
+
+template <class T>
+T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
+  if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
+      box2[3] < box1[1]) {
+    return static_cast<T>(0.);
+  } else {
+    const T inter_xmin = std::max(box1[0], box2[0]);
+    const T inter_ymin = std::max(box1[1], box2[1]);
+    const T inter_xmax = std::min(box1[2], box2[2]);
+    const T inter_ymax = std::min(box1[3], box2[3]);
+    const T inter_w = inter_xmax - inter_xmin;
+    const T inter_h = inter_ymax - inter_ymin;
+    const T inter_area = inter_w * inter_h;
+    const T bbox1_area = BBoxArea<T>(box1, normalized);
+    const T bbox2_area = BBoxArea<T>(box2, normalized);
+    return inter_area / (bbox1_area + bbox2_area - inter_area);
+  }
+}
+
+template <class T>
+Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
+           const T nms_threshold, const float eta) {
+  PADDLE_ENFORCE_NOT_NULL(bbox);
+  int64_t num_boxes = bbox->dims()[0];
+  // 4: [xmin ymin xmax ymax]
+  int64_t box_size = bbox->dims()[1];
+
+  std::vector<T> scores_data(num_boxes);
+  std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
+  std::vector<std::pair<T, int>> sorted_indices;
+  GetMaxScoreIndex<T>(scores_data, &sorted_indices);
+
+  std::vector<int> selected_indices;
+  int selected_num = 0;
+  T adaptive_threshold = nms_threshold;
+  const T *bbox_data = bbox->data<T>();
+  bool flag;
+  while (sorted_indices.size() != 0) {
+    int idx = sorted_indices.front().second;
+    flag = true;
+    for (size_t k = 0; k < selected_indices.size(); ++k) {
+      if (flag) {
+        const int kept_idx = selected_indices[k];
+        T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
+                                      bbox_data + kept_idx * box_size, false);
+        flag = (overlap <= adaptive_threshold);
+      } else {
+        break;
+      }
+    }
+    if (flag) {
+      selected_indices.push_back(idx);
+      selected_num++;
+    }
+    sorted_indices.erase(sorted_indices.begin());
+    if (flag && eta < 1 && adaptive_threshold > 0.5) {
+      adaptive_threshold *= eta;
+    }
+  }
+  Tensor keep_nms;
+  keep_nms.Resize({selected_num});
+  int *keep_data = keep_nms.mutable_data<int>(ctx.GetPlace());
+  for (int i = 0; i < selected_num; ++i) {
+    keep_data[i] = selected_indices[i];
+  }
+
+  return keep_nms;
+}
+
+template <typename DeviceContext, typename T>
+class GenerateProposalsKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext &context) const override {
+    auto *scores = context.Input<Tensor>("Scores");
+    auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
+    auto *im_info = context.Input<Tensor>("ImInfo");
+    auto *anchors = context.Input<Tensor>("Anchors");
+    auto *variances = context.Input<Tensor>("Variances");
+
+    auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
+    auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
+
+    int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
+    int post_nms_top_n = context.Attr<int>("post_nms_topN");
+    float nms_thresh = context.Attr<float>("nms_thresh");
+    float min_size = context.Attr<float>("min_size");
+    float eta = context.Attr<float>("eta");
+
+    auto &dev_ctx = context.template device_context<DeviceContext>();
+
+    auto scores_dim = scores->dims();
+    int64_t num = scores_dim[0];
+    int64_t c_score = scores_dim[1];
+    int64_t h_score = scores_dim[2];
+    int64_t w_score = scores_dim[3];
+
+    auto bbox_dim = bbox_deltas->dims();
+    int64_t c_bbox = bbox_dim[1];
+    int64_t h_bbox = bbox_dim[2];
+    int64_t w_bbox = bbox_dim[3];
+
+    rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
+                              context.GetPlace());
+    rpn_roi_probs->mutable_data<T>({scores->numel() / 4, 1},
+                                   context.GetPlace());
+
+    Tensor bbox_deltas_swap, scores_swap;
+    bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
+                                     dev_ctx.GetPlace());
+    scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
+                                dev_ctx.GetPlace());
+
+    math::Transpose<DeviceContext, T, 4> trans;
+    std::vector<int> axis = {0, 2, 3, 1};
+    trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
+    trans(dev_ctx, *scores, &scores_swap, axis);
+
+    framework::LoD lod;
+    std::vector<size_t> lod0(1, 0);
+    Tensor *anchor = const_cast<framework::Tensor *>(anchors);
+    anchor->Resize({anchors->numel() / 4, 4});
+    Tensor *var = const_cast<framework::Tensor *>(variances);
+    var->Resize({var->numel() / 4, 4});
+
+    int64_t num_proposals = 0;
+    for (int64_t i = 0; i < num; ++i) {
+      Tensor im_info_slice = im_info->Slice(i, i + 1);
+      Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
+      Tensor scores_slice = scores_swap.Slice(i, i + 1);
+
+      bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
+      scores_slice.Resize({h_score * w_score * c_score, 1});
+
+      std::pair<Tensor, Tensor> tensor_pair =
+          ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var,
+                              bbox_deltas_slice, scores_slice, pre_nms_top_n,
+                              post_nms_top_n, nms_thresh, min_size, eta);
+      Tensor proposals = tensor_pair.first;
+      Tensor scores = tensor_pair.second;
+
+      framework::VisitDataType(
+          framework::ToDataType(rpn_rois->type()),
+          AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals));
+      framework::VisitDataType(
+          framework::ToDataType(rpn_roi_probs->type()),
+          AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores));
+
+      num_proposals += proposals.dims()[0];
+      lod0.emplace_back(num_proposals);
+    }
+
+    lod.emplace_back(lod0);
+    rpn_rois->set_lod(lod);
+    rpn_roi_probs->set_lod(lod);
+    rpn_rois->Resize({num_proposals, 4});
+    rpn_roi_probs->Resize({num_proposals, 1});
+  }
+
+  std::pair<Tensor, Tensor> ProposalForOneImage(
+      const DeviceContext &ctx, const Tensor &im_info_slice,
+      const Tensor &anchors, const Tensor &variances,
+      const Tensor &bbox_deltas_slice,  // [M, 4]
+      const Tensor &scores_slice,       // [N, 1]
+      int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
+      float eta) const {
+    auto *scores_data = scores_slice.data<T>();
+
+    // Sort index
+    Tensor index_t;
+    index_t.Resize({scores_slice.numel()});
+    int *index = index_t.mutable_data<int>(ctx.GetPlace());
+    for (int i = 0; i < scores_slice.numel(); ++i) {
+      index[i] = i;
+    }
+    std::function<bool(const int64_t &, const int64_t &)> compare =
+        [scores_data](const int64_t &i, const int64_t &j) {
+          return scores_data[i] > scores_data[j];
+        };
+
+    if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
+      std::sort(index, index + scores_slice.numel(), compare);
+    } else {
+      std::nth_element(index, index + pre_nms_top_n,
+                       index + scores_slice.numel(), compare);
+      index_t.Resize({pre_nms_top_n});
+    }
+
+    Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
+    scores_sel.mutable_data<T>({index_t.numel(), 1}, ctx.GetPlace());
+    bbox_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
+    anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
+    var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
+
+    CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
+    CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
+    CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
+    CPUGather<T>(ctx, variances, index_t, &var_sel);
+
+    Tensor proposals;
+    proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
+    BoxCoder<T>(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);
+
+    ClipTiledBoxes<T>(ctx, im_info_slice, &proposals);
+
+    Tensor keep;
+    FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep);
+
+    Tensor scores_filter;
+    bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
+    scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
+    CPUGather<T>(ctx, proposals, keep, &bbox_sel);
+    CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
+    if (nms_thresh <= 0) {
+      return std::make_pair(bbox_sel, scores_sel);
+    }
+
+    Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
+
+    if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
+      keep_nms.Resize({post_nms_top_n});
+    }
+
+    proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
+    scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
+    CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
+    CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);
+
+    return std::make_pair(proposals, scores_sel);
+  }
+};
+
+class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("Scores", "The scores of anchors should be foreground.");
+    AddInput("BboxDeltas", "bbox_deltas.");
+    AddInput("ImInfo", "Information for image reshape.");
+    AddInput("Anchors", "All anchors.");
+    AddInput("Variances", " variances");
+
+    AddOutput("RpnRois", "Anchors.");
+    AddOutput("RpnRoiProbs", "Anchors.");
+    AddAttr<int>("pre_nms_topN", "pre_nms_topN");
+    AddAttr<int>("post_nms_topN", "post_nms_topN");
+    AddAttr<float>("nms_thresh", "nms_thres");
+    AddAttr<float>("min_size", "min size");
+    AddAttr<float>("eta", "eta");
+    AddComment(R"DOC(
+Generate Proposals OP
+
+This operator proposes rois according to each box with their probability to be a foreground object and 
+the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals
+could be used to train detection net.
+
+Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
+of anchors, H and W are height and width of the feature map.
+BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
+
+For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and 
+ calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area. 
+Finally, apply nms to get final proposals as output.
+)DOC");
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp,
+                  ops::GenerateProposalsOpMaker,
+                  paddle::framework::EmptyGradOpMaker);
+REGISTER_OP_CPU_KERNEL(
+    generate_proposals,
+    ops::GenerateProposalsKernel<paddle::platform::CPUDeviceContext, float>);
diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc
new file mode 100644
index 0000000000..3a34aa86b6
--- /dev/null
+++ b/paddle/fluid/operators/fusion_gru_op.cc
@@ -0,0 +1,332 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/fusion_gru_op.h"
+#include <string>
+#include "paddle/fluid/framework/eigen.h"
+#include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/detail/activation_functions.h"
+#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
+#include "paddle/fluid/operators/math/detail/gru_kernel.h"
+#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/gru_compute.h"
+#include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/operators/math/sequence2batch.h"
+
+namespace paddle {
+namespace operators {
+
+void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
+  PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
+  PADDLE_ENFORCE(ctx->HasInput("WeightX"),
+                 "Input(WeightX) of GRU should not be null.");
+  PADDLE_ENFORCE(ctx->HasInput("WeightH"),
+                 "Input(WeightH) of GRU should not be null.");
+
+  PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
+  PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
+                 "Output(BatchedGate) of GRU should not be null.");
+  PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
+                 "Output(BatchResetHiddenPrev) of GRU should not be null.");
+  PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
+                 "Output(BatchedHidden) of GRU should not be null.");
+  PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
+                 "Output(Hidden) of GRU should not be null.");
+
+  auto x_dims = ctx->GetInputDim("X");
+  PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
+
+  auto wx_dims = ctx->GetInputDim("WeightX");
+  PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
+                    "The rank of Input(WeightX) should be 2.");
+  PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
+                    "The first dimension of Input(WeightX) "
+                    "should be %d.",
+                    x_dims[1]);
+
+  int frame_size = wx_dims[1] / 3;
+  auto wh_dims = ctx->GetInputDim("WeightH");
+  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
+                    "The rank of Input(WeightH) should be 2.");
+  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
+                    "The first dimension of Input(WeightH) "
+                    "should be %d.",
+                    frame_size);
+  PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
+                    "The second dimension of Input(WeightH) "
+                    "should be 3 * %d.",
+                    frame_size);
+
+  if (ctx->HasInput("H0")) {
+    auto h0_dims = ctx->GetInputDim("H0");
+    PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
+                      "The width of H0 must be equal to frame_size.");
+  }
+  if (ctx->HasInput("Bias")) {
+    auto b_dims = ctx->GetInputDim("Bias");
+    PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
+    PADDLE_ENFORCE_EQ(b_dims[0], 1,
+                      "The first dimension of Input(Bias) should be 1.");
+    PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
+                      "The shape of Bias must be [1, frame_size * 3].");
+  }
+  framework::DDim out_dims({x_dims[0], frame_size});
+  ctx->SetOutputDim("Hidden", out_dims);
+  ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
+  ctx->SetOutputDim("BatchedHidden", out_dims);
+  ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
+  ctx->ShareLoD("X", "Hidden");
+
+  int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
+  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
+  ctx->ShareLoD("X", "XX");
+}
+
+framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
+    const framework::ExecutionContext& ctx) const {
+  return framework::OpKernelType(
+      framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
+      ctx.device_context());
+}
+
+void FusionGRUOpMaker::Make() {
+  AddInput("X",
+           "(LoDTensor) the input is a LodTensor, which support "
+           "variable-time length input sequence. The underlying tensor in "
+           "this LoDTensor is a matrix with shape (T X M), where T is the "
+           "total time steps in this mini-batch, M is the dim size of x.");
+  AddInput("H0",
+           "(Tensor, optional) The initial hidden state is an optional "
+           "input. This is a tensor with shape (N x D), where N is the "
+           "batch size, D is the hidden size.")
+      .AsDispensable();
+  AddInput("WeightX",
+           "(Tensor) The FC weight with shape (M x 3D),"
+           "where M is the dim size of x, D is the hidden size. ");
+  AddInput("WeightH",
+           "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. ");
+  AddInput("Bias",
+           "(Tensor, optional) (1 x 3D)."
+           "Almost same as GRUOp."
+           "Note: if have FC bias it should be added on this bias.")
+      .AsDispensable();
+  AddOutput("XX",
+            "(LoDTensor) the result after X * WeightX (size is T x 4D)"
+            " or batched_X (size is T x M), this will be automatically chosen,"
+            " where T is the total time steps in this mini-batch,"
+            " D is the hidden size, M is the dim size of x input.")
+      .AsIntermediate();
+  AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate();
+  AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.")
+      .AsIntermediate();
+  AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.")
+      .AsIntermediate();
+  AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
+  AddAttr<std::string>("activation",
+                       "(string, default tanh) "
+                       "The activation type used for output candidate {h}_t.")
+      .SetDefault("tanh");
+  AddAttr<std::string>(
+      "gate_activation",
+      "(string, default sigmoid) "
+      "The activation type used in update gate and reset gate.")
+      .SetDefault("sigmoid");
+  AddAttr<bool>("is_reverse",
+                "(bool, defalut: False) "
+                "whether to compute reversed GRU.")
+      .SetDefault(false);
+  AddComment(R"DOC(
+The Fusion complete GRU Operator.
+This operator fuse the fully-connected operator into GRU, 
+more details can refer to GRU op.
+)DOC");
+}
+
+template <typename DeviceContext, typename T>
+inline void ReorderInitState(const DeviceContext& ctx,
+                             const framework::Tensor& src,
+                             framework::Vector<size_t> index_lod,
+                             framework::Tensor* dst, bool indexed_src) {
+  math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
+  dst->mutable_data<T>(src.dims(), ctx.GetPlace());
+  row_shuffle(ctx, src, index_lod, dst, indexed_src);
+}
+
+template <typename DeviceContext, typename T>
+class FusionGRUKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* x = ctx.Input<LoDTensor>("X");
+    auto* wx = ctx.Input<Tensor>("WeightX");
+    auto* wh = ctx.Input<Tensor>("WeightH");
+    auto* bias = ctx.Input<Tensor>("Bias");
+    auto* h0 = ctx.Input<Tensor>("H0");
+
+    auto* xx = ctx.Output<LoDTensor>("XX");
+    auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
+    auto* batch_reset_hidden_prev =
+        ctx.Output<LoDTensor>("BatchResetHiddenPrev");
+    auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden");
+    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
+    bool is_reverse = ctx.Attr<bool>("is_reverse");
+
+    T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
+    T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
+    batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
+    batch_hidden->mutable_data<T>(ctx.GetPlace());
+    hidden_out->mutable_data<T>(ctx.GetPlace());
+
+    const T* x_data = x->data<T>();
+    const T* wx_data = wx->data<T>();
+    const T* wh_data = wh->data<T>();
+    auto x_dims = x->dims();
+    auto wx_dims = wx->dims();
+    auto& dev_ctx = ctx.template device_context<DeviceContext>();
+    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
+    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
+    if (x_dims[1] > wx_dims[1]) {
+      math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
+                                        x_data, wx_data, xx_data,
+                                        bias ? bias->data<T>() : NULL);
+      to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
+    } else {
+      to_batch(dev_ctx, *x, xx, true, is_reverse);
+      batched_gate->set_lod(xx->lod());
+      math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
+                                        xx_data, wx_data, batched_gate_data,
+                                        bias ? bias->data<T>() : NULL);
+    }
+
+    int frame_size = static_cast<int>(wx_dims[1] / 3);
+    math::GRUMetaValue<T> gru_value;
+    gru_value.gate_weight = const_cast<T*>(wh_data);
+    gru_value.state_weight =
+        const_cast<T*>(wh_data + 2 * frame_size * frame_size);
+    Tensor ordered_h0;
+
+    framework::Vector<size_t> order(batched_gate->lod()[2]);
+
+    if (h0) {
+      ReorderInitState<DeviceContext, T>(
+          ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0,
+          true);
+      gru_value.prev_out_value = ordered_h0.data<T>();
+    } else {
+      gru_value.prev_out_value = nullptr;
+    }
+    auto batch_starts = batched_gate->lod()[0];
+    size_t seq_len = batch_starts.size() - 1;
+    auto active_node =
+        math::detail::GetActivationType(ctx.Attr<std::string>("activation"));
+    auto active_gate = math::detail::GetActivationType(
+        ctx.Attr<std::string>("gate_activation"));
+
+#ifdef PADDLE_WITH_MKLML
+    // use MKL packed to speedup GEMM
+    if (FLAGS_paddle_num_threads >= 4) {
+      auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
+      T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
+                                       frame_size * 2 /*width of weight*/,
+                                       frame_size /*height of height*/);
+      PADDLE_ENFORCE(packed_gate);
+      blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
+                     frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
+                     packed_gate);
+      T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
+                                        frame_size /*width of weight*/,
+                                        frame_size /*height of height*/);
+      PADDLE_ENFORCE(packed_state);
+      blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
+                     frame_size, T(1.0), gru_value.state_weight, frame_size,
+                     packed_state);
+      for (size_t n = 0; n < seq_len; n++) {
+        int bstart = static_cast<int>(batch_starts[n]);
+        int bend = static_cast<int>(batch_starts[n + 1]);
+        int cur_batch_size = bend - bstart;
+
+        Tensor gate_t = batched_gate->Slice(bstart, bend);
+        Tensor reset_hidden_prev_t =
+            batch_reset_hidden_prev->Slice(bstart, bend);
+        Tensor hidden_t = batch_hidden->Slice(bstart, bend);
+        gru_value.output_value = hidden_t.data<T>();
+        gru_value.gate_value = gate_t.data<T>();
+        gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
+
+        if (gru_value.prev_out_value) {
+          blas.GEMM_COMPUTE(
+              CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
+              frame_size, gru_value.prev_out_value, frame_size, packed_gate,
+              frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
+        }
+
+        math::detail::forward_reset_output(
+            math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
+            cur_batch_size, active_gate);
+
+        if (gru_value.prev_out_value) {
+          blas.GEMM_COMPUTE(
+              CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
+              gru_value.reset_output_value, frame_size, packed_state,
+              frame_size, T(1), gru_value.gate_value + frame_size * 2,
+              frame_size * 3);
+        }
+
+        math::detail::forward_final_output(
+            math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
+            cur_batch_size, active_node);
+
+        gru_value.prev_out_value = gru_value.output_value;
+      }
+
+      blas.GEMM_FREE(packed_gate);
+      blas.GEMM_FREE(packed_state);
+    } else {
+#endif
+      for (size_t n = 0; n < seq_len; n++) {
+        int bstart = static_cast<int>(batch_starts[n]);
+        int bend = static_cast<int>(batch_starts[n + 1]);
+        int cur_batch_size = bend - bstart;
+
+        Tensor gate_t = batched_gate->Slice(bstart, bend);
+        Tensor reset_hidden_prev_t =
+            batch_reset_hidden_prev->Slice(bstart, bend);
+        Tensor hidden_t = batch_hidden->Slice(bstart, bend);
+        gru_value.output_value = hidden_t.data<T>();
+        gru_value.gate_value = gate_t.data<T>();
+        gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
+
+        math::GRUUnitFunctor<DeviceContext, T>::compute(
+            dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
+            active_gate);
+
+        gru_value.prev_out_value = gru_value.output_value;
+      }
+#ifdef PADDLE_WITH_MKLML
+    }
+#endif
+    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
+    batch_hidden->set_lod(batched_gate->lod());
+    to_seq(dev_ctx, *batch_hidden, hidden_out);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
+                  paddle::framework::DefaultGradOpDescMaker<true>);
+REGISTER_OP_CPU_KERNEL(
+    fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>,
+    ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
diff --git a/paddle/fluid/operators/fusion_gru_op.h b/paddle/fluid/operators/fusion_gru_op.h
new file mode 100644
index 0000000000..eaa59cd412
--- /dev/null
+++ b/paddle/fluid/operators/fusion_gru_op.h
@@ -0,0 +1,41 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using LoDTensor = framework::LoDTensor;
+using Tensor = framework::Tensor;
+
+class FusionGRUOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override;
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override;
+};
+
+class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override;
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
index 90aba5fe89..0cd3d3887c 100644
--- a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
+++ b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape(
                     "FC height should be sum of all inputs width.");
   if (ctx->HasInput("FCBias")) {
     auto b_dims = ctx->GetInputDim("FCBias");
-    PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2.");
-    PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D);
-    PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D);
+    PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2,
+                   "b_dims should be 1 or 2, get %d", b_dims.size());
+    if (b_dims.size() == 1) {
+      PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D);
+    } else {
+      PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D);
+      PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D);
+    }
   }
 
   ctx->SetOutputDim("Out", {ins_dims[0][0], D});
diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h
index 0bae926e98..5693761e9f 100644
--- a/paddle/fluid/operators/math/cpu_vec.h
+++ b/paddle/fluid/operators/math/cpu_vec.h
@@ -14,6 +14,7 @@ limitations under the License. */
 
 #pragma once
 #include <cmath>
+#include <functional>
 #include <string>
 #include "paddle/fluid/platform/cpu_info.h"
 #ifdef __AVX__
diff --git a/paddle/fluid/operators/math/sequence2batch.cc b/paddle/fluid/operators/math/sequence2batch.cc
index b546b87282..e4ffeedb5a 100644
--- a/paddle/fluid/operators/math/sequence2batch.cc
+++ b/paddle/fluid/operators/math/sequence2batch.cc
@@ -38,13 +38,14 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> {
     auto width = dst_dims[1];
     auto* src_data = src.data<T>();
     auto* dst_data = dst->data<T>();
-    for (int i = 0; i < height; ++i) {
-      if (is_src_index) {
-        memcpy(dst_data + i * width, src_data + index[i] * width,
-               width * sizeof(T));
-      } else {
-        memcpy(dst_data + index[i] * width, src_data + i * width,
-               width * sizeof(T));
+    const int sz = width * sizeof(T);
+    if (is_src_index) {
+      for (int i = 0; i < height; ++i) {
+        memcpy(dst_data + i * width, src_data + index[i] * width, sz);
+      }
+    } else {
+      for (int i = 0; i < height; ++i) {
+        memcpy(dst_data + index[i] * width, src_data + i * width, sz);
       }
     }
   }
diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc
index eb09470f37..97c36a83fc 100644
--- a/paddle/fluid/operators/parallel_do_op.cc
+++ b/paddle/fluid/operators/parallel_do_op.cc
@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
         grad->SetInput(framework::GradVarName(output_param), og_names);
       }
     }
+    grad->SetInput("Communicator", {"nccl_com__do_not_change_"});
     grad->SetAttrMap(this->Attrs());
     grad->SetBlockAttr(kParallelBlock, grad_block_[0]);
 
diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc
index 020ce4d6f5..4c99f4be32 100644
--- a/paddle/fluid/platform/init.cc
+++ b/paddle/fluid/platform/init.cc
@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) {
   } catch (const std::exception &exp) {
     LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
   }
-#else
-  LOG(WARNING)
-      << "'CUDA' is not supported, Please re-compile with WITH_GPU option";
 #endif
   InitDevices(init_p2p, devices);
 }
@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
   } catch (const std::exception &exp) {
     LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
   }
-#else
-  LOG(WARNING)
-      << "'CUDA' is not supported, Please re-compile with WITH_GPU option";
 #endif
 
   for (size_t i = 0; i < devices.size(); ++i) {
diff --git a/python/paddle/dataset/image.py b/python/paddle/dataset/image.py
index b32736ee7c..920dbf3b4e 100644
--- a/python/paddle/dataset/image.py
+++ b/python/paddle/dataset/image.py
@@ -203,7 +203,7 @@ def resize_short(im, size):
         h_new = size * h // w
     else:
         w_new = size * w // h
-    im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
+    im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_CUBIC)
     return im
 
 
@@ -345,7 +345,6 @@ def simple_transform(im,
         if np.random.randint(2) == 0:
             im = left_right_flip(im, is_color)
     else:
-        im = center_crop(im, crop_size, is_color)
         im = center_crop(im, crop_size, is_color=is_color)
     if len(im.shape) == 3:
         im = to_chw(im)
diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py
index 7207147884..a5bc1fa8f8 100644
--- a/python/paddle/fluid/layers/detection.py
+++ b/python/paddle/fluid/layers/detection.py
@@ -39,6 +39,7 @@ __all__ = [
     'detection_map',
     'rpn_target_assign',
     'anchor_generator',
+    'generate_proposals',
 ]
 
 __auto__ = [
@@ -1253,3 +1254,73 @@ def anchor_generator(input,
     anchor.stop_gradient = True
     var.stop_gradient = True
     return anchor, var
+
+
+def generate_proposals(scores,
+                       bbox_deltas,
+                       im_info,
+                       anchors,
+                       variances,
+                       pre_nms_top_n=6000,
+                       post_nms_top_n=1000,
+                       nms_thresh=0.5,
+                       min_size=0.1,
+                       eta=1.0,
+                       name=None):
+    """
+    ** Generate proposal labels Faster-RCNN **
+	
+	This operation proposes RoIs according to each box with their probability to be a foreground object and 
+	the box can be calculated by anchors. Bbox_deltais and scores to be an object are the output of RPN. Final proposals
+	could be used to train detection net.
+
+	For generating proposals, this operation performs following steps:
+
+	1. Transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4)
+ 	2. Calculate box locations as proposals candidates. 
+	3. Clip boxes to image
+	4. Remove predicted boxes with small area. 
+	5. Apply NMS to get final proposals as output.
+	
+      
+	Args:
+		scores(Variable): A 4-D Tensor with shape [N, A, H, W] represents the probability for each box to be an object.
+			N is batch size, A is number of anchors, H and W are height and width of the feature map.
+		bbox_deltas(Variable): A 4-D Tensor with shape [N, 4*A, H, W] represents the differece between predicted box locatoin and anchor location. 
+		im_info(Variable): A 2-D Tensor with shape [N, 3] represents origin image information for N batch. Info contains height, width and scale
+			between origin image size and the size of feature map.
+		anchors(Variable):   A 4-D Tensor represents the anchors with a layout of [H, W, A, 4]. H and W are height and width of the feature map,
+              		num_anchors is the box count of each position. Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized.
+		variances(Variable): The expanded variances of anchors with a layout of [H, W, num_priors, 4]. Each variance is in (xcenter, ycenter, w, h) format.
+		pre_nms_top_n(float): Number of total bboxes to be kept per image before NMS. 6000 by default.
+		post_nms_top_n(float): Number of total bboxes to be kept per image after NMS. 1000 by default.
+		nms_thresh(float): Threshold in NMS, 0.5 by default.
+		min_size(float): Remove predicted boxes with either height or width < min_size. 0.1 by default.
+		eta(float): Apply in adaptive NMS, if adaptive threshold > 0.5, adaptive_threshold = adaptive_threshold * eta in each iteration.
+    """
+    helper = LayerHelper('generate_proposals', **locals())
+
+    rpn_rois = helper.create_tmp_variable(dtype=bbox_deltas.dtype)
+    rpn_roi_probs = helper.create_tmp_variable(dtype=scores.dtype)
+    helper.append_op(
+        type="generate_proposals",
+        inputs={
+            'Scores': scores,
+            'BboxDeltas': bbox_deltas,
+            'ImInfo': im_info,
+            'Anchors': anchors,
+            'Variances': variances
+        },
+        attrs={
+            'pre_nms_topN': pre_nms_top_n,
+            'post_nms_topN': post_nms_top_n,
+            'nms_thresh': nms_thresh,
+            'min_size': min_size,
+            'eta': eta
+        },
+        outputs={'RpnRois': rpn_rois,
+                 'RpnRoiProbs': rpn_roi_probs})
+    rpn_rois.stop_gradient = True
+    rpn_roi_probs.stop_gradient = True
+
+    return rpn_rois, rpn_roi_probs
diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
index 3951e7b8ca..a231bbfbc8 100644
--- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
+++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost)
 batch_size = fluid.layers.create_tensor(dtype='int64')
 batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)
 
-# fluid.memory_optimize(fluid.default_main_program(), level=0)
-fluid.release_memory(fluid.default_main_program())
+fluid.memory_optimize(fluid.default_main_program(), level=0)
+# fluid.release_memory(fluid.default_main_program())
 
 BATCH_SIZE = 16
 PASS_NUM = 1
diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
index 1ad51936b5..e520c89650 100644
--- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
+++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
@@ -92,8 +92,8 @@ def main():
     optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
     optimizer.minimize(avg_cost)
 
-    # fluid.memory_optimize(fluid.default_main_program())
-    fluid.release_memory(fluid.default_main_program())
+    fluid.memory_optimize(fluid.default_main_program())
+    # fluid.release_memory(fluid.default_main_program())
 
     # fix the order of training data
     train_data = paddle.batch(
diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py
index 1467e72caa..b71b440d3c 100644
--- a/python/paddle/fluid/tests/test_detection.py
+++ b/python/paddle/fluid/tests/test_detection.py
@@ -201,5 +201,44 @@ class TestDetectionMAP(unittest.TestCase):
         print(str(program))
 
 
+class TestGenerateProposals(unittest.TestCase):
+    def test_generate_proposals(self):
+        data_shape = [20, 64, 64]
+        images = fluid.layers.data(
+            name='images', shape=data_shape, dtype='float32')
+        im_info = fluid.layers.data(
+            name='im_info', shape=[1, 3], dtype='float32')
+        anchors, variances = fluid.layers.anchor_generator(
+            name='anchor_generator',
+            input=images,
+            anchor_sizes=[32, 64],
+            aspect_ratios=[1.0],
+            variance=[0.1, 0.1, 0.2, 0.2],
+            stride=[16.0, 16.0],
+            offset=0.5)
+        num_anchors = anchors.shape[2]
+        scores = fluid.layers.data(
+            name='scores', shape=[1, num_anchors, 8, 8], dtype='float32')
+        bbox_deltas = fluid.layers.data(
+            name='bbox_deltas',
+            shape=[1, num_anchors * 4, 8, 8],
+            dtype='float32')
+        rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
+            name='generate_proposals',
+            scores=scores,
+            bbox_deltas=bbox_deltas,
+            im_info=im_info,
+            anchors=anchors,
+            variances=variances,
+            pre_nms_top_n=6000,
+            post_nms_top_n=1000,
+            nms_thresh=0.5,
+            min_size=0.1,
+            eta=1.0)
+        self.assertIsNotNone(rpn_rois)
+        self.assertIsNotNone(rpn_roi_probs)
+        print(rpn_rois.shape)
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py
index 239adcb9d5..179c2540f8 100644
--- a/python/paddle/fluid/tests/unittests/dist_transformer.py
+++ b/python/paddle/fluid/tests/unittests/dist_transformer.py
@@ -18,54 +18,129 @@ import numpy as np
 import argparse
 import time
 import math
+import os
+import sys
+import six
+import argparse
+import ast
+import multiprocessing
+import time
+from functools import partial
+from os.path import expanduser
+import glob
+import random
+import tarfile
 
 import paddle
 import paddle.fluid as fluid
+import paddle.fluid.layers as layers
 from paddle.fluid import core
-import os
-import sys
-import six
-import transformer_model
-import paddle.dataset.wmt16 as wmt16
+from test_dist_base import TestDistRunnerBase, runtime_main
+from paddle.compat import long_type
+
+import hashlib
+
+from paddle.fluid.transpiler.details import program_to_code
+
+const_para_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant(0.001))
+const_bias_attr = const_para_attr
 
 # Fix seed for test
 fluid.default_startup_program().random_seed = 1
 fluid.default_main_program().random_seed = 1
 
-WMT16_RECORDIO_FILE = "/tmp/wmt16.recordio"
 
+#from transformer_config import ModelHyperParams, TrainTaskConfig, merge_cfg_from_list
+class TrainTaskConfig(object):
+    # only support GPU currently
+    use_gpu = True
+    # the epoch number to train.
+    pass_num = 1
+    # the number of sequences contained in a mini-batch.
+    # deprecated, set batch_size in args.
+    batch_size = 20
+    # the hyper parameters for Adam optimizer.
+    # This static learning_rate will be multiplied to the LearningRateScheduler
+    # derived learning rate the to get the final learning rate.
+    learning_rate = 1
+    beta1 = 0.9
+    beta2 = 0.98
+    eps = 1e-9
+    # the parameters for learning rate scheduling.
+    warmup_steps = 4000
+    # the weight used to mix up the ground-truth distribution and the fixed
+    # uniform distribution in label smoothing when training.
+    # Set this as zero if label smoothing is not wanted.
+    label_smooth_eps = 0.1
+    # the directory for saving trained models.
+    model_dir = "trained_models"
+    # the directory for saving checkpoints.
+    ckpt_dir = "trained_ckpts"
+    # the directory for loading checkpoint.
+    # If provided, continue training from the checkpoint.
+    ckpt_path = None
+    # the parameter to initialize the learning rate scheduler.
+    # It should be provided if use checkpoints, since the checkpoint doesn't
+    # include the training step counter currently.
+    start_step = 0
 
-class ModelHyperParams(object):
-    # Dictionary size for source and target language. This model directly uses
-    # paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
-    # alreay been added, but the <pad> token is not added. Transformer requires
-    # sequences in a mini-batch are padded to have the same length. A <pad> token is
-    # added into the original dictionary in paddle.dateset.wmt16.
+    check_acc = True
 
-    # size of source word dictionary.
-    src_vocab_size = 10000
-    # index for <pad> token in source language.
-    src_pad_idx = src_vocab_size
+    data_path = expanduser("~") + (
+        "/.cache/paddle/dataset/test_dist_transformer/")
+    src_vocab_fpath = data_path + "vocab.bpe.32000"
+    trg_vocab_fpath = data_path + "vocab.bpe.32000"
+    train_file_pattern = data_path + "train.tok.clean.bpe.32000.en-de"
+    val_file_pattern = data_path + "newstest2013.tok.bpe.32000.en-de"
+    pool_size = 2000
+    sort_type = None
+    local = True
+    shuffle = False
+    shuffle_batch = False
+    special_token = ['<s>', '<e>', '<unk>']
+    token_delimiter = ' '
+    use_token_batch = False
 
-    # size of target word dictionay
-    trg_vocab_size = 10000
-    # index for <pad> token in target language.
-    trg_pad_idx = trg_vocab_size
 
-    # position value corresponding to the <pad> token.
-    pos_pad_idx = 0
+class InferTaskConfig(object):
+    use_gpu = True
+    # the number of examples in one run for sequence generation.
+    batch_size = 10
+    # the parameters for beam search.
+    beam_size = 5
+    max_out_len = 256
+    # the number of decoded sentences to output.
+    n_best = 1
+    # the flags indicating whether to output the special tokens.
+    output_bos = False
+    output_eos = False
+    output_unk = True
+    # the directory for loading the trained model.
+    model_path = "trained_models/pass_1.infer.model"
 
-    # max length of sequences. It should plus 1 to include position
-    # padding token for position encoding.
-    max_length = 50
 
+class ModelHyperParams(object):
+    # These following five vocabularies related configurations will be set
+    # automatically according to the passed vocabulary path and special tokens.
+    # size of source word dictionary.
+    src_vocab_size = 10000
+    # size of target word dictionay
+    trg_vocab_size = 10000
+    # index for <bos> token
+    bos_idx = 0
+    # index for <eos> token
+    eos_idx = 1
+    # index for <unk> token
+    unk_idx = 2
+    # max length of sequences deciding the size of position encoding table.
+    # Start from 1 and count start and end tokens in.
+    max_length = 256
     # the dimension for word embeddings, which is also the last dimension of
     # the input and output of multi-head attention, position-wise feed-forward
     # networks, encoder and decoder.
-
     d_model = 512
     # size of the hidden layer in position-wise feed-forward networks.
-    d_inner_hid = 1024
+    d_inner_hid = 2048
     # the dimension that keys are projected to for dot-product attention.
     d_key = 64
     # the dimension that values are projected to for dot-product attention.
@@ -75,95 +150,1521 @@ class ModelHyperParams(object):
     # number of sub-layers to be stacked in the encoder and decoder.
     n_layer = 6
     # dropout rate used by all dropout layers.
-    dropout = 0.1
+    dropout = 0.0  # no random
+    # random seed used in dropout for CE.
+    dropout_seed = None
+    # the flag indicating whether to share embedding and softmax weights.
+    # vocabularies in source and target should be same for weight sharing.
+    weight_sharing = True
 
 
-def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head):
+def merge_cfg_from_list(cfg_list, g_cfgs):
+    """
+    Set the above global configurations using the cfg_list.
+    """
+    assert len(cfg_list) % 2 == 0
+    for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
+        for g_cfg in g_cfgs:
+            if hasattr(g_cfg, key):
+                try:
+                    value = eval(value)
+                except Exception:  # for file path
+                    pass
+                setattr(g_cfg, key, value)
+                break
+
+
+# The placeholder for batch_size in compile time. Must be -1 currently to be
+# consistent with some ops' infer-shape output in compile time, such as the
+# sequence_expand op used in beamsearch decoder.
+batch_size = -1
+# The placeholder for squence length in compile time.
+seq_len = ModelHyperParams.max_length
+# Here list the data shapes and data types of all inputs.
+# The shapes here act as placeholder and are set to pass the infer-shape in
+# compile time.
+input_descs = {
+    # The actual data shape of src_word is:
+    # [batch_size * max_src_len_in_batch, 1]
+    "src_word": [(batch_size, seq_len, long_type(1)), "int64", 2],
+    # The actual data shape of src_pos is:
+    # [batch_size * max_src_len_in_batch, 1]
+    "src_pos": [(batch_size, seq_len, long_type(1)), "int64"],
+    # This input is used to remove attention weights on paddings in the
+    # encoder.
+    # The actual data shape of src_slf_attn_bias is:
+    # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
+    "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
+                           seq_len), "float32"],
+    # The actual data shape of trg_word is:
+    # [batch_size * max_trg_len_in_batch, 1]
+    "trg_word": [(batch_size, seq_len, long_type(1)), "int64",
+                 2],  # lod_level is only used in fast decoder.
+    # The actual data shape of trg_pos is:
+    # [batch_size * max_trg_len_in_batch, 1]
+    "trg_pos": [(batch_size, seq_len, long_type(1)), "int64"],
+    # This input is used to remove attention weights on paddings and
+    # subsequent words in the decoder.
+    # The actual data shape of trg_slf_attn_bias is:
+    # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
+    "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
+                           seq_len), "float32"],
+    # This input is used to remove attention weights on paddings of the source
+    # input in the encoder-decoder attention.
+    # The actual data shape of trg_src_attn_bias is:
+    # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
+    "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
+                           seq_len), "float32"],
+    # This input is used in independent decoder program for inference.
+    # The actual data shape of enc_output is:
+    # [batch_size, max_src_len_in_batch, d_model]
+    "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
+    # The actual data shape of label_word is:
+    # [batch_size * max_trg_len_in_batch, 1]
+    "lbl_word": [(batch_size * seq_len, long_type(1)), "int64"],
+    # This input is used to mask out the loss of paddding tokens.
+    # The actual data shape of label_weight is:
+    # [batch_size * max_trg_len_in_batch, 1]
+    "lbl_weight": [(batch_size * seq_len, long_type(1)), "float32"],
+    # These inputs are used to change the shape tensor in beam-search decoder.
+    "trg_slf_attn_pre_softmax_shape_delta": [(long_type(2), ), "int32"],
+    "trg_slf_attn_post_softmax_shape_delta": [(long_type(4), ), "int32"],
+    "init_score": [(batch_size, long_type(1)), "float32"],
+}
+
+# Names of word embedding table which might be reused for weight sharing.
+word_emb_param_names = (
+    "src_word_emb_table",
+    "trg_word_emb_table", )
+# Names of position encoding table which will be initialized externally.
+pos_enc_param_names = (
+    "src_pos_enc_table",
+    "trg_pos_enc_table", )
+# separated inputs for different usages.
+encoder_data_input_fields = (
+    "src_word",
+    "src_pos",
+    "src_slf_attn_bias", )
+decoder_data_input_fields = (
+    "trg_word",
+    "trg_pos",
+    "trg_slf_attn_bias",
+    "trg_src_attn_bias",
+    "enc_output", )
+label_data_input_fields = (
+    "lbl_word",
+    "lbl_weight", )
+# In fast decoder, trg_pos (only containing the current time step) is generated
+# by ops and trg_slf_attn_bias is not needed.
+fast_decoder_data_input_fields = (
+    "trg_word",
+    "init_score",
+    "trg_src_attn_bias", )
+
+# fast_decoder_util_input_fields = (
+#     "trg_slf_attn_pre_softmax_shape_delta",
+#     "trg_slf_attn_post_softmax_shape_delta", )
+
+
+#from optim import LearningRateScheduler
+class LearningRateScheduler(object):
+    """
+    Wrapper for learning rate scheduling as described in the Transformer paper.
+    LearningRateScheduler adapts the learning rate externally and the adapted
+    learning rate will be feeded into the main_program as input data.
+    """
+
+    def __init__(self,
+                 d_model,
+                 warmup_steps,
+                 learning_rate=0.001,
+                 current_steps=0,
+                 name="learning_rate"):
+        self.current_steps = current_steps
+        self.warmup_steps = warmup_steps
+        self.d_model = d_model
+        self.static_lr = learning_rate
+        self.learning_rate = layers.create_global_var(
+            name=name,
+            shape=[1],
+            value=float(learning_rate),
+            dtype="float32",
+            persistable=True)
+
+    def update_learning_rate(self):
+        self.current_steps += 1
+        lr_value = np.power(self.d_model, -0.5) * np.min([
+            np.power(self.current_steps, -0.5),
+            np.power(self.warmup_steps, -1.5) * self.current_steps
+        ]) * self.static_lr
+        return np.array([lr_value], dtype="float32")
+
+
+#from transformer_train import train_loop
+def pad_batch_data(insts,
+                   pad_idx,
+                   n_head,
+                   is_target=False,
+                   is_label=False,
+                   return_attn_bias=True,
+                   return_max_len=True,
+                   return_num_token=False):
     """
     Pad the instances to the max sequence length in batch, and generate the
-    corresponding position data and attention bias. Then, convert the numpy
-    data to tensors and return a dict mapping names to tensors.
+    corresponding position data and attention bias.
     """
+    return_list = []
+    max_len = max(len(inst) for inst in insts)
+    num_token = reduce(lambda x, y: x + y,
+                       [len(inst) for inst in insts]) if return_num_token else 0
+    # Any token included in dict can be used to pad, since the paddings' loss
+    # will be masked out by weights and make no effect on parameter gradients.
+    inst_data = np.array(
+        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
+    return_list += [inst_data.astype("int64").reshape([-1, 1])]
+    if is_label:  # label weight
+        inst_weight = np.array(
+            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
+        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
+    else:  # position data
+        inst_pos = np.array([
+            range(1, len(inst) + 1) + [0] * (max_len - len(inst))
+            for inst in insts
+        ])
+        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
+    if return_attn_bias:
+        if is_target:
+            # This is used to avoid attention on paddings and subsequent
+            # words.
+            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
+            slf_attn_bias_data = np.triu(slf_attn_bias_data,
+                                         1).reshape([-1, 1, max_len, max_len])
+            slf_attn_bias_data = np.tile(slf_attn_bias_data,
+                                         [1, n_head, 1, 1]) * [-1e9]
+        else:
+            # This is used to avoid attention on paddings.
+            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
+                                           (max_len - len(inst))
+                                           for inst in insts])
+            slf_attn_bias_data = np.tile(
+                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
+                [1, n_head, max_len, 1])
+        return_list += [slf_attn_bias_data.astype("float32")]
+    if return_max_len:
+        return_list += [max_len]
+    if return_num_token:
+        return_list += [num_token]
+    return return_list if len(return_list) > 1 else return_list[0]
+
+
+def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
+                        n_head, d_model):
+    """
+    Put all padded data needed by training into a dict.
+    """
+    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
+        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
+    src_word = src_word.reshape(-1, src_max_len, 1)
+    src_pos = src_pos.reshape(-1, src_max_len, 1)
+    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
+        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
+    trg_word = trg_word.reshape(-1, trg_max_len, 1)
+    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
 
-    def __pad_batch_data(insts,
-                         pad_idx,
-                         is_target=False,
-                         return_pos=True,
-                         return_attn_bias=True,
-                         return_max_len=True):
-        """
-        Pad the instances to the max sequence length in batch, and generate the
-        corresponding position data and attention bias.
-        """
-        return_list = []
-        max_len = max(len(inst) for inst in insts)
-        inst_data = np.array(
-            [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
-        return_list += [inst_data.astype("int64").reshape([-1, 1])]
-        if return_pos:
-            inst_pos = np.array([[
-                pos_i + 1 if w_i != pad_idx else 0
-                for pos_i, w_i in enumerate(inst)
-            ] for inst in inst_data])
-
-            return_list += [inst_pos.astype("int64").reshape([-1, 1])]
-        if return_attn_bias:
-            if is_target:
-                # This is used to avoid attention on paddings and subsequent
-                # words.
-                slf_attn_bias_data = np.ones((inst_data.shape[0], max_len,
-                                              max_len))
-                slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
-                    [-1, 1, max_len, max_len])
-                slf_attn_bias_data = np.tile(slf_attn_bias_data,
-                                             [1, n_head, 1, 1]) * [-1e9]
-            else:
-                # This is used to avoid attention on paddings.
-                slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
-                                               (max_len - len(inst))
-                                               for inst in insts])
-                slf_attn_bias_data = np.tile(
-                    slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
-                    [1, n_head, max_len, 1])
-            return_list += [slf_attn_bias_data.astype("float32")]
-        if return_max_len:
-            return_list += [max_len]
-        return return_list if len(return_list) > 1 else return_list[0]
-
-    src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data(
-        [inst[0] for inst in insts], src_pad_idx, is_target=False)
-    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data(
-        [inst[1] for inst in insts], trg_pad_idx, is_target=True)
     trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                 [1, 1, trg_max_len, 1]).astype("float32")
-    lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False,
-                                False, False, False)
-    lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
 
+    lbl_word, lbl_weight, num_token = pad_batch_data(
+        [inst[2] for inst in insts],
+        trg_pad_idx,
+        n_head,
+        is_target=False,
+        is_label=True,
+        return_attn_bias=False,
+        return_max_len=False,
+        return_num_token=True)
+
+    data_input_dict = dict(
+        zip(data_input_names, [
+            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
+            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
+        ]))
+    return data_input_dict, np.asarray([num_token], dtype="float32")
+
+
+def read_multiple(reader, count, clip_last=True):
+    """
+    Stack data from reader for multi-devices.
+    """
+
+    def __impl__():
+        res = []
+        for item in reader():
+            res.append(item)
+            if len(res) == count:
+                yield res
+                res = []
+        if len(res) == count:
+            yield res
+        elif not clip_last:
+            data = []
+            for item in res:
+                data += item
+            if len(data) > count:
+                inst_num_per_part = len(data) // count
+                yield [
+                    data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
+                    for i in range(count)
+                ]
+
+    return __impl__
+
+
+def split_data(data, num_part):
+    """
+    Split data for each device.
+    """
+    if len(data) == num_part:
+        return data
+    data = data[0]
+    inst_num_per_part = len(data) // num_part
     return [
-        src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias,
-        trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
+        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
+        for i in range(num_part)
     ]
 
 
-def transformer(use_feed):
-    assert not use_feed, "transfomer doesn't support feed yet"
-    return transformer_model.transformer(
-        ModelHyperParams.src_vocab_size + 1,
-        ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
-        ModelHyperParams.n_layer, ModelHyperParams.n_head,
-        ModelHyperParams.d_key, ModelHyperParams.d_value,
-        ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
-        ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
-        ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
+def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
+                 sum_cost, token_num):
+    # Context to do validation.
+    test_program = train_progm.clone()
+    with fluid.program_guard(test_program):
+        test_program = fluid.io.get_inference_program([avg_cost])
+
+    val_data = DataReader(
+        src_vocab_fpath=TrainTaskConfig.src_vocab_fpath,
+        trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath,
+        fpattern=TrainTaskConfig.val_file_pattern,
+        token_delimiter=TrainTaskConfig.token_delimiter,
+        use_token_batch=TrainTaskConfig.use_token_batch,
+        batch_size=TrainTaskConfig.batch_size *
+        (1 if TrainTaskConfig.use_token_batch else dev_count),
+        pool_size=TrainTaskConfig.pool_size,
+        sort_type=TrainTaskConfig.sort_type,
+        start_mark=TrainTaskConfig.special_token[0],
+        end_mark=TrainTaskConfig.special_token[1],
+        unk_mark=TrainTaskConfig.special_token[2],
+        # count start and end tokens out
+        max_length=ModelHyperParams.max_length - 2,
+        clip_last_batch=False,
+        shuffle=False,
+        shuffle_batch=False)
+
+    build_strategy = fluid.BuildStrategy()
+
+    strategy = fluid.ExecutionStrategy()
+    strategy.num_threads = 1
+
+    test_exe = fluid.ParallelExecutor(
+        use_cuda=TrainTaskConfig.use_gpu,
+        main_program=test_program,
+        share_vars_from=train_exe,
+        build_strategy=build_strategy,
+        exec_strategy=strategy)
+
+    def test(exe=test_exe):
+        test_total_cost = 0
+        test_total_token = 0
+        test_data = read_multiple(
+            reader=val_data.batch_generator,
+            count=dev_count if TrainTaskConfig.use_token_batch else 1)
+        for batch_id, data in enumerate(test_data()):
+            feed_list = []
+            for place_id, data_buffer in enumerate(
+                    split_data(
+                        data, num_part=dev_count)):
+                data_input_dict, _ = prepare_batch_input(
+                    data_buffer, data_input_names, ModelHyperParams.eos_idx,
+                    ModelHyperParams.eos_idx, ModelHyperParams.n_head,
+                    ModelHyperParams.d_model)
+                feed_list.append(data_input_dict)
+
+            outs = exe.run(feed=feed_list,
+                           fetch_list=[sum_cost.name, token_num.name])
+            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
+            test_total_cost += sum_cost_val.sum()
+            test_total_token += token_num_val.sum()
+        test_avg_cost = test_total_cost / test_total_token
+        test_ppl = np.exp([min(test_avg_cost, 100)])
+        return test_avg_cost, test_ppl
+
+    return test
+
+
+def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
+               token_num, predict):
+    # Initialize the parameters.
+    if TrainTaskConfig.ckpt_path:
+        lr_scheduler.current_steps = TrainTaskConfig.start_step
+    else:
+        exe.run(fluid.framework.default_startup_program())
+
+    train_data = DataReader(
+        src_vocab_fpath=TrainTaskConfig.src_vocab_fpath,
+        trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath,
+        fpattern=TrainTaskConfig.train_file_pattern,
+        token_delimiter=TrainTaskConfig.token_delimiter,
+        use_token_batch=TrainTaskConfig.use_token_batch,
+        batch_size=TrainTaskConfig.batch_size *
+        (1 if TrainTaskConfig.use_token_batch else dev_count),
+        pool_size=TrainTaskConfig.pool_size,
+        sort_type=TrainTaskConfig.sort_type,
+        shuffle=TrainTaskConfig.shuffle,
+        shuffle_batch=TrainTaskConfig.shuffle_batch,
+        start_mark=TrainTaskConfig.special_token[0],
+        end_mark=TrainTaskConfig.special_token[1],
+        unk_mark=TrainTaskConfig.special_token[2],
+        # count start and end tokens out
+        max_length=ModelHyperParams.max_length - 2,
+        clip_last_batch=False)
+    train_data = read_multiple(
+        reader=train_data.batch_generator,
+        count=dev_count if TrainTaskConfig.use_token_batch else 1)
+
+    build_strategy = fluid.BuildStrategy()
+    # Since the token number differs among devices, customize gradient scale to
+    # use token average cost among multi-devices. and the gradient scale is
+    # `1 / token_number` for average cost.
+    build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
+
+    strategy = fluid.ExecutionStrategy()
+    strategy.num_threads = 1
+
+    train_exe = fluid.ParallelExecutor(
+        use_cuda=TrainTaskConfig.use_gpu,
+        loss_name=sum_cost.name,
+        main_program=train_progm,
+        build_strategy=build_strategy,
+        exec_strategy=strategy)
+
+    data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
+                                                                             -1] + label_data_input_fields
+
+    if TrainTaskConfig.val_file_pattern is not None:
+        test = test_context(train_progm, avg_cost, train_exe, dev_count,
+                            data_input_names, sum_cost, token_num)
+
+    # the best cross-entropy value with label smoothing
+    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
+        (1. - TrainTaskConfig.label_smooth_eps
+         )) + TrainTaskConfig.label_smooth_eps *
+                        np.log(TrainTaskConfig.label_smooth_eps / (
+                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
+    init = False
+    for pass_id in xrange(TrainTaskConfig.pass_num):
+        pass_start_time = time.time()
+        for batch_id, data in enumerate(train_data()):
+            if batch_id >= 5:
+                break
+
+            feed_list = []
+            total_num_token = 0
+
+            #if TrainTaskConfig.local:
+            #    lr_rate = lr_scheduler.update_learning_rate()
+            #for place_id, data_buffer in enumerate(
+            #        split_data(
+            #            data, num_part=dev_count)):
+
+            if TrainTaskConfig.local:
+                lr_rate = lr_scheduler.update_learning_rate()
+
+            for place_id, data_buffer in enumerate(
+                    split_data(
+                        data, num_part=dev_count)):
+                data_input_dict, num_token = prepare_batch_input(
+                    data_buffer, data_input_names, ModelHyperParams.eos_idx,
+                    ModelHyperParams.eos_idx, ModelHyperParams.n_head,
+                    ModelHyperParams.d_model)
+                total_num_token += num_token
+                feed_kv_pairs = data_input_dict.items()
+                if TrainTaskConfig.local:
+                    feed_kv_pairs += {
+                        lr_scheduler.learning_rate.name: lr_rate
+                    }.items()
+                feed_list.append(dict(feed_kv_pairs))
+
+                if not init:
+                    for pos_enc_param_name in pos_enc_param_names:
+                        pos_enc = position_encoding_init(
+                            ModelHyperParams.max_length + 1,
+                            ModelHyperParams.d_model)
+                        feed_list[place_id][pos_enc_param_name] = pos_enc
+
+            if not TrainTaskConfig.check_acc:
+                for feed_dict in feed_list:
+                    feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token
+            else:
+                b = 100 * TrainTaskConfig.batch_size
+                a = np.asarray([b], dtype="float32")
+                for feed_dict in feed_list:
+                    feed_dict[sum_cost.name + "@GRAD"] = 1. / a
+
+            outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name],
+                                 feed=feed_list)
+
+            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
+            total_sum_cost = sum_cost_val.sum()
+            total_token_num = token_num_val.sum()
+            total_avg_cost = total_sum_cost / total_token_num
+
+            init = True
+
+            # Validate and save the model for inference.
+            if TrainTaskConfig.val_file_pattern is not None:
+                val_avg_cost, val_ppl = test()
+                print("[%f]" % val_avg_cost)
+            else:
+                assert (False)
+
+
+#import transformer_reader as reader
+class SortType(object):
+    GLOBAL = 'global'
+    POOL = 'pool'
+    NONE = "none"
+
+
+class Converter(object):
+    def __init__(self, vocab, beg, end, unk, delimiter):
+        self._vocab = vocab
+        self._beg = beg
+        self._end = end
+        self._unk = unk
+        self._delimiter = delimiter
+
+    def __call__(self, sentence):
+        return [self._beg] + [
+            self._vocab.get(w, self._unk)
+            for w in sentence.split(self._delimiter)
+        ] + [self._end]
+
+
+class ComposedConverter(object):
+    def __init__(self, converters):
+        self._converters = converters
+
+    def __call__(self, parallel_sentence):
+        return [
+            self._converters[i](parallel_sentence[i])
+            for i in range(len(self._converters))
+        ]
+
+
+class SentenceBatchCreator(object):
+    def __init__(self, batch_size):
+        self.batch = []
+        self._batch_size = batch_size
+
+    def append(self, info):
+        self.batch.append(info)
+        if len(self.batch) == self._batch_size:
+            tmp = self.batch
+            self.batch = []
+            return tmp
+
+
+class TokenBatchCreator(object):
+    def __init__(self, batch_size):
+        self.batch = []
+        self.max_len = -1
+        self._batch_size = batch_size
+
+    def append(self, info):
+        cur_len = info.max_len
+        max_len = max(self.max_len, cur_len)
+        if max_len * (len(self.batch) + 1) > self._batch_size:
+            result = self.batch
+            self.batch = [info]
+            self.max_len = cur_len
+            return result
+        else:
+            self.max_len = max_len
+            self.batch.append(info)
+
+
+class SampleInfo(object):
+    def __init__(self, i, max_len, min_len):
+        self.i = i
+        self.min_len = min_len
+        self.max_len = max_len
+
+
+class MinMaxFilter(object):
+    def __init__(self, max_len, min_len, underlying_creator):
+        self._min_len = min_len
+        self._max_len = max_len
+        self._creator = underlying_creator
+
+    def append(self, info):
+        if info.max_len > self._max_len or info.min_len < self._min_len:
+            return
+        else:
+            return self._creator.append(info)
+
+    @property
+    def batch(self):
+        return self._creator.batch
+
+
+class DataReader(object):
+    """
+    The data reader loads all data from files and produces batches of data
+    in the way corresponding to settings.
+
+    An example of returning a generator producing data batches whose data
+    is shuffled in each pass and sorted in each pool:
+
+    ```
+    train_data = DataReader(
+        src_vocab_fpath='data/src_vocab_file',
+        trg_vocab_fpath='data/trg_vocab_file',
+        fpattern='data/part-*',
+        use_token_batch=True,
+        batch_size=2000,
+        pool_size=10000,
+        sort_type=SortType.POOL,
+        shuffle=True,
+        shuffle_batch=True,
+        start_mark='<s>',
+        end_mark='<e>',
+        unk_mark='<unk>',
+        clip_last_batch=False).batch_generator
+    ```
+
+    :param src_vocab_fpath: The path of vocabulary file of source language.
+    :type src_vocab_fpath: basestring
+    :param trg_vocab_fpath: The path of vocabulary file of target language.
+    :type trg_vocab_fpath: basestring
+    :param fpattern: The pattern to match data files.
+    :type fpattern: basestring
+    :param batch_size: The number of sequences contained in a mini-batch.
+        or the maximum number of tokens (include paddings) contained in a
+        mini-batch.
+    :type batch_size: int
+    :param pool_size: The size of pool buffer.
+    :type pool_size: int
+    :param sort_type: The grain to sort by length: 'global' for all
+        instances; 'pool' for instances in pool; 'none' for no sort.
+    :type sort_type: basestring
+    :param clip_last_batch: Whether to clip the last uncompleted batch.
+    :type clip_last_batch: bool
+    :param tar_fname: The data file in tar if fpattern matches a tar file.
+    :type tar_fname: basestring
+    :param min_length: The minimum length used to filt sequences.
+    :type min_length: int
+    :param max_length: The maximum length used to filt sequences.
+    :type max_length: int
+    :param shuffle: Whether to shuffle all instances.
+    :type shuffle: bool
+    :param shuffle_batch: Whether to shuffle the generated batches.
+    :type shuffle_batch: bool
+    :param use_token_batch: Whether to produce batch data according to
+        token number.
+    :type use_token_batch: bool
+    :param field_delimiter: The delimiter used to split source and target in
+        each line of data file.
+    :type field_delimiter: basestring
+    :param token_delimiter: The delimiter used to split tokens in source or
+        target sentences.
+    :type token_delimiter: basestring
+    :param start_mark: The token representing for the beginning of
+        sentences in dictionary.
+    :type start_mark: basestring
+    :param end_mark: The token representing for the end of sentences
+        in dictionary.
+    :type end_mark: basestring
+    :param unk_mark: The token representing for unknown word in dictionary.
+    :type unk_mark: basestring
+    :param seed: The seed for random.
+    :type seed: int
+    """
+
+    def __init__(self,
+                 src_vocab_fpath,
+                 trg_vocab_fpath,
+                 fpattern,
+                 batch_size,
+                 pool_size,
+                 sort_type=SortType.GLOBAL,
+                 clip_last_batch=True,
+                 tar_fname=None,
+                 min_length=0,
+                 max_length=100,
+                 shuffle=True,
+                 shuffle_batch=False,
+                 use_token_batch=False,
+                 field_delimiter="\t",
+                 token_delimiter=" ",
+                 start_mark="<s>",
+                 end_mark="<e>",
+                 unk_mark="<unk>",
+                 seed=0):
+        self._src_vocab = self.load_dict(src_vocab_fpath)
+        self._only_src = True
+        if trg_vocab_fpath is not None:
+            self._trg_vocab = self.load_dict(trg_vocab_fpath)
+            self._only_src = False
+        self._pool_size = pool_size
+        self._batch_size = batch_size
+        self._use_token_batch = use_token_batch
+        self._sort_type = sort_type
+        self._clip_last_batch = clip_last_batch
+        self._shuffle = shuffle
+        self._shuffle_batch = shuffle_batch
+        self._min_length = min_length
+        self._max_length = max_length
+        self._field_delimiter = field_delimiter
+        self._token_delimiter = token_delimiter
+        self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
+                              unk_mark)
+        self._random = random.Random(x=seed)
+
+    def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
+                         unk_mark):
+        converters = [
+            Converter(
+                vocab=self._src_vocab,
+                beg=self._src_vocab[start_mark],
+                end=self._src_vocab[end_mark],
+                unk=self._src_vocab[unk_mark],
+                delimiter=self._token_delimiter)
+        ]
+        if not self._only_src:
+            converters.append(
+                Converter(
+                    vocab=self._trg_vocab,
+                    beg=self._trg_vocab[start_mark],
+                    end=self._trg_vocab[end_mark],
+                    unk=self._trg_vocab[unk_mark],
+                    delimiter=self._token_delimiter))
+
+        converters = ComposedConverter(converters)
+
+        self._src_seq_ids = []
+        self._trg_seq_ids = None if self._only_src else []
+        self._sample_infos = []
+
+        for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
+            src_trg_ids = converters(line)
+            self._src_seq_ids.append(src_trg_ids[0])
+            lens = [len(src_trg_ids[0])]
+            if not self._only_src:
+                self._trg_seq_ids.append(src_trg_ids[1])
+                lens.append(len(src_trg_ids[1]))
+            self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
+
+    def _load_lines(self, fpattern, tar_fname):
+        fpaths = glob.glob(fpattern)
+
+        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
+            if tar_fname is None:
+                raise Exception("If tar file provided, please set tar_fname.")
+
+            f = tarfile.open(fpaths[0], "r")
+            for line in f.extractfile(tar_fname):
+                fields = line.strip("\n").split(self._field_delimiter)
+                if (not self._only_src and len(fields) == 2) or (
+                        self._only_src and len(fields) == 1):
+                    yield fields
+        else:
+            for fpath in fpaths:
+                if not os.path.isfile(fpath):
+                    raise IOError("Invalid file: %s" % fpath)
+
+                with open(fpath, "r") as f:
+                    for line in f:
+                        fields = line.strip("\n").split(self._field_delimiter)
+                        if (not self._only_src and len(fields) == 2) or (
+                                self._only_src and len(fields) == 1):
+                            yield fields
+
+    @staticmethod
+    def load_dict(dict_path, reverse=False):
+        word_dict = {}
+        with open(dict_path, "r") as fdict:
+            for idx, line in enumerate(fdict):
+                if reverse:
+                    word_dict[idx] = line.strip("\n")
+                else:
+                    word_dict[line.strip("\n")] = idx
+        return word_dict
+
+    def batch_generator(self):
+        # global sort or global shuffle
+        if self._sort_type == SortType.GLOBAL:
+            infos = sorted(
+                self._sample_infos, key=lambda x: x.max_len, reverse=True)
+        else:
+            if self._shuffle:
+                infos = self._sample_infos
+                self._random.shuffle(infos)
+            else:
+                infos = self._sample_infos
+
+            if self._sort_type == SortType.POOL:
+                for i in range(0, len(infos), self._pool_size):
+                    infos[i:i + self._pool_size] = sorted(
+                        infos[i:i + self._pool_size], key=lambda x: x.max_len)
+
+        # concat batch
+        batches = []
+        batch_creator = TokenBatchCreator(
+            self._batch_size
+        ) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
+        batch_creator = MinMaxFilter(self._max_length, self._min_length,
+                                     batch_creator)
+
+        for info in infos:
+            batch = batch_creator.append(info)
+            if batch is not None:
+                batches.append(batch)
+
+        if not self._clip_last_batch and len(batch_creator.batch) != 0:
+            batches.append(batch_creator.batch)
+
+        if self._shuffle_batch:
+            self._random.shuffle(batches)
+
+        for batch in batches:
+            batch_ids = [info.i for info in batch]
+
+            if self._only_src:
+                yield [[self._src_seq_ids[idx]] for idx in batch_ids]
+            else:
+                yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
+                        self._trg_seq_ids[idx][1:]) for idx in batch_ids]
+
+
+#from transformer_model import transformer
+def position_encoding_init(n_position, d_pos_vec):
+    """
+    Generate the initial values for the sinusoid position encoding table.
+    """
+    position_enc = np.array([[
+        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
+        for j in range(d_pos_vec)
+    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
+    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
+    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
+    return position_enc.astype("float32")
+
+
+def multi_head_attention(queries,
+                         keys,
+                         values,
+                         attn_bias,
+                         d_key,
+                         d_value,
+                         d_model,
+                         n_head=1,
+                         dropout_rate=0.,
+                         cache=None):
+    """
+    Multi-Head Attention. Note that attn_bias is added to the logit before
+    computing softmax activiation to mask certain selected positions so that
+    they will not considered in attention weights.
+    """
+    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
+        raise ValueError(
+            "Inputs: quries, keys and values should all be 3-D tensors.")
+
+    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
+        """
+        Add linear projection to queries, keys, and values.
+        """
+        q = layers.fc(input=queries,
+                      size=d_key * n_head,
+                      num_flatten_dims=2,
+                      param_attr=const_para_attr,
+                      bias_attr=const_bias_attr)
+        k = layers.fc(input=keys,
+                      size=d_key * n_head,
+                      num_flatten_dims=2,
+                      param_attr=const_para_attr,
+                      bias_attr=const_bias_attr)
+        v = layers.fc(input=values,
+                      size=d_value * n_head,
+                      num_flatten_dims=2,
+                      param_attr=const_para_attr,
+                      bias_attr=const_bias_attr)
+        return q, k, v
+
+    def __split_heads(x, n_head):
+        """
+        Reshape the last dimension of inpunt tensor x so that it becomes two
+        dimensions and then transpose. Specifically, input a tensor with shape
+        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
+        with shape [bs, n_head, max_sequence_length, hidden_dim].
+        """
+        if n_head == 1:
+            return x
+
+        hidden_size = x.shape[-1]
+        # The value 0 in shape attr means copying the corresponding dimension
+        # size of the input as the output dimension size.
+        reshaped = layers.reshape(
+            x=x, shape=[0, 0, n_head, hidden_size // n_head])
+
+        # permuate the dimensions into:
+        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
+        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
+
+    def __combine_heads(x):
+        """
+        Transpose and then reshape the last two dimensions of inpunt tensor x
+        so that it becomes one dimension, which is reverse to __split_heads.
+        """
+        if len(x.shape) == 3: return x
+        if len(x.shape) != 4:
+            raise ValueError("Input(x) should be a 4-D Tensor.")
+
+        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
+        # The value 0 in shape attr means copying the corresponding dimension
+        # size of the input as the output dimension size.
+        return layers.reshape(
+            x=trans_x,
+            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
+
+    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
+        """
+        Scaled Dot-Product Attention
+        """
+        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
+        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
+        if attn_bias:
+            product += attn_bias
+        weights = layers.softmax(product)
+        if dropout_rate:
+            weights = layers.dropout(
+                weights,
+                dropout_prob=dropout_rate,
+                seed=ModelHyperParams.dropout_seed,
+                is_test=False)
+        out = layers.matmul(weights, v)
+        return out
+
+    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
+
+    if cache is not None:  # use cache and concat time steps
+        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
+        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
+
+    q = __split_heads(q, n_head)
+    k = __split_heads(k, n_head)
+    v = __split_heads(v, n_head)
+
+    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
+                                                  dropout_rate)
+
+    out = __combine_heads(ctx_multiheads)
+
+    # Project back to the model size.
+    proj_out = layers.fc(input=out,
+                         size=d_model,
+                         num_flatten_dims=2,
+                         param_attr=const_para_attr,
+                         bias_attr=const_bias_attr)
+    return proj_out
+
+
+def positionwise_feed_forward(x, d_inner_hid, d_hid):
+    """
+    Position-wise Feed-Forward Networks.
+    This module consists of two linear transformations with a ReLU activation
+    in between, which is applied to each position separately and identically.
+    """
+    hidden = layers.fc(input=x,
+                       size=d_inner_hid,
+                       num_flatten_dims=2,
+                       act="relu",
+                       param_attr=const_para_attr,
+                       bias_attr=const_bias_attr)
+    out = layers.fc(input=hidden,
+                    size=d_hid,
+                    num_flatten_dims=2,
+                    param_attr=const_para_attr,
+                    bias_attr=const_bias_attr)
+    return out
+
+
+def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
+    """
+    Add residual connection, layer normalization and droput to the out tensor
+    optionally according to the value of process_cmd.
+    This will be used before or after multi-head attention and position-wise
+    feed-forward networks.
+    """
+    for cmd in process_cmd:
+        if cmd == "a":  # add residual connection
+            out = out + prev_out if prev_out else out
+        elif cmd == "n":  # add layer normalization
+            out = layers.layer_norm(
+                out,
+                begin_norm_axis=len(out.shape) - 1,
+                param_attr=fluid.initializer.Constant(1.),
+                bias_attr=fluid.initializer.Constant(0.))
+        elif cmd == "d":  # add dropout
+            if dropout_rate:
+                out = layers.dropout(
+                    out,
+                    dropout_prob=dropout_rate,
+                    seed=ModelHyperParams.dropout_seed,
+                    is_test=False)
+    return out
+
+
+pre_process_layer = partial(pre_post_process_layer, None)
+post_process_layer = pre_post_process_layer
+
+
+def prepare_encoder(src_word,
+                    src_pos,
+                    src_vocab_size,
+                    src_emb_dim,
+                    src_max_len,
+                    dropout_rate=0.,
+                    word_emb_param_name=None,
+                    pos_enc_param_name=None):
+    """Add word embeddings and position encodings.
+    The output tensor has a shape of:
+    [batch_size, max_src_length_in_batch, d_model].
+    This module is used at the bottom of the encoder stacks.
+    """
+    if TrainTaskConfig.check_acc:
+        src_word_emb = layers.embedding(
+            src_word,
+            size=[src_vocab_size, src_emb_dim],
+            param_attr=fluid.ParamAttr(
+                name=word_emb_param_name,
+                initializer=fluid.initializer.ConstantInitializer(0.001)))
+    else:
+        src_word_emb = layers.embedding(
+            src_word,
+            size=[src_vocab_size, src_emb_dim],
+            param_attr=fluid.ParamAttr(
+                name=word_emb_param_name,
+                initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
+
+    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
+    src_pos_enc = layers.embedding(
+        src_pos,
+        size=[src_max_len, src_emb_dim],
+        param_attr=fluid.ParamAttr(
+            name=pos_enc_param_name,
+            trainable=False,
+            initializer=fluid.initializer.ConstantInitializer(0.001)))
+    enc_input = src_word_emb + src_pos_enc
+    return layers.dropout(
+        enc_input,
+        dropout_prob=dropout_rate,
+        seed=ModelHyperParams.dropout_seed,
+        is_test=False) if dropout_rate else enc_input
+
+
+prepare_encoder = partial(
+    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
+prepare_decoder = partial(
+    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])
+
+
+def encoder_layer(enc_input,
+                  attn_bias,
+                  n_head,
+                  d_key,
+                  d_value,
+                  d_model,
+                  d_inner_hid,
+                  dropout_rate=0.):
+    """The encoder layers that can be stacked to form a deep encoder.
+    This module consits of a multi-head (self) attention followed by
+    position-wise feed-forward networks and both the two components companied
+    with the post_process_layer to add residual connection, layer normalization
+    and droput.
+    """
+    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
+                                       attn_bias, d_key, d_value, d_model,
+                                       n_head, dropout_rate)
+    attn_output = post_process_layer(enc_input, attn_output, "dan",
+                                     dropout_rate)
+    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
+    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)
+
+
+def encoder(enc_input,
+            attn_bias,
+            n_layer,
+            n_head,
+            d_key,
+            d_value,
+            d_model,
+            d_inner_hid,
+            dropout_rate=0.):
+    """
+    The encoder is composed of a stack of identical layers returned by calling
+    encoder_layer.
+    """
+    for i in range(n_layer):
+        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
+                                   d_model, d_inner_hid, dropout_rate)
+        enc_input = enc_output
+    return enc_output
+
+
+def decoder_layer(dec_input,
+                  enc_output,
+                  slf_attn_bias,
+                  dec_enc_attn_bias,
+                  n_head,
+                  d_key,
+                  d_value,
+                  d_model,
+                  d_inner_hid,
+                  dropout_rate=0.,
+                  cache=None):
+    """ The layer to be stacked in decoder part.
+    The structure of this module is similar to that in the encoder part except
+    a multi-head attention is added to implement encoder-decoder attention.
+    """
+    slf_attn_output = multi_head_attention(
+        dec_input,
+        dec_input,
+        dec_input,
+        slf_attn_bias,
+        d_key,
+        d_value,
+        d_model,
+        n_head,
+        dropout_rate,
+        cache, )
+    slf_attn_output = post_process_layer(
+        dec_input,
+        slf_attn_output,
+        "dan",  # residual connection + dropout + layer normalization
+        dropout_rate, )
+    enc_attn_output = multi_head_attention(
+        slf_attn_output,
+        enc_output,
+        enc_output,
+        dec_enc_attn_bias,
+        d_key,
+        d_value,
+        d_model,
+        n_head,
+        dropout_rate, )
+    enc_attn_output = post_process_layer(
+        slf_attn_output,
+        enc_attn_output,
+        "dan",  # residual connection + dropout + layer normalization
+        dropout_rate, )
+    ffd_output = positionwise_feed_forward(
+        enc_attn_output,
+        d_inner_hid,
+        d_model, )
+    dec_output = post_process_layer(
+        enc_attn_output,
+        ffd_output,
+        "dan",  # residual connection + dropout + layer normalization
+        dropout_rate, )
+    return dec_output
 
 
-def get_model():
-    avg_cost = transformer(use_feed=False)
-    optimizer = fluid.optimizer.Adam()
-    optimizer.minimize(avg_cost)
-    fluid.memory_optimize(fluid.default_main_program())
-    return avg_cost
+def decoder(dec_input,
+            enc_output,
+            dec_slf_attn_bias,
+            dec_enc_attn_bias,
+            n_layer,
+            n_head,
+            d_key,
+            d_value,
+            d_model,
+            d_inner_hid,
+            dropout_rate=0.,
+            caches=None):
+    """
+    The decoder is composed of a stack of identical decoder_layer layers.
+    """
+    for i in range(n_layer):
+        cache = None
+        if caches is not None:
+            cache = caches[i]
+
+        dec_output = decoder_layer(
+            dec_input,
+            enc_output,
+            dec_slf_attn_bias,
+            dec_enc_attn_bias,
+            n_head,
+            d_key,
+            d_value,
+            d_model,
+            d_inner_hid,
+            dropout_rate,
+            cache=cache)
+        dec_input = dec_output
+    return dec_output
+
+
+def make_all_inputs(input_fields):
+    """
+    Define the input data layers for the transformer model.
+    """
+    inputs = []
+    for input_field in input_fields:
+        input_var = layers.data(
+            name=input_field,
+            shape=input_descs[input_field][0],
+            dtype=input_descs[input_field][1],
+            lod_level=input_descs[input_field][2]
+            if len(input_descs[input_field]) == 3 else 0,
+            append_batch_size=False)
+        inputs.append(input_var)
+    return inputs
+
+
+def transformer(
+        src_vocab_size,
+        trg_vocab_size,
+        max_length,
+        n_layer,
+        n_head,
+        d_key,
+        d_value,
+        d_model,
+        d_inner_hid,
+        dropout_rate,
+        weight_sharing,
+        label_smooth_eps, ):
+    if weight_sharing:
+        assert src_vocab_size == src_vocab_size, (
+            "Vocabularies in source and target should be same for weight sharing."
+        )
+    enc_inputs = make_all_inputs(encoder_data_input_fields)
+
+    enc_output = wrap_encoder(
+        src_vocab_size,
+        max_length,
+        n_layer,
+        n_head,
+        d_key,
+        d_value,
+        d_model,
+        d_inner_hid,
+        dropout_rate,
+        weight_sharing,
+        enc_inputs, )
+
+    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
+
+    predict = wrap_decoder(
+        trg_vocab_size,
+        max_length,
+        n_layer,
+        n_head,
+        d_key,
+        d_value,
+        d_model,
+        d_inner_hid,
+        dropout_rate,
+        weight_sharing,
+        dec_inputs,
+        enc_output, )
+
+    # Padding index do not contribute to the total loss. The weights is used to
+    # cancel padding index in calculating the loss.
+    label, weights = make_all_inputs(label_data_input_fields)
+    if label_smooth_eps:
+        label = layers.label_smooth(
+            label=layers.one_hot(
+                input=label, depth=trg_vocab_size),
+            epsilon=label_smooth_eps)
+
+    cost = layers.softmax_with_cross_entropy(
+        logits=layers.reshape(
+            predict, shape=[-1, trg_vocab_size]),
+        label=label,
+        soft_label=True if label_smooth_eps else False)
+    weighted_cost = cost * weights
+    sum_cost = layers.reduce_sum(weighted_cost)
+    token_num = layers.reduce_sum(weights)
+    avg_cost = sum_cost / token_num
+    avg_cost.stop_gradient = True
+    return sum_cost, avg_cost, predict, token_num
+
+
+def wrap_encoder(src_vocab_size,
+                 max_length,
+                 n_layer,
+                 n_head,
+                 d_key,
+                 d_value,
+                 d_model,
+                 d_inner_hid,
+                 dropout_rate,
+                 weight_sharing,
+                 enc_inputs=None):
+    """
+    The wrapper assembles together all needed layers for the encoder.
+    """
+    if enc_inputs is None:
+        # This is used to implement independent encoder program in inference.
+        src_word, src_pos, src_slf_attn_bias = \
+            make_all_inputs(encoder_data_input_fields)
+    else:
+        src_word, src_pos, src_slf_attn_bias = \
+            enc_inputs
+    enc_input = prepare_encoder(
+        src_word,
+        src_pos,
+        src_vocab_size,
+        d_model,
+        max_length,
+        dropout_rate,
+        word_emb_param_name=word_emb_param_names[0])
+    enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
+                         d_value, d_model, d_inner_hid, dropout_rate)
+    return enc_output
+
+
+def wrap_decoder(trg_vocab_size,
+                 max_length,
+                 n_layer,
+                 n_head,
+                 d_key,
+                 d_value,
+                 d_model,
+                 d_inner_hid,
+                 dropout_rate,
+                 weight_sharing,
+                 dec_inputs=None,
+                 enc_output=None,
+                 caches=None):
+    """
+    The wrapper assembles together all needed layers for the decoder.
+    """
+    if dec_inputs is None:
+        # This is used to implement independent decoder program in inference.
+        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
+        enc_output = make_all_inputs(
+            decoder_data_input_fields + decoder_util_input_fields)
+    else:
+        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
+
+    dec_input = prepare_decoder(
+        trg_word,
+        trg_pos,
+        trg_vocab_size,
+        d_model,
+        max_length,
+        dropout_rate,
+        word_emb_param_name=word_emb_param_names[0]
+        if weight_sharing else word_emb_param_names[1])
+    dec_output = decoder(
+        dec_input,
+        enc_output,
+        trg_slf_attn_bias,
+        trg_src_attn_bias,
+        n_layer,
+        n_head,
+        d_key,
+        d_value,
+        d_model,
+        d_inner_hid,
+        dropout_rate,
+        caches=caches)
+    # Return logits for training and probs for inference.
+    if weight_sharing:
+        predict = layers.matmul(
+            x=dec_output,
+            y=fluid.get_var(word_emb_param_names[0]),
+            transpose_y=True)
+    else:
+        predict = layers.fc(input=dec_output,
+                            size=trg_vocab_size,
+                            num_flatten_dims=2,
+                            param_attr=const_para_attr,
+                            bias_attr=const_bias_attr)
+    if dec_inputs is None:
+        predict = layers.softmax(predict)
+    return predict
+
+
+def fast_decode(
+        src_vocab_size,
+        trg_vocab_size,
+        max_in_len,
+        n_layer,
+        n_head,
+        d_key,
+        d_value,
+        d_model,
+        d_inner_hid,
+        dropout_rate,
+        weight_sharing,
+        beam_size,
+        max_out_len,
+        eos_idx, ):
+    """
+    Use beam search to decode. Caches will be used to store states of history
+    steps which can make the decoding faster.
+    """
+    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
+                              d_key, d_value, d_model, d_inner_hid,
+                              dropout_rate, weight_sharing)
+    start_tokens, init_scores, trg_src_attn_bias = \
+        make_all_inputs(fast_decoder_data_input_fields )
+
+    def beam_search():
+        max_len = layers.fill_constant(
+            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
+        step_idx = layers.fill_constant(
+            shape=[1], dtype=start_tokens.dtype, value=0)
+        cond = layers.less_than(x=step_idx, y=max_len)
+        while_op = layers.While(cond)
+        # array states will be stored for each step.
+        ids = layers.array_write(
+            layers.reshape(start_tokens, (-1, 1)), step_idx)
+        scores = layers.array_write(init_scores, step_idx)
+        # cell states will be overwrited at each step.
+        # caches contains states of history steps to reduce redundant
+        # computation in decoder.
+        caches = [{
+            "k": layers.fill_constant_batch_size_like(
+                input=start_tokens,
+                shape=[-1, 0, d_model],
+                dtype=enc_output.dtype,
+                value=0),
+            "v": layers.fill_constant_batch_size_like(
+                input=start_tokens,
+                shape=[-1, 0, d_model],
+                dtype=enc_output.dtype,
+                value=0)
+        } for i in range(n_layer)]
+        with while_op.block():
+            pre_ids = layers.array_read(array=ids, i=step_idx)
+            pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
+            pre_scores = layers.array_read(array=scores, i=step_idx)
+            # sequence_expand can gather sequences according to lod thus can be
+            # used in beam search to sift states corresponding to selected ids.
+            pre_src_attn_bias = layers.sequence_expand(
+                x=trg_src_attn_bias, y=pre_scores)
+            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
+            pre_caches = [{
+                "k": layers.sequence_expand(
+                    x=cache["k"], y=pre_scores),
+                "v": layers.sequence_expand(
+                    x=cache["v"], y=pre_scores),
+            } for cache in caches]
+            pre_pos = layers.elementwise_mul(
+                x=layers.fill_constant_batch_size_like(
+                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
+                    value=1,
+                    shape=[-1, 1, 1],
+                    dtype=pre_ids.dtype),
+                y=layers.increment(
+                    x=step_idx, value=1.0, in_place=False),
+                axis=0)
+            logits = wrap_decoder(
+                trg_vocab_size,
+                max_in_len,
+                n_layer,
+                n_head,
+                d_key,
+                d_value,
+                d_model,
+                d_inner_hid,
+                dropout_rate,
+                weight_sharing,
+                dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
+                enc_output=pre_enc_output,
+                caches=pre_caches)
+            logits = layers.reshape(logits, (-1, trg_vocab_size))
+
+            topk_scores, topk_indices = layers.topk(
+                input=layers.softmax(logits), k=beam_size)
+            accu_scores = layers.elementwise_add(
+                x=layers.log(topk_scores),
+                y=layers.reshape(
+                    pre_scores, shape=[-1]),
+                axis=0)
+            # beam_search op uses lod to distinguish branches.
+            topk_indices = layers.lod_reset(topk_indices, pre_ids)
+            selected_ids, selected_scores = layers.beam_search(
+                pre_ids=pre_ids,
+                pre_scores=pre_scores,
+                ids=topk_indices,
+                scores=accu_scores,
+                beam_size=beam_size,
+                end_id=eos_idx)
+
+            layers.increment(x=step_idx, value=1.0, in_place=True)
+            # update states
+            layers.array_write(selected_ids, i=step_idx, array=ids)
+            layers.array_write(selected_scores, i=step_idx, array=scores)
+            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
+            layers.assign(pre_enc_output, enc_output)
+            for i in range(n_layer):
+                layers.assign(pre_caches[i]["k"], caches[i]["k"])
+                layers.assign(pre_caches[i]["v"], caches[i]["v"])
+            length_cond = layers.less_than(x=step_idx, y=max_len)
+            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
+            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
+
+        finished_ids, finished_scores = layers.beam_search_decode(
+            ids, scores, beam_size=beam_size, end_id=eos_idx)
+        return finished_ids, finished_scores
+
+    finished_ids, finished_scores = beam_search()
+    return finished_ids, finished_scores
+
+
+def get_model(is_dist, is_async):
+    sum_cost, avg_cost, predict, token_num = transformer(
+        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
+        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
+        ModelHyperParams.n_head, ModelHyperParams.d_key,
+        ModelHyperParams.d_value, ModelHyperParams.d_model,
+        ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
+        ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
+
+    local_lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
+                                               TrainTaskConfig.warmup_steps,
+                                               TrainTaskConfig.learning_rate)
+
+    if not is_dist:
+        optimizer = fluid.optimizer.Adam(
+            learning_rate=local_lr_scheduler.learning_rate,
+            beta1=TrainTaskConfig.beta1,
+            beta2=TrainTaskConfig.beta2,
+            epsilon=TrainTaskConfig.eps)
+        optimizer.minimize(sum_cost)
+    elif is_async:
+        optimizer = fluid.optimizer.SGD(0.003)
+        optimizer.minimize(sum_cost)
+    else:
+        lr_decay = fluid.layers\
+         .learning_rate_scheduler\
+         .noam_decay(ModelHyperParams.d_model,
+            TrainTaskConfig.warmup_steps)
+
+        optimizer = fluid.optimizer.Adam(
+            learning_rate=lr_decay,
+            beta1=TrainTaskConfig.beta1,
+            beta2=TrainTaskConfig.beta2,
+            epsilon=TrainTaskConfig.eps)
+        optimizer.minimize(sum_cost)
+
+    return sum_cost, avg_cost, predict, token_num, local_lr_scheduler
 
 
 def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
@@ -176,10 +1677,23 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
     return t
 
 
-class DistTransformer2x2(object):
+def update_args():
+    src_dict = DataReader.load_dict(TrainTaskConfig.src_vocab_fpath)
+    trg_dict = DataReader.load_dict(TrainTaskConfig.trg_vocab_fpath)
+    dict_args = [
+        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
+        str(len(trg_dict)), "bos_idx",
+        str(src_dict[TrainTaskConfig.special_token[0]]), "eos_idx",
+        str(src_dict[TrainTaskConfig.special_token[1]]), "unk_idx",
+        str(src_dict[TrainTaskConfig.special_token[2]])
+    ]
+    merge_cfg_from_list(dict_args, [TrainTaskConfig, ModelHyperParams])
+
+
+class DistTransformer2x2(TestDistRunnerBase):
     def run_pserver(self, pserver_endpoints, trainers, current_endpoint,
-                    trainer_id):
-        get_model()
+                    trainer_id, sync_mode):
+        get_model(True, not sync_mode)
         t = get_transpiler(trainer_id,
                            fluid.default_main_program(), pserver_endpoints,
                            trainers)
@@ -196,7 +1710,6 @@ class DistTransformer2x2(object):
         while True:
             assert retry_times >= 0, "wait ps ready failed"
             time.sleep(3)
-            print("waiting ps ready: ", pid)
             try:
                 # the listen_and_serv_op would touch a file which contains the listen port
                 # on the /tmp directory until it was ready to process all the RPC call.
@@ -205,63 +1718,35 @@ class DistTransformer2x2(object):
             except os.error:
                 retry_times -= 1
 
-    def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
-        avg_cost = get_model()
+    def run_trainer(self,
+                    place,
+                    endpoints,
+                    trainer_id,
+                    trainers,
+                    is_dist=True,
+                    sync_mode=True):
+
+        sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
+            is_dist, not sync_mode)
+
         if is_dist:
             t = get_transpiler(trainer_id,
                                fluid.default_main_program(), endpoints,
                                trainers)
             trainer_prog = t.get_trainer_program()
+            TrainTaskConfig.batch_size = 10
+            TrainTaskConfig.train_file_pattern = TrainTaskConfig.data_path + "train.tok.clean.bpe.32000.en-de.train_{}".format(
+                trainer_id)
         else:
+            TrainTaskConfig.batch_size = 20
             trainer_prog = fluid.default_main_program()
 
         startup_exe = fluid.Executor(place)
-        startup_exe.run(fluid.default_startup_program())
-
-        strategy = fluid.ExecutionStrategy()
-        strategy.num_threads = 1
-        strategy.allow_op_delay = False
-        exe = fluid.ParallelExecutor(
-            True, loss_name=avg_cost.name, exec_strategy=strategy)
-
-        first_loss, = exe.run(fetch_list=[avg_cost.name])
-        print(first_loss)
-        for i in six.moves.xrange(5):
-            _ = exe.run(fetch_list=[avg_cost.name])
-        last_loss, = exe.run(fetch_list=[avg_cost.name])
-        print(last_loss)
-
-
-def main(role="pserver",
-         endpoints="127.0.0.1:9123",
-         trainer_id=0,
-         current_endpoint="127.0.0.1:9123",
-         trainers=1,
-         is_dist=True):
-
-    reader = paddle.batch(
-        wmt16.train(ModelHyperParams.src_vocab_size,
-                    ModelHyperParams.trg_vocab_size),
-        batch_size=transformer_model.batch_size)
-
-    with fluid.recordio_writer.create_recordio_writer(
-            WMT16_RECORDIO_FILE) as writer:
-        for batch in reader():
-            for tensor in prepare_batch_input(
-                    batch, ModelHyperParams.src_pad_idx,
-                    ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head):
-                t = fluid.LoDTensor()
-                t.set(tensor, fluid.CPUPlace())
-                writer.append_tensor(t)
-            writer.complete_append_tensor()
-
-    model = DistTransformer2x2()
-    if role == "pserver":
-        model.run_pserver(endpoints, trainers, current_endpoint, trainer_id)
-    else:
-        p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
-        ) else fluid.CPUPlace()
-        model.run_trainer(p, endpoints, trainer_id, trainers, is_dist)
+
+        TrainTaskConfig.local = not is_dist
+
+        train_loop(startup_exe, trainer_prog, 1, sum_cost, avg_cost,
+                   local_lr_scheduler, token_num, predict)
 
 
 if __name__ == "__main__":
@@ -269,18 +1754,6 @@ if __name__ == "__main__":
         print(
             "Usage: python dist_transformer.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist] [sync_mode]"
         )
-    role = sys.argv[1]
-    endpoints = sys.argv[2]
-    trainer_id = int(sys.argv[3])
-    current_endpoint = sys.argv[4]
-    trainers = int(sys.argv[5])
-    is_dist = True if sys.argv[6] == "TRUE" else False
-    # FIXME(typhoonzero): refine this test.
-    is_async = True if sys.argv[7] == "TRUE" else False
-    main(
-        role=role,
-        endpoints=endpoints,
-        trainer_id=trainer_id,
-        current_endpoint=current_endpoint,
-        trainers=trainers,
-        is_dist=is_dist)
+
+    update_args()
+    runtime_main(DistTransformer2x2)
diff --git a/python/paddle/fluid/tests/unittests/test_dist_transformer.py b/python/paddle/fluid/tests/unittests/test_dist_transformer.py
index 62fcf5953f..a8e6ce4cfe 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_transformer.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_transformer.py
@@ -15,17 +15,55 @@
 from __future__ import print_function
 
 import unittest
+import paddle
 from test_dist_base import TestDistBase
 
 
-class TestDistTransformer2x2(TestDistBase):
+def download_files():
+    url_prefix = 'http://paddle-unittest-data.cdn.bcebos.com/dist_transformer/'
+    vocab_url = url_prefix + 'vocab.bpe.32000'
+    vocab_md5 = 'a86d345ca6e27f6591d0dccb1b9be853'
+    paddle.dataset.common.download(vocab_url, 'test_dist_transformer',
+                                   vocab_md5)
+
+    local_train_url = url_prefix + 'train.tok.clean.bpe.32000.en-de'
+    local_train_md5 = '033eb02b9449e6dd823f050782ac8914'
+    paddle.dataset.common.download(local_train_url, 'test_dist_transformer',
+                                   local_train_md5)
+
+    train0_url = url_prefix + 'train.tok.clean.bpe.32000.en-de.train_0'
+    train0_md5 = 'ddce7f602f352a0405267285379a38b1'
+    paddle.dataset.common.download(train0_url, 'test_dist_transformer',
+                                   train0_md5)
+
+    train1_url = url_prefix + 'train.tok.clean.bpe.32000.en-de.train_1'
+    train1_md5 = '8757798200180285b1a619cd7f408747'
+    paddle.dataset.common.download(train1_url, 'test_dist_transformer',
+                                   train1_md5)
+
+    test_url = url_prefix + 'newstest2013.tok.bpe.32000.en-de'
+    test_md5 = '9dd74a266dbdb25314183899f269b4a2'
+    paddle.dataset.common.download(test_url, 'test_dist_transformer', test_md5)
+
+
+class TestDistTransformer2x2Sync(TestDistBase):
     def _setup_config(self):
         self._sync_mode = True
 
     def test_transformer(self):
-        # TODO(paddle-dev): check if the delta is OK.
-        # Usually start around ~8000 and converge to ~5000
-        self.check_with_place("dist_transformer.py", delta=400)
+        download_files()
+        #Note: loss on test dataset of the first 5 batch are:
+        # 10.518872, 10.518871, 10.518868, 10.518862, 10.518855
+        self.check_with_place("dist_transformer.py", delta=1e-7)
+
+
+class TestDistTransformer2x2Async(TestDistBase):
+    def _setup_config(self):
+        self._sync_mode = False
+
+    def test_transformer(self):
+        download_files()
+        self.check_with_place("dist_transformer.py", delta=1.0)
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
new file mode 100644
index 0000000000..764f83b534
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
@@ -0,0 +1,133 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import numpy as np
+import math
+from op_test import OpTest
+from test_gru_op import gru
+from test_fusion_lstm_op import fc, ACTIVATION
+
+
+def fusion_gru(
+        x,  # T x M
+        lod,  # 1 x N
+        h0,  # N x D
+        wx,  # M x 3D
+        wh,  # D x 3D
+        bias,  # 1 x 3D
+        is_reverse,
+        act_state,
+        act_gate):
+    return gru(fc(x, wx, bias),
+               lod,
+               h0,
+               wh,
+               np.zeros(
+                   (1, wh.shape[1]), dtype='float64'),
+               is_reverse,
+               act_state,
+               act_gate)
+
+
+class TestFusionGRUOp(OpTest):
+    def set_confs(self):
+        pass
+
+    def setUp(self):
+        self.op_type = "fusion_gru"
+        self.lod = [[2, 4, 3]]
+        self.M = 3
+        self.D = 5
+        self.is_reverse = False
+        self.with_h0 = True
+        self.with_bias = True
+        self.act_state = 'tanh'
+        self.act_gate = 'sigmoid'
+        self.set_confs()
+
+        T = sum(self.lod[0])
+        N = len(self.lod[0])
+
+        x = np.random.rand(T, self.M).astype('float64')
+        wx = np.random.rand(self.M, 3 * self.D).astype('float64')
+        wh = np.random.rand(self.D, 3 * self.D).astype('float64')
+        bias = np.random.rand(
+            1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
+                (1, 3 * self.D), dtype='float64')
+        h0 = np.random.rand(
+            N, self.D).astype('float64') if self.with_h0 else np.zeros(
+                (N, self.D), dtype='float64')
+
+        _, _, _, hidden = fusion_gru(
+            x, self.lod, h0, wx, wh, bias, self.is_reverse,
+            ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
+
+        self.inputs = {'X': (x, self.lod), 'WeightX': wx, 'WeightH': wh}
+
+        if self.with_bias:
+            self.inputs['Bias'] = bias
+
+        if self.with_h0:
+            self.inputs['H0'] = h0
+
+        self.outputs = {'Hidden': (hidden, self.lod)}
+
+        self.attrs = {
+            'activation': self.act_state,
+            'gate_activation': self.act_gate,
+            'is_reverse': self.is_reverse
+        }
+
+    def test_check_output(self):
+        self.check_output(atol=1e-8)
+
+
+class TestFusionGRUOpNoInitial(TestFusionGRUOp):
+    def set_confs(self):
+        self.with_h0 = False
+
+
+class TestFusionGRUOpNoBias(TestFusionGRUOp):
+    def set_confs(self):
+        self.with_bias = False
+
+
+class TestFusionGRUOpReverse(TestFusionGRUOp):
+    def set_confs(self):
+        self.is_reverse = True
+
+
+class TestFusionGRUOpMD1(TestFusionGRUOp):
+    def set_confs(self):
+        self.M = 36
+        self.D = 8
+
+
+class TestFusionGRUOpMD2(TestFusionGRUOp):
+    def set_confs(self):
+        self.M = 8
+        self.D = 8
+
+
+class TestFusionGRUOpBS1(TestFusionGRUOp):
+    def set_confs(self):
+        self.lod = [[3]]
+        self.D = 16
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposals.py b/python/paddle/fluid/tests/unittests/test_generate_proposals.py
new file mode 100644
index 0000000000..3fbd2ce95a
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_generate_proposals.py
@@ -0,0 +1,320 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://w_idxw.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import numpy as np
+import sys
+import math
+import paddle.fluid as fluid
+from op_test import OpTest
+from test_multiclass_nms_op import nms
+from test_anchor_generator_op import anchor_generator_in_python
+import copy
+
+
+def generate_proposals_in_python(scores, bbox_deltas, im_info, anchors,
+                                 variances, pre_nms_topN, post_nms_topN,
+                                 nms_thresh, min_size, eta):
+    all_anchors = anchors.reshape(-1, 4)
+    rois = np.empty((0, 5), dtype=np.float32)
+    roi_probs = np.empty((0, 1), dtype=np.float32)
+
+    rpn_rois = []
+    rpn_roi_probs = []
+    lod = []
+    num_images = scores.shape[0]
+    for img_idx in range(num_images):
+        img_i_boxes, img_i_probs = proposal_for_one_image(
+            im_info[img_idx, :], all_anchors, variances,
+            bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
+            pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta)
+        lod.append(img_i_probs.shape[0])
+        rpn_rois.append(img_i_boxes)
+        rpn_roi_probs.append(img_i_probs)
+
+    return rpn_rois, rpn_roi_probs, lod
+
+
+def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores,
+                           pre_nms_topN, post_nms_topN, nms_thresh, min_size,
+                           eta):
+    # Transpose and reshape predicted bbox transformations to get them
+    # into the same order as the anchors:
+    #   - bbox deltas will be (4 * A, H, W) format from conv output
+    #   - transpose to (H, W, 4 * A)
+    #   - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
+    #     in slowest to fastest order to match the enumerated anchors
+    bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4)
+    all_anchors = all_anchors.reshape(-1, 4)
+    variances = variances.reshape(-1, 4)
+    # Same story for the scores:
+    #   - scores are (A, H, W) format from conv output
+    #   - transpose to (H, W, A)
+    #   - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
+    #     to match the order of anchors and bbox_deltas
+    scores = scores.transpose((1, 2, 0)).reshape(-1, 1)
+
+    # sort all (proposal, score) pairs by score from highest to lowest
+    # take top pre_nms_topN (e.g. 6000)
+    if pre_nms_topN <= 0 or pre_nms_topN >= len(scores):
+        order = np.argsort(-scores.squeeze())
+    else:
+        # Avoid sorting possibly large arrays;
+        # First partition to get top K unsorted
+        # and then sort just thoes
+        inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN]
+        order = np.argsort(-scores[inds].squeeze())
+        order = inds[order]
+    scores = scores[order, :]
+    bbox_deltas = bbox_deltas[order, :]
+    all_anchors = all_anchors[order, :]
+    proposals = box_coder(all_anchors, bbox_deltas, variances)
+    # clip proposals to image (may result in proposals with zero area
+    # that will be removed in the next step)
+    proposals = clip_tiled_boxes(proposals, im_info[:2])
+    # remove predicted boxes with height or width < min_size
+    keep = filter_boxes(proposals, min_size, im_info)
+    proposals = proposals[keep, :]
+    scores = scores[keep, :]
+
+    # apply loose nms (e.g. threshold = 0.7)
+    # take post_nms_topN (e.g. 1000)
+    # return the top proposals
+    if nms_thresh > 0:
+        keep = nms(boxes=proposals,
+                   scores=scores,
+                   nms_threshold=nms_thresh,
+                   eta=eta)
+        if post_nms_topN > 0 and post_nms_topN < len(keep):
+            keep = keep[:post_nms_topN]
+        proposals = proposals[keep, :]
+        scores = scores[keep, :]
+
+    return proposals, scores
+
+
+def box_coder(all_anchors, bbox_deltas, variances):
+    """
+    Decode proposals by anchors and bbox_deltas from RPN 
+    """
+    #proposals: xmin, ymin, xmax, ymax
+    proposals = np.zeros_like(bbox_deltas, dtype=np.float32)
+
+    #anchor_loc: width, height, center_x, center_y
+    anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32)
+
+    anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0]
+    anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1]
+    anchor_loc[:, 2] = (all_anchors[:, 2] + all_anchors[:, 0]) / 2
+    anchor_loc[:, 3] = (all_anchors[:, 3] + all_anchors[:, 1]) / 2
+
+    #predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height 
+    pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32)
+    if variances is not None:
+        for i in range(bbox_deltas.shape[0]):
+            pred_bbox[i, 0] = variances[i, 0] * bbox_deltas[i, 0] * anchor_loc[
+                i, 0] + anchor_loc[i, 2]
+            pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[
+                i, 1] + anchor_loc[i, 3]
+            pred_bbox[i, 2] = math.exp(variances[i, 2] *
+                                       bbox_deltas[i, 2]) * anchor_loc[i, 0]
+            pred_bbox[i, 3] = math.exp(variances[i, 3] *
+                                       bbox_deltas[i, 3]) * anchor_loc[i, 1]
+    else:
+        for i in range(bbox_deltas.shape[0]):
+            pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[
+                i, 2]
+            pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[
+                i, 3]
+            pred_bbox[i, 2] = math.exp(bbox_deltas[i, 2]) * anchor_loc[i, 0]
+            pred_bbox[i, 3] = math.exp(bbox_deltas[i, 3]) * anchor_loc[i, 1]
+
+    proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2
+    proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2
+    proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2
+    proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2
+
+    return proposals
+
+
+def clip_tiled_boxes(boxes, im_shape):
+    """Clip boxes to image boundaries. im_shape is [height, width] and boxes
+    has shape (N, 4 * num_tiled_boxes)."""
+    assert boxes.shape[1] % 4 == 0, \
+        'boxes.shape[1] is {:d}, but must be divisible by 4.'.format(
+        boxes.shape[1]
+    )
+    # x1 >= 0
+    boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
+    # y1 >= 0
+    boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
+    # x2 < im_shape[1]
+    boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
+    # y2 < im_shape[0]
+    boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
+    return boxes
+
+
+def filter_boxes(boxes, min_size, im_info):
+    """Only keep boxes with both sides >= min_size and center within the image.
+    """
+    # Scale min_size to match image scale
+    min_size *= im_info[2]
+    ws = boxes[:, 2] - boxes[:, 0] + 1
+    hs = boxes[:, 3] - boxes[:, 1] + 1
+    x_ctr = boxes[:, 0] + ws / 2.
+    y_ctr = boxes[:, 1] + hs / 2.
+    keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_info[1]) &
+                    (y_ctr < im_info[0]))[0]
+    return keep
+
+
+def iou(box_a, box_b):
+    """
+	Apply intersection-over-union overlap between box_a and box_b
+    """
+    xmin_a = min(box_a[0], box_a[2])
+    ymin_a = min(box_a[1], box_a[3])
+    xmax_a = max(box_a[0], box_a[2])
+    ymax_a = max(box_a[1], box_a[3])
+
+    xmin_b = min(box_b[0], box_b[2])
+    ymin_b = min(box_b[1], box_b[3])
+    xmax_b = max(box_b[0], box_b[2])
+    ymax_b = max(box_b[1], box_b[3])
+
+    area_a = (ymax_a - ymin_a + 1) * (xmax_a - xmin_a + 1)
+    area_b = (ymax_b - ymin_b + 1) * (xmax_b - xmin_b + 1)
+    if area_a <= 0 and area_b <= 0:
+        return 0.0
+
+    xa = max(xmin_a, xmin_b)
+    ya = max(ymin_a, ymin_b)
+    xb = min(xmax_a, xmax_b)
+    yb = min(ymax_a, ymax_b)
+
+    inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)
+
+    iou_ratio = inter_area / (area_a + area_b - inter_area)
+
+    return iou_ratio
+
+
+def nms(boxes, scores, nms_threshold, eta=1.0):
+    """Apply non-maximum suppression at test time to avoid detecting too many
+    overlapping bounding boxes for a given object.
+    Args:
+        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+        scores: (tensor) The class predscores for the img, Shape:[num_priors].
+        nms_threshold: (float) The overlap thresh for suppressing unnecessary
+            boxes.
+        eta: (float) The parameter for adaptive NMS.
+    Return:
+        The indices of the kept boxes with respect to num_priors.
+    """
+    all_scores = copy.deepcopy(scores)
+    all_scores = all_scores.flatten()
+
+    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
+    sorted_scores = all_scores[sorted_indices]
+    selected_indices = []
+    adaptive_threshold = nms_threshold
+    for i in range(sorted_scores.shape[0]):
+        idx = sorted_indices[i]
+        keep = True
+        for k in range(len(selected_indices)):
+            if keep:
+                kept_idx = selected_indices[k]
+                overlap = iou(boxes[idx], boxes[kept_idx])
+                keep = True if overlap <= adaptive_threshold else False
+            else:
+                break
+        if keep:
+            selected_indices.append(idx)
+        if keep and eta < 1 and adaptive_threshold > 0.5:
+            adaptive_threshold *= eta
+    return selected_indices
+
+
+class TestGenerateProposalsOp(OpTest):
+    def set_data(self):
+        self.init_test_params()
+        self.init_test_input()
+        self.init_test_output()
+        self.inputs = {
+            'Scores': self.scores,
+            'BboxDeltas': self.bbox_deltas,
+            'ImInfo': self.im_info.astype(np.float32),
+            'Anchors': self.anchors,
+            'Variances': self.variances
+        }
+
+        self.attrs = {
+            'pre_nms_topN': self.pre_nms_topN,
+            'post_nms_topN': self.post_nms_topN,
+            'nms_thresh': self.nms_thresh,
+            'min_size': self.min_size,
+            'eta': self.eta
+        }
+
+        print("lod = ", self.lod)
+        self.outputs = {
+            'RpnRois': (self.rpn_rois[0], [self.lod]),
+            'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod])
+        }
+
+    def test_check_output(self):
+        self.check_output()
+
+    def setUp(self):
+        self.op_type = "generate_proposals"
+        self.set_data()
+
+    def init_test_params(self):
+        self.pre_nms_topN = 12000  # train 12000, test 2000
+        self.post_nms_topN = 5000  # train 6000, test 1000
+        self.nms_thresh = 0.7
+        self.min_size = 3.0
+        self.eta = 0.8
+
+    def init_test_input(self):
+        batch_size = 1
+        input_channels = 20
+        layer_h = 16
+        layer_w = 16
+        input_feat = np.random.random(
+            (batch_size, input_channels, layer_h, layer_w)).astype('float32')
+        self.anchors, self.variances = anchor_generator_in_python(
+            input_feat=input_feat,
+            anchor_sizes=[16., 32.],
+            aspect_ratios=[0.5, 1.0],
+            variances=[1.0, 1.0, 1.0, 1.0],
+            stride=[16.0, 16.0],
+            offset=0.5)
+        self.im_info = np.array([[64., 64., 8.]])  #im_height, im_width, scale
+        num_anchors = self.anchors.shape[2]
+        self.scores = np.random.random(
+            (batch_size, num_anchors, layer_h, layer_w)).astype('float32')
+        self.bbox_deltas = np.random.random(
+            (batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32')
+
+    def init_test_output(self):
+        self.rpn_rois, self.rpn_roi_probs, self.lod = generate_proposals_in_python(
+            self.scores, self.bbox_deltas, self.im_info, self.anchors,
+            self.variances, self.pre_nms_topN, self.post_nms_topN,
+            self.nms_thresh, self.min_size, self.eta)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py
index 001fd7efb1..9f6f03f9cf 100644
--- a/python/paddle/fluid/tests/unittests/test_gru_op.py
+++ b/python/paddle/fluid/tests/unittests/test_gru_op.py
@@ -19,22 +19,19 @@ import numpy as np
 import math
 import functools
 from op_test import OpTest
-from test_lstm_op import identity, sigmoid, tanh, relu
-
-
-class TestGRUOp(OpTest):
-    lod = [[2, 4, 3]]
-    batch_size = sum(lod[0])
-    frame_size = 5
-    activate = {
-        'identity': identity,
-        'sigmoid': sigmoid,
-        'tanh': tanh,
-        'relu': relu
-    }
-
-    @staticmethod
-    def seq_to_batch(lod, is_reverse):
+from test_lstm_op import ACTIVATION
+
+
+def gru(
+        input,  # T x 3D
+        lod,  # 1 x N
+        h0,  # N x D
+        weight,  # D x 3D
+        bias,  # 1 x 3D
+        is_reverse,
+        act_state,
+        act_gate):
+    def _seq_to_batch(lod, is_reverse):
         idx_in_seq_list = []
         seq_lens = lod[0]
         seq_starts = [0]
@@ -56,121 +53,125 @@ class TestGRUOp(OpTest):
             idx_in_seq_list.append(idx_in_seq)
         return idx_in_seq_list, sorted_seqs
 
-    def gru_step(self, x, h_p, w, b):
-        batch_size = x.shape[0]
-        frame_size = w.shape[0]
-        g = x + np.tile(b, (batch_size, 1))
-        w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
-            (frame_size, frame_size * 2))
-        u_r = self.activate[self.attrs['gate_activation']](np.dot(
-            h_p, w_u_r) + g[:, :frame_size * 2])
-        u = u_r[:, :frame_size]
-        r = u_r[:, frame_size:frame_size * 2]
+    def _step(x, h_p, w, b, act_state, act_gate):
+        T = x.shape[0]
+        D = w.shape[0]
+        g = x + np.tile(b, (T, 1))
+        w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2))
+        u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2])
+        u = u_r[:, :D]
+        r = u_r[:, D:D * 2]
         r_h_p = r * h_p
-        w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
-            (frame_size, frame_size))
-        c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
-                                                    g[:, frame_size * 2:])
+        w_c = w.flatten()[D * D * 2:].reshape((D, D))
+        c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
         g = np.hstack((u_r, c))
         h = u * c + (1 - u) * h_p
         return g, r_h_p, h
 
-    def gru(self):
-        input, lod = self.inputs['Input']
-        w = self.inputs['Weight']
-        b = self.inputs['Bias'] if 'Bias' in self.inputs else np.zeros(
-            (1, self.frame_size * 3))
-        batch_gate = self.outputs['BatchGate']
-        batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev']
-        batch_hidden = self.outputs['BatchHidden']
-        hidden = self.outputs['Hidden']
-        idx_in_seq_list = self.idx_in_seq_list
-        h_p = self.inputs['H0'][
-            self.sorted_seqs] if 'H0' in self.inputs else np.zeros(
-                (len(idx_in_seq_list[0]), self.frame_size))
-        num_batch = len(idx_in_seq_list)
-        end_idx = 0
-        for batch_idx in range(num_batch):
-            x = input[idx_in_seq_list[batch_idx]]
-            g, r_h_p, h = self.gru_step(x, h_p, w, b)
-            if batch_idx < (num_batch - 1):
-                h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
-            start_idx = end_idx
-            end_idx = start_idx + len(idx_in_seq_list[batch_idx])
-            batch_gate[start_idx:end_idx] = g
-            batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
-            batch_hidden[start_idx:end_idx] = h
-            hidden[idx_in_seq_list[batch_idx]] = h
-        return batch_gate, batch_reset_hidden_prev, hidden
-
-    def set_data(self):
-        lod = self.lod
-        self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
-            lod, self.is_reverse)
-        batch_size = self.batch_size
-        frame_size = self.frame_size
-        input = np.random.rand(batch_size, frame_size * 3).astype('float64')
-        h0 = np.random.rand(len(self.idx_in_seq_list[0]),
-                            frame_size).astype('float64')
-        weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
-        bias = np.random.rand(1, frame_size * 3).astype('float64')
-
-        self.inputs = {
-            'Input': (input, lod),
-            'H0': h0,
-            'Weight': weight,
-            'Bias': bias
-        }
+    T = sum(lod[0])
+    N = len(lod[0])
+    D = weight.shape[0]
+    batch_gate = np.zeros((T, 3 * D), dtype='float64')
+    batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
+    batch_hidden = np.zeros((T, D), dtype='float64')
+    hidden = np.zeros((T, D), dtype='float64')
+
+    idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
+    h_p = h0[sorted_seqs]
+    max_seq_len = len(idx_in_seq_list)
+    assert len(idx_in_seq_list[0]) == N
+    end_idx = 0
+    for batch_idx in range(max_seq_len):
+        x = input[idx_in_seq_list[batch_idx]]
+        g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate)
+        if batch_idx < (max_seq_len - 1):
+            h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
+        start_idx = end_idx
+        end_idx = start_idx + len(idx_in_seq_list[batch_idx])
+        batch_gate[start_idx:end_idx] = g
+        batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
+        batch_hidden[start_idx:end_idx] = h
+        hidden[idx_in_seq_list[batch_idx]] = h
+    return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden
 
-        self.outputs = {
-            'BatchGate': np.zeros(
-                (batch_size, frame_size * 3), dtype='float64'),
-            'BatchResetHiddenPrev': np.zeros(
-                (batch_size, frame_size), dtype='float64'),
-            'BatchHidden': np.zeros(
-                (batch_size, frame_size), dtype='float64'),
-            'Hidden': np.zeros(
-                (batch_size, frame_size), dtype='float64')
-        }
 
+class TestGRUOp(OpTest):
     def set_confs(self):
-        self.is_reverse = False
-        self.attrs = {
-            'activation': 'tanh',
-            'gate_activation': 'sigmoid',
-            'is_reverse': self.is_reverse
-        }
+        pass
 
     def setUp(self):
         self.op_type = "gru"
+        self.lod = [[2, 4, 3]]
+        self.D = 5
+        self.is_reverse = False
+        self.with_h0 = True
+        self.with_bias = True
+        self.act_state = 'tanh'
+        self.act_gate = 'sigmoid'
         self.set_confs()
-        self.set_data()
-        self.gru()
+
+        T = sum(self.lod[0])
+        N = len(self.lod[0])
+
+        input = np.random.rand(T, 3 * self.D).astype('float64')
+        weight = np.random.rand(self.D, 3 * self.D).astype('float64')
+        bias = np.random.rand(
+            1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
+                (1, 3 * self.D), dtype='float64')
+        h0 = np.random.rand(
+            N, self.D).astype('float64') if self.with_h0 else np.zeros(
+                (N, self.D), dtype='float64')
+
+        batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
+            input, self.lod, h0, weight, bias, self.is_reverse,
+            ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
+        self.inputs = {'Input': (input, self.lod), 'Weight': weight}
+
+        if self.with_bias:
+            self.inputs['Bias'] = bias
+
+        if self.with_h0:
+            self.inputs['H0'] = h0
+
+        self.outputs = {
+            'Hidden': (hidden, self.lod),
+            'BatchGate': batch_gate,
+            'BatchResetHiddenPrev': batch_reset_hidden_prev,
+            'BatchHidden': batch_hidden,
+        }
+
+        self.attrs = {
+            'activation': self.act_state,
+            'gate_activation': self.act_gate,
+            'is_reverse': self.is_reverse
+        }
 
     def test_check_output(self):
-        self.check_output()
+        self.check_output(atol=1e-8)
 
     def test_check_grad(self):
         self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
 
 
 class TestGRUOpNoInitial(TestGRUOp):
-    def set_data(self):
-        super(TestGRUOpNoInitial, self).set_data()
-        self.inputs.pop('H0')
+    def set_confs(self):
+        self.with_h0 = False
 
     def test_check_grad(self):
         self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
 
 
+class TestGRUOpNoBias(TestGRUOp):
+    def set_confs(self):
+        self.with_bias = False
+
+    def test_check_grad(self):
+        self.check_grad(['Input', 'H0', 'Weight'], ['Hidden'])
+
+
 class TestGRUOpReverse(TestGRUOp):
     def set_confs(self):
         self.is_reverse = True
-        self.attrs = {
-            'activation': 'tanh',
-            'gate_activation': 'sigmoid',
-            'is_reverse': self.is_reverse
-        }
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index 21c51bd139..4eb87b6a77 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -31,6 +31,7 @@ Steps to transpile pserver:
 """
 
 import math
+import sys
 import numpy as np
 import collections
 import six
@@ -181,7 +182,8 @@ class DistributeTranspiler(object):
                   program=None,
                   pservers="127.0.0.1:6174",
                   trainers=1,
-                  sync_mode=True):
+                  sync_mode=True,
+                  startup_program=None):
         """
         Run the transpiler.
 
@@ -194,13 +196,17 @@ class DistributeTranspiler(object):
                 list.
             trainers (int): number of trainers in the distributed job.
             sync_mode (bool): Do sync training or not, default is True.
+            startup_program (Program|None): startup_program to transpile,
+                default is fluid.default_main_program().
         """
         if program is None:
             program = default_main_program()
+        if startup_program is None:
+            startup_program = default_startup_program()
         self.origin_program = program
-        self.origin_startup_program = default_startup_program().clone()
+        self.startup_program = startup_program
+        self.origin_startup_program = self.startup_program.clone()
 
-        self.startup_program = default_startup_program()
         self.trainer_num = trainers
         self.sync_mode = sync_mode
         self.trainer_id = trainer_id
@@ -376,21 +382,18 @@ class DistributeTranspiler(object):
 
         return self.origin_program
 
-    def _get_trainer_startup_program(self,
-                                     recv_vars,
-                                     eplist,
-                                     startup_program=None):
+    def _get_trainer_startup_program(self, recv_vars, eplist):
         """
         Get transpiled trainer side startup program.
 
         Args:
-            startup_program(Program): Startup program.
+            recv_vars (list): Variable list to recv for current trainer_id
+            eplist (list): A list of strings indicating 
 
         Returns:
             Program: trainer side startup program.
         """
-        if startup_program is None:
-            startup_program = self.startup_program
+        startup_program = self.startup_program
 
         # FIXME(gongwb): delete not need ops.
         # note that: some parameter is not trainable and those ops can't be deleted.
@@ -438,7 +441,18 @@ class DistributeTranspiler(object):
             #add concat ops to merge splited parameters received from parameter servers.
             if len(splited_var) <= 1:
                 continue
-            orig_param = startup_program.global_block().vars[varname]
+            # NOTE: if enable memory optimization, origin vars maybe removed.
+            if startup_program.global_block().vars.has_key(varname):
+                orig_param = startup_program.global_block().vars[varname]
+            else:
+                origin_param_var = self.origin_program.global_block().vars[
+                    varname]
+                orig_param = startup_program.global_block().create_var(
+                    name=varname,
+                    persistable=origin_param_var.persistable,
+                    type=origin_param_var.type,
+                    dtype=origin_param_var.dtype,
+                    shape=origin_param_var.shape)
             startup_program.global_block().append_op(
                 type="concat",
                 inputs={"X": splited_var},
@@ -461,7 +475,9 @@ class DistributeTranspiler(object):
         # NOTE: assume blocks of the same variable is not distributed
         # on the same pserver, only change param/grad varnames for
         # trainers to fetch.
-
+        sys.stderr.write("get_pserver_program() is deprecated, call\
+            get_pserver_programs() to get pserver main and startup\
+            in a single call.")
         # step1
         pserver_program = Program()
         pserver_program.random_seed = self.origin_program.random_seed
@@ -651,32 +667,58 @@ class DistributeTranspiler(object):
             endpoint)
 
         pserver_program._sync_with_cpp()
+        # save pserver program to generate pserver side startup relatively.
+        self.pserver_program = pserver_program
         return pserver_program
 
+    def get_pserver_programs(self, endpoint):
+        """
+        Get pserver side main program and startup program for distributed training.
+
+        Args:
+            endpoint (str): current pserver endpoint.
+        
+        Returns:
+            tuple: (main_program, startup_program), of type "Program"
+        """
+        pserver_prog = self.get_pserver_program(endpoint)
+        pserver_startup = self.get_startup_program(endpoint)
+        return pserver_prog, pserver_startup
+
     def get_startup_program(self,
                             endpoint,
-                            pserver_program,
+                            pserver_program=None,
                             startup_program=None):
         """
+        **Deprecated**
+
         Get startup program for current parameter server.
         Modify operator input variables if there are variables that
         were split to several blocks.
 
         Args:
             endpoint (str): current pserver endpoint.
-            pserver_program (Program): call get_pserver_program first and
-                pass the result here.
-            startup_program (Program): if pass None, will use
-                default_startup_program
+            pserver_program (Program): deprecated, call get_pserver_program first.
+            startup_program (Program): deprecated, should pass startup_program
+                when initalizing 
 
         Returns:
             Program: parameter server side startup program.
         """
+        sys.stderr.write("get_startup_program() is deprecated, call\
+            get_pserver_programs() to get pserver main and startup\
+            in a single call.")
+        if pserver_program != None:
+            sys.stderr.write("passing pserver_program to get_startup_program()\
+                is deprecated, you can use new API get_pserver_programs() to\
+                get both pserver main program and startup program.")
+        if startup_program != None:
+            sys.stderr.write("passing startup_program to get_startup_program()\
+                is deprecated, use fluid.program_guard() or pass this argument\
+                to transpile() call.")
+
         s_prog = Program()
-        if not startup_program:
-            orig_s_prog = default_startup_program()
-        else:
-            orig_s_prog = startup_program
+        orig_s_prog = self.startup_program
         s_prog.random_seed = orig_s_prog.random_seed
         params = self.param_grad_ep_mapping[endpoint]["params"]