From 7ab5626dee0e08262d0b3eaf8f89f05217cccde8 Mon Sep 17 00:00:00 2001
From: Jacek Czaja <jacek.czaja@intel.com>
Date: Thu, 13 Sep 2018 14:26:52 +0200
Subject: [PATCH 01/13] - Added initial pass for embedding-fc-lstm

- Added draft of new operator

- Added fused embedding fc lstm files

- First time embedding_fc_lstm_fuse_pass was invoked in
  test_text_classification

- Added Embedding pattern

- Not crashing

- Enabled draft of embedding_fc_lstm pass (does it job)

- First working (Seqcompute only) version

- Removed diagnostic comment

- First enabling of BatchCompute

- Disabling pass for embedding with is_sparse and is_distributed

- Cosmetics

- Style

- Style
---
 paddle/fluid/framework/ir/CMakeLists.txt      |   1 +
 .../ir/embedding_fc_lstm_fuse_pass.cc         | 242 +++++++
 .../ir/embedding_fc_lstm_fuse_pass.h          |  40 ++
 .../framework/ir/graph_pattern_detector.cc    |  18 +
 .../framework/ir/graph_pattern_detector.h     |  17 +
 paddle/fluid/inference/analysis/analyzer.h    |  17 +-
 .../operators/fused_embedding_fc_lstm_op.cc   | 608 ++++++++++++++++++
 .../operators/fused_embedding_fc_lstm_op.h    |  41 ++
 8 files changed, 976 insertions(+), 8 deletions(-)
 create mode 100644 paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
 create mode 100644 paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
 create mode 100644 paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
 create mode 100644 paddle/fluid/operators/fused_embedding_fc_lstm_op.h

diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index 4dca3ceb45..01733fdda2 100644
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -34,6 +34,7 @@ endif()
 pass_library(attention_lstm_fuse_pass inference)
 pass_library(infer_clean_graph_pass inference)
 pass_library(fc_lstm_fuse_pass inference)
+pass_library(embedding_fc_lstm_fuse_pass inference)
 pass_library(fc_gru_fuse_pass inference)
 pass_library(seq_concat_fc_fuse_pass inference)
 
diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
new file mode 100644
index 0000000000..38495125c3
--- /dev/null
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
@@ -0,0 +1,242 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
+#include <string>
+#include "paddle/fluid/framework/lod_tensor.h"
+
+#include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/cpu_vec.h"
+#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+static int BuildFusion(Graph* graph, const std::string& name_scope,
+                       Scope* scope, bool with_fc_bias) {
+  GraphPatternDetector gpd;
+  auto* pattern = gpd.mutable_pattern();
+
+  // Build pattern
+  PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
+                  ->assert_is_op_input("lookup_table")
+                  ->assert_var_not_persistable();
+  patterns::Embedding embedding_pattern(pattern, name_scope);
+  // TODO(jczaja): Intermediate can only be for val that are not used anywhere
+  //               but lookup table output may go into other LSTM (for reverse
+  //               direction)
+  auto* embedding_out = embedding_pattern(x);
+  patterns::FC fc_pattern(pattern, name_scope);
+
+  // fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
+  auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate();
+  patterns::LSTM lstm_pattern(pattern, name_scope);
+  lstm_pattern(fc_out);
+
+  // Create New OpDesc
+  auto embedding_lstm_creator = [&](Node* embedding, Node* W, Node* lstm,
+                                    Node* input, Node* weight_x, Node* weight_h,
+                                    Node* bias, Node* hidden, Node* cell,
+                                    Node* xx, Node* fc_bias) {
+    OpDesc op_desc;
+    op_desc.SetType("fused_embedding_fc_lstm");
+#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
+    SET_IN(Ids, input);
+    SET_IN(WeightH, weight_h);
+    // Neet to have this passed as We need Wc data for peephole connections
+    SET_IN(Bias, bias);
+#undef SET_IN
+
+    // Multiply embeddings with Weights
+    PADDLE_ENFORCE(scope);
+    const std::string& embeddings = patterns::UniqueKey("Embeddings");
+    auto* embeddings_var = scope->Var(embeddings);
+    PADDLE_ENFORCE(embeddings_var);
+    auto* embeddings_tensor =
+        embeddings_var->GetMutable<framework::LoDTensor>();
+    // Get WeightX size: [single_embedding, fc_size]
+    // and embedding size: [dict_size, single_embedding]
+    // and create new size of embeddings eg. [dict_size , hidden_size]
+    auto* embedding_var = scope->FindVar(W->Name());
+    PADDLE_ENFORCE(embedding_var);
+    const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
+
+    const auto& weightx_tensor =
+        scope->FindVar(weight_x->Name())->Get<framework::LoDTensor>();
+    embeddings_tensor->Resize(
+        {embedding_tensor.dims()[0], weightx_tensor.dims()[1]});
+
+    // Multiplie embeddings via WeightsX and add bias
+    auto embedding_data = embedding_tensor.data<float>();
+    auto weightx_data = weightx_tensor.data<float>();
+    auto embeddings_data =
+        embeddings_tensor->mutable_data<float>(platform::CPUPlace());
+
+    // Adding biases to GEMM result to be
+    auto* lstm_bias_var = scope->FindVar(bias->Name());
+    PADDLE_ENFORCE(lstm_bias_var);
+    const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
+
+    auto alpha = 1.0f;
+    auto beta = 1.0f;
+    int m = embedding_tensor.dims()[0];
+    int n = weightx_tensor.dims()[1];
+    int k = embedding_tensor.dims()[1];
+
+    // Copy only gate biases values (only actual bias data, not peephole
+    // weights)
+    std::vector<float> combined_biases(n, 0.0f);
+    memcpy(&combined_biases[0], lstm_bias_tensor.data<float>(),
+           n * sizeof(float));
+
+    if (with_fc_bias) {
+      // Add FC-bias with LSTM-bias (into GEMM result to be)
+      auto* fc_bias_var = scope->FindVar(fc_bias->Name());
+      const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
+      for (int i = 0; i < fc_bias_tensor.numel(); i++) {
+        combined_biases[i] =
+            lstm_bias_tensor.data<float>()[i] + fc_bias_tensor.data<float>()[i];
+      }
+    }
+
+    // broadcast biases
+    std::vector<float> ones(m, 1.0f);
+    paddle::operators::math::CBlas<float>::GEMM(
+        CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1,
+        &combined_biases[0], n, 0.0f, embeddings_data, n);
+
+    // Wx*embeddings
+    paddle::operators::math::CBlas<float>::GEMM(
+        CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
+        embedding_data, k, weightx_data, n, beta, embeddings_data, n);
+    op_desc.SetInput("Embeddings", {embeddings});
+
+    // Create temp variables.
+    const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
+    const std::string BatchedCellPreAct =
+        patterns::UniqueKey("BatchedCellPreAct");
+    const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
+
+    scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
+    scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
+    scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
+
+    op_desc.SetInput("H0", {});
+    op_desc.SetInput("C0", {});
+    op_desc.SetOutput("Hidden", {hidden->Name()});
+    op_desc.SetOutput("Cell", {cell->Name()});
+    op_desc.SetOutput("XX", {xx->Name()});
+    op_desc.SetOutput("BatchedGate", {BatchedGate});
+    op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
+    op_desc.SetOutput("BatchedInput", {BatchedInput});
+    op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
+    op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
+    // TODO(TJ): get from attr
+    op_desc.SetAttr("use_seq", true);
+
+    PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
+    auto* scope = graph->Get<Scope*>(kParamScopeAttr);
+#define OP_SET_OUT(x)                            \
+  const std::string x = patterns::UniqueKey(#x); \
+  op_desc.SetOutput(#x, {x});                    \
+  scope->Var(x)->GetMutable<LoDTensor>()
+    OP_SET_OUT(BatchedCell);
+    OP_SET_OUT(BatchedHidden);
+    OP_SET_OUT(ReorderedH0);
+    OP_SET_OUT(ReorderedC0);
+#undef OP_SET_OUT
+
+    auto* op = graph->CreateOpNode(&op_desc);
+    IR_NODE_LINK_TO(input, op);
+    IR_NODE_LINK_TO(weight_x, op);
+    IR_NODE_LINK_TO(weight_h, op);
+    IR_NODE_LINK_TO(bias, op);
+    IR_NODE_LINK_TO(op, hidden);
+    return op;
+  };
+
+  int fusion_count{0};
+
+  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
+                     Graph* g) {
+    GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, embedding_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(W, W, embedding_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
+
+    // TODO(jczaja): Add support for is_sparse / is_distributed
+    auto is_sparse = boost::get<bool>(lookup_table->Op()->GetAttr("is_sparse"));
+    auto is_distributed =
+        boost::get<bool>(lookup_table->Op()->GetAttr("is_distributed"));
+
+    if (is_sparse == true || is_distributed == true) {
+      return;
+    }
+
+    if (with_fc_bias) {
+      GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
+      GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
+      GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
+      embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
+                             Bias, Hidden, Cell, fc_out, fc_bias);
+      // Remove unneeded nodes.
+      // TODO(jczaja): Proper removing of loopup table
+      std::unordered_set<const Node*> marked_nodes(
+          //{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
+          {mul, lstm, elementwise_add, fc_bias});
+      GraphSafeRemoveNodes(graph, marked_nodes);
+    } else {
+      GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
+      embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
+                             Bias, Hidden, Cell, fc_out, nullptr);
+      // Remove unneeded nodes.
+      // TODO(jczaja): Proper removing of loopup table
+      // std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
+      // lstm});
+      std::unordered_set<const Node*> marked_nodes({mul, lstm});
+      GraphSafeRemoveNodes(graph, marked_nodes);
+    }
+
+    ++fusion_count;
+  };
+
+  gpd(graph, handler);
+
+  return fusion_count;
+}
+
+std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
+    std::unique_ptr<ir::Graph> graph) const {
+  FusePassBase::Init(name_scope_, graph.get());
+
+  int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
+                                 true /*with_fc_bias*/);
+
+  AddStatis(fusion_count);
+  return graph;
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(embedding_fc_lstm_fuse_pass,
+              paddle::framework::ir::EmbeddingFCLSTMFusePass);
diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
new file mode 100644
index 0000000000..e5ad3067ec
--- /dev/null
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
@@ -0,0 +1,40 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+// Fusing of Embedding , FC and LSTM op
+
+// Just FC without bias
+class EmbeddingFCLSTMFusePass : public FusePassBase {
+ public:
+  virtual ~EmbeddingFCLSTMFusePass() {}
+
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
+
+  const std::string name_scope_{"embedding_fc_lstm_fuse"};
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc
index 6d2c51b0e9..46c6a52c09 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc
@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
   }
 }
 
+PDNode *patterns::Embedding::operator()(PDNode *x) {
+  x->assert_is_op_input("lookup_table", "Ids");
+  auto *lookup_table_op =
+      pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table");
+#define NEW_NODE(arg__, io__)                    \
+  auto *arg__ = pattern->NewNode(arg__##_repr()) \
+                    ->assert_is_op_##io__("lookup_table", #arg__);
+
+  NEW_NODE(W, input);
+
+  NEW_NODE(Out, output);
+#undef NEW_NODE
+
+  lookup_table_op->LinksFrom({x, W});
+  lookup_table_op->LinksTo({Out});
+  return Out;
+}
+
 PDNode *patterns::LSTM::operator()(PDNode *x) {
   x->assert_is_op_input("lstm", "Input");
   auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h
index 69b486c29d..508113bf4f 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.h
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.h
@@ -418,6 +418,23 @@ struct FC : public PatternBase {
   PATTERN_DECL_NODE(Out);
 };
 
+// Embedding
+struct Embedding : public PatternBase {
+  Embedding(PDPattern* pattern, const std::string& name_scope)
+      : PatternBase(pattern, name_scope, "embedding") {}
+
+  PDNode* operator()(PDNode* x);
+
+  // declare operator node's name
+  PATTERN_DECL_NODE(lookup_table);
+  // Inputs
+  //
+  PATTERN_DECL_NODE(Ids);
+  PATTERN_DECL_NODE(W);  // embeddings
+  // Outputs
+  PATTERN_DECL_NODE(Out);
+};
+
 struct LSTM : public PatternBase {
   LSTM(PDPattern* pattern, const std::string& name_scope)
       : PatternBase(pattern, name_scope, "lstm") {}
diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h
index 9bdbefc07c..0aa9367bf5 100644
--- a/paddle/fluid/inference/analysis/analyzer.h
+++ b/paddle/fluid/inference/analysis/analyzer.h
@@ -64,14 +64,15 @@ class Analyzer : public OrderedRegistry<PassManager> {
   // larger fusion.
   const std::vector<std::string> all_ir_passes_{{
       // Manual update the passes here.
-      "infer_clean_graph_pass",    //
-      "attention_lstm_fuse_pass",  //
-      "fc_lstm_fuse_pass",         //
-      "mul_lstm_fuse_pass",        //
-      "fc_gru_fuse_pass",          //
-      "mul_gru_fuse_pass",         //
-      "seq_concat_fc_fuse_pass",   //
-      "fc_fuse_pass",              //
+      "infer_clean_graph_pass",       //
+      "attention_lstm_fuse_pass",     //
+      "embedding_fc_lstm_fuse_pass",  //
+      "fc_lstm_fuse_pass",            //
+      "mul_lstm_fuse_pass",           //
+      "fc_gru_fuse_pass",             //
+      "mul_gru_fuse_pass",            //
+      "seq_concat_fc_fuse_pass",      //
+      "fc_fuse_pass",                 //
 #ifdef PADDLE_WITH_MKLDNN
       "conv_relu_mkldnn_fuse_pass",  //
 #endif
diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
new file mode 100644
index 0000000000..3c4cc77452
--- /dev/null
+++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
@@ -0,0 +1,608 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/fused_embedding_fc_lstm_op.h"
+#include <string>
+#include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/cpu_vec.h"
+#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/sequence2batch.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace operators {
+
+void FusedEmbeddingFCLSTMOp::InferShape(
+    framework::InferShapeContext* ctx) const {
+  PADDLE_ENFORCE(ctx->HasInput("Embeddings"),
+                 "Assert only one Input(Embeddings) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasInput("WeightH"),
+                 "Assert only one Input(WeightH) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
+                 "Assert only one Output(Hidden) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasOutput("Cell"),
+                 "Assert only one Output(Cell) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasInput("Ids"),
+                 "Input(Ids) of LookupTableOp should not be null.");
+
+  auto table_dims = ctx->GetInputDim("Embeddings");
+  auto ids_dims = ctx->GetInputDim("Ids");
+  int ids_rank = ids_dims.size();
+
+  PADDLE_ENFORCE_EQ(table_dims.size(), 2);
+  PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
+                    "The last dimension of the 'Ids' tensor must be 1.");
+
+  auto x_dims = ctx->GetInputDim("Ids");
+  PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(Ids)'s rank must be 2.");
+
+  if (ctx->HasInput("H0")) {
+    PADDLE_ENFORCE(ctx->HasInput("C0"),
+                   "Input(Cell) and Input(Hidden) of LSTM should not "
+                   "be null at the same time.");
+    auto h_dims = ctx->GetInputDim("H0");
+    auto c_dims = ctx->GetInputDim("C0");
+    PADDLE_ENFORCE(h_dims == c_dims,
+                   "The dimension of Input(H0) and Input(C0) "
+                   "should be the same.");
+  }
+
+  auto embeddings_dims = ctx->GetInputDim("Embeddings");
+  PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2,
+                    "The rank of Input(Embeddings) should be 2.");
+  //  PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
+  //                    "The first dimension of Input(Embeddings) "
+  //                    "should be %d.",
+  //                    x_dims[1]);
+
+  auto wh_dims = ctx->GetInputDim("WeightH");
+  int frame_size = wh_dims[1] / 4;
+  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
+                    "The rank of Input(WeightH) should be 2.");
+  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
+                    "The first dimension of Input(WeightH) "
+                    "should be %d.",
+                    frame_size);
+  PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
+                    "The second dimension of Input(WeightH) "
+                    "should be 4 * %d.",
+                    frame_size);
+
+  auto b_dims = ctx->GetInputDim("Bias");
+  PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
+  PADDLE_ENFORCE_EQ(b_dims[0], 1,
+                    "The first dimension of Input(Bias) should be 1.");
+  PADDLE_ENFORCE_EQ(
+      b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
+      "The second dimension of Input(Bias) should be "
+      "7 * %d if enable peepholes connection or"
+      "4 * %d if disable peepholes",
+      frame_size, frame_size);
+
+  framework::DDim out_dims({x_dims[0], frame_size});
+  ctx->SetOutputDim("Hidden", out_dims);
+  ctx->SetOutputDim("Cell", out_dims);
+  ctx->ShareLoD("Ids", "Hidden");
+  ctx->ShareLoD("Ids", "Cell");
+  int xx_width;
+  if (ctx->Attrs().Get<bool>("use_seq")) {
+    xx_width = wh_dims[1];
+  } else {
+    xx_width = x_dims[1] > wh_dims[1] ? wh_dims[1] : x_dims[1];
+    PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
+                   "Assert only one Output(BatchedInput) of LSTM.");
+    PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
+                   "Assert only one Output(BatchedHidden) of LSTM.");
+    PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
+                   "Assert only one Output(BatchedCell) of LSTM.");
+    PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
+                   "Assert only one Output(ReorderedH0) of LSTM");
+    PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
+                   "Assert only one Output(ReorderedC0) of LSTM.");
+    ctx->SetOutputDim("BatchedInput", {x_dims[0], wh_dims[1]});
+    ctx->SetOutputDim("BatchedHidden", out_dims);
+    ctx->SetOutputDim("BatchedCell", out_dims);
+  }
+  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
+  ctx->ShareLoD("Ids", "XX");
+}
+
+framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
+    const framework::ExecutionContext& ctx) const {
+  return framework::OpKernelType(
+      framework::ToDataType(
+          ctx.Input<framework::LoDTensor>("Embeddings")->type()),
+      ctx.device_context());
+}
+
+void FusedEmbeddingFCLSTMOpMaker::Make() {
+  AddInput("Ids",
+           "An input with type int32 or int64 "
+           "contains the ids to be looked up in W. "
+           "The last dimension size must be 1.");
+  AddInput("Embeddings",
+           "(Tensor) the learnable weights of X."
+           " - The shape is (M x 4D), where M is the dim size of x, D is the "
+           "hidden size. "
+           " - Weight = {W_cx, W_ix, W_fx, W_ox}");
+  AddInput("WeightH",
+           "(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
+           " - The shape is (D x 4D), where D is the hidden size. "
+           " - Weight = {W_ch, W_ih, W_fh, W_oh}");
+  AddInput("Bias",
+           "(Tensor) the learnable weights. Almost same as LSTMOp"
+           "Note: we should add the fc bias into this (1x4D) in bias."
+           "input-hidden bias weight and peephole connections weight if "
+           "setting `use_peepholes` True. "
+           "1. `use_peepholes = False` "
+           " - The shape is (1 x 4D). "
+           " - Bias = {b_c, b_i, b_f, b_o}."
+           "2. `use_peepholes = True` "
+           " - The shape is (1 x 7D). "
+           " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
+  AddInput("H0",
+           "(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
+           "optional "
+           "input. This is a tensor with shape (N x D), where N is the "
+           "batch size and D is the hidden size.")
+      .AsDispensable();
+  AddInput("C0",
+           "(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
+           "optional "
+           "input. This is a tensor with shape (N x D), where N is the "
+           "batch size. `H0` and `C0` can be NULL but only at the same time.")
+      .AsDispensable();
+  AddOutput("Hidden",
+            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
+            "The shape is (T x D), and lod is the same with the `Input`.");
+  AddOutput("Cell",
+            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
+            "The shape is (T x D), and lod is the same with the `Input`.");
+  AddOutput("XX",
+            "(LoDTensor) the result after X * WeightX (size is T x 4D)"
+            " or batched_X (size is T x M), this will be automatically chosen,"
+            " where T is the total time steps in this mini-batch,"
+            " D is the hidden size, M is the dim size of x input.")
+      .AsIntermediate();
+  AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate();
+  AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate();
+  AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
+  AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
+  AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
+  AddAttr<bool>("use_peepholes",
+                "(bool, defalut: True) "
+                "whether to enable diagonal/peephole connections.")
+      .SetDefault(true);
+  AddAttr<bool>("is_reverse",
+                "(bool, defalut: False) "
+                "whether to compute reversed LSTM.")
+      .SetDefault(false);
+  AddAttr<bool>("use_seq",
+                "(bool, defalut: True) "
+                "whether to use seq mode to compute.")
+      .SetDefault(true);
+  AddAttr<std::string>("gate_activation",
+                       "(string, default: sigmoid)"
+                       "The activation for input gate, forget gate and output "
+                       "gate, `sigmoid` by default.")
+      .SetDefault("sigmoid")
+      .InEnum({"sigmoid", "tanh", "relu", "identity"});
+  AddAttr<std::string>("cell_activation",
+                       "(string, default: tanh)"
+                       "The activation for cell output, `tanh` by defalut.")
+      .SetDefault("tanh")
+      .InEnum({"sigmoid", "tanh", "relu", "identity"});
+  AddAttr<std::string>("candidate_activation",
+                       "(string, default: tanh)"
+                       "The activation for candidate hidden state, "
+                       "`tanh` by default.")
+      .SetDefault("tanh")
+      .InEnum({"sigmoid", "tanh", "relu", "identity"});
+  AddComment(R"DOC(
+Fusion Long-Short Term Memory (LSTM) Operator.
+This operator fuse the X into LSTM, more details can refer to LSTM op.
+)DOC");
+}
+
+template <typename T>
+class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
+ public:
+#define INIT_VEC_FUNC                                                          \
+  std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
+  auto& act_gate_str = ctx.Attr<std::string>("gate_activation");               \
+  auto& act_cell_str = ctx.Attr<std::string>("cell_activation");               \
+  auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");          \
+  if (platform::jit::MayIUse(platform::jit::avx)) {                            \
+    math::VecActivations<T, platform::jit::avx> act_functor;                   \
+    act_gate = act_functor(act_gate_str);                                      \
+    act_cell = act_functor(act_cell_str);                                      \
+    act_cand = act_functor(act_cand_str);                                      \
+  } else {                                                                     \
+    math::VecActivations<T, platform::jit::isa_any> act_functor;               \
+    act_gate = act_functor(act_gate_str);                                      \
+    act_cell = act_functor(act_cell_str);                                      \
+    act_cand = act_functor(act_cand_str);                                      \
+  }
+
+#define INIT_BASE_INPUT_OUTPUT                        \
+  auto* ids = ctx.Input<LoDTensor>("Ids");            \
+  auto* h0 = ctx.Input<Tensor>("H0");                 \
+  auto* c0 = ctx.Input<Tensor>("C0");                 \
+  auto* embeddings = ctx.Input<Tensor>("Embeddings"); \
+  auto* wh = ctx.Input<Tensor>("WeightH");            \
+  auto* bias = ctx.Input<Tensor>("Bias");             \
+  auto* xx = ctx.Output<LoDTensor>("XX");             \
+  auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
+  auto* cell_out = ctx.Output<LoDTensor>("Cell");     \
+  bool is_reverse = ctx.Attr<bool>("is_reverse");     \
+  bool use_peepholes = ctx.Attr<bool>("use_peepholes");
+
+#define INIT_BASE_SIZES                       \
+  auto ids_dims = ids->dims();   /* T x M*/   \
+  auto ids_numel = ids->numel(); /* T x 1*/   \
+  auto wh_dims = wh->dims();     /* D x 4D*/  \
+  const int D = wh_dims[0];                   \
+  const int D2 = D * 2;                       \
+  const int D3 = D * 3;                       \
+  int64_t row_number = embeddings->dims()[0]; \
+  int64_t row_width = embeddings->dims()[1];  \
+  const int D4 = wh_dims[1];
+
+#define INIT_BASE_INPUT_DATAS                                        \
+  const int64_t* ids_data = ids->data<int64_t>();                    \
+  const T* embeddings_data = embeddings->data<T>();                  \
+  const T* wh_data = wh->data<T>();                                  \
+  /* diagonal weight*/                                               \
+  const T* wc_data = bias->data<T>() + D4;                           \
+  /* for peephole only*/                                             \
+  Tensor checked_cell;                                               \
+  T* checked_cell_data = nullptr;                                    \
+  auto place = ctx.GetPlace();                                       \
+  if (use_peepholes) {                                               \
+    /* w_ic * Ct-1, w_fc * Ct-1  ; w_oc * Ct => ih*/                 \
+    checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
+  }
+
+/// Compute LSTM
+#define GEMM_WH_ADDON(bs, prev, out)                                           \
+  blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
+            wh_data, D4, static_cast<T>(1), out, D4)
+
+// gates: W_ch, W_ih, W_fh, W_oh
+#define GET_Ct(ct_1, gates, ct)                   \
+  /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
+  act_cand(D, gates, gates);                      \
+  blas.VMUL(D, gates, gates + D, gates + D);      \
+  blas.VMUL(D, ct_1, gates + D2, gates + D2);     \
+  blas.VADD(D, gates + D, gates + D2, ct)
+
+#define GET_Ht(ct, gates, ht)        \
+  /* H_t = act_cell(C_t) * ogated */ \
+  act_cell(D, ct, gates + D2);       \
+  blas.VMUL(D, gates + D2, gates + D3, ht)
+
+#define GET_Ct_NOH0C0(gates, ct)     \
+  /* C_t = igated * cgated*/         \
+  act_gate(D, gates + D, gates + D); \
+  act_cand(D, gates, gates);         \
+  blas.VMUL(D, gates, gates + D, ct)
+
+#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
+  GET_Ct_NOH0C0(gates, ct);                \
+  act_gate(D, gates + D3, gates + D3);     \
+  GET_Ht(ct, gates, ht)
+
+#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
+  GET_Ct_NOH0C0(gates, ct);                         \
+  /* get outgated, put W_oc * C_t on igated */      \
+  blas.VMUL(D, wc_data + D2, ct, gates + D);        \
+  blas.VADD(D, gates + D, gates + D3, gates + D3);  \
+  act_gate(D, gates + D3, gates + D3);              \
+  GET_Ht(ct, gates, ht)
+
+#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
+  act_gate(D3, gates + D, gates + D);     \
+  GET_Ct(ct_1, gates, ct);                \
+  GET_Ht(ct, gates, ht)
+
+#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht)        \
+  /* get fgated and igated*/                              \
+  blas.VMUL(D, wc_data, ct_1, checked_cell_data);         \
+  blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
+  blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
+  act_gate(D2, gates + D, gates + D);                     \
+  GET_Ct(ct_1, gates, ct);                                \
+  /* get ogated*/                                         \
+  blas.VMUL(D, wc_data + D2, ct, gates + D);              \
+  blas.VADD(D, gates + D, gates + D3, gates + D3);        \
+  act_gate(D, gates + D3, gates + D3);                    \
+  GET_Ht(ct, gates, ht)
+
+  void SeqCompute(const framework::ExecutionContext& ctx) const {
+    using DeviceContext = paddle::platform::CPUDeviceContext;
+    INIT_BASE_INPUT_OUTPUT
+    INIT_BASE_SIZES
+    INIT_VEC_FUNC
+    INIT_BASE_INPUT_DATAS
+
+    //  std::cout << "====> SeqCompute" << std::endl;
+    auto ids_lod = ids->lod();
+    const int total_T = ids_dims[0];
+    const int N = ids_lod[0].size() - 1;
+    const T* h0_data = h0 ? h0->data<T>() : nullptr;
+    const T* c0_data = c0 ? c0->data<T>() : nullptr;
+    T* xx_data = xx->mutable_data<T>(place);
+    T* h_out_data = hidden_out->mutable_data<T>(place);
+    T* c_out_data = cell_out->mutable_data<T>(place);
+    auto blas = math::GetBlas<DeviceContext, T>(ctx);
+
+    for (int64_t i = 0; i < ids_numel; ++i) {
+      PADDLE_ENFORCE_LT(ids_data[i], row_number);
+      PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i);
+      memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
+             row_width * sizeof(T));
+    }
+
+    int xx_offset = D4;
+    int gate_offset = D;
+    if (is_reverse) {
+      const int offset = (total_T - 1) * D;
+      xx_data = xx_data + offset * 4;
+      h_out_data = h_out_data + offset;
+      c_out_data = c_out_data + offset;
+      xx_offset = -D4;
+      gate_offset = -D;
+    }
+
+#define MOVE_ONE_STEP                    \
+  prev_h_data = h_out_data;              \
+  prev_c_data = c_out_data;              \
+  xx_data = xx_data + xx_offset;         \
+  h_out_data = h_out_data + gate_offset; \
+  c_out_data = c_out_data + gate_offset
+
+#define PROCESS_H0C0_DEFINES                           \
+  int bid = is_reverse ? N - 1 - i : i;                \
+  int seq_len = ids_lod[0][bid + 1] - ids_lod[0][bid]; \
+  const T* prev_c_data = nullptr;                      \
+  const T* prev_h_data = nullptr;                      \
+  int tstart = 0
+
+#define PROCESS_H0C0_PEEPHOLE                                      \
+  PROCESS_H0C0_DEFINES;                                            \
+  if (h0_data) {                                                   \
+    prev_h_data = h0_data + bid * D;                               \
+    prev_c_data = c0_data + bid * D;                               \
+  } else {                                                         \
+    COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
+    MOVE_ONE_STEP;                                                 \
+    tstart = 1;                                                    \
+  }
+
+#define PROCESS_H0C0                                      \
+  PROCESS_H0C0_DEFINES;                                   \
+  if (h0_data) {                                          \
+    prev_h_data = h0_data + bid * D;                      \
+    prev_c_data = c0_data + bid * D;                      \
+  } else {                                                \
+    COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
+    MOVE_ONE_STEP;                                        \
+    tstart = 1;                                           \
+  }
+
+    if (use_peepholes) {
+      for (int i = 0; i < N; ++i) {
+        PROCESS_H0C0_PEEPHOLE
+        for (int step = tstart; step < seq_len; ++step) {
+          GEMM_WH_ADDON(1, prev_h_data, xx_data);
+          COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data);
+          MOVE_ONE_STEP;
+        }
+      }
+    } else {
+      for (int i = 0; i < N; ++i) {
+        PROCESS_H0C0
+        for (int step = tstart; step < seq_len; ++step) {
+          GEMM_WH_ADDON(1, prev_h_data, xx_data);
+          COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
+          MOVE_ONE_STEP;
+        }
+      }
+    }
+#undef PROCESS_H0C0_DEFINES
+#undef PROCESS_H0C0_PEEPHOLE
+#undef PROCESS_H0C0
+#undef MOVE_ONE_STEP
+  }
+
+  void BatchCompute(const framework::ExecutionContext& ctx) const {
+    using DeviceContext = platform::CPUDeviceContext;
+    INIT_BASE_INPUT_OUTPUT
+    if (ids->lod()[0].size() == 2) {
+      SeqCompute(ctx);
+      return;
+    }
+    INIT_BASE_SIZES
+    INIT_VEC_FUNC
+    INIT_BASE_INPUT_DATAS
+
+    // std::cout << "===> Batch Compute" << std::endl;
+
+    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
+    auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
+    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
+    auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
+    auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
+    T* xx_data = xx->mutable_data<T>(place);
+    T* batched_input_data = batched_input->mutable_data<T>(place);
+    T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
+    T* batched_h_out_data = batched_h_out->mutable_data<T>(place);
+    hidden_out->mutable_data<T>(place);
+    cell_out->mutable_data<T>(place);
+
+    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
+    auto& dev_ctx = ctx.template device_context<DeviceContext>();
+    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
+
+    for (int64_t i = 0; i < ids_numel; ++i) {
+      PADDLE_ENFORCE_LT(ids_data[i], row_number);
+      PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i);
+      memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
+             row_width * sizeof(T));
+    }
+
+    to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
+
+    auto batched_lod = batched_input->lod();
+    const auto& seq_order = batched_lod[2];
+    const int max_bs = seq_order.size();
+    reordered_h0->Resize({max_bs, D});
+    reordered_c0->Resize({max_bs, D});
+
+    int tstart = 0;
+    T* prev_h_data = nullptr;
+    T* prev_c_data = nullptr;
+    if (h0) {
+      // reorder h0, c0
+      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
+      T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
+      const T* h0_data = h0->data<T>();
+      const T* c0_data = c0->data<T>();
+      prev_h_data = reordered_h0_data;
+      prev_c_data = reordered_c0_data;
+      size_t sz = sizeof(T) * D;
+      for (int i = 0; i < max_bs; ++i) {
+        std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
+        std::memcpy(reordered_c0_data, c0_data + seq_order[i] * D, sz);
+        reordered_h0_data += D;
+        reordered_c0_data += D;
+      }
+    } else {
+      // compute without h0, c0
+      T* cur_in_data = batched_input_data;
+      T* cur_h_out_data = batched_h_out_data;
+      T* cur_c_out_data = batched_c_out_data;
+      for (int i = 0; i < max_bs; ++i) {
+        GET_Ct_NOH0C0(cur_in_data, cur_c_out_data);
+        if (use_peepholes) {
+          blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D);
+          blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
+        }
+        act_gate(D, cur_in_data + D3, cur_in_data + D3);
+        GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data);
+        cur_in_data += D4;
+        cur_c_out_data += D;
+        cur_h_out_data += D;
+      }
+      tstart = 1;
+      prev_h_data = batched_h_out_data;
+      prev_c_data = batched_c_out_data;
+    }
+    const auto& batch_starts = batched_lod[0];
+    const int max_seq_len = batch_starts.size() - 1;
+    const int offset = tstart * max_bs * D;
+    batched_input_data = batched_input_data + offset * 4;
+    batched_h_out_data = batched_h_out_data + offset;
+    batched_c_out_data = batched_c_out_data + offset;
+
+#define DEFINE_CUR                        \
+  T* cur_in_data = batched_input_data;    \
+  T* cur_prev_c_data = prev_c_data;       \
+  T* cur_c_out_data = batched_c_out_data; \
+  T* cur_h_out_data = batched_h_out_data
+
+#define MOVE_ONE_BATCH  \
+  cur_in_data += D4;    \
+  cur_prev_c_data += D; \
+  cur_c_out_data += D;  \
+  cur_h_out_data += D
+
+#define MOVE_ONE_STEP                  \
+  prev_c_data = batched_c_out_data;    \
+  prev_h_data = batched_h_out_data;    \
+  batched_c_out_data = cur_c_out_data; \
+  batched_h_out_data = cur_h_out_data; \
+  batched_input_data = cur_in_data
+
+    if (use_peepholes) {
+      for (int step = tstart; step < max_seq_len; ++step) {
+        const int cur_bs = batch_starts[step + 1] - batch_starts[step];
+        GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
+        DEFINE_CUR;
+        for (int i = 0; i < cur_bs; ++i) {
+          COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data,
+                                cur_h_out_data);
+          MOVE_ONE_BATCH;
+        }
+        MOVE_ONE_STEP;
+      }
+    } else {
+      for (int step = tstart; step < max_seq_len; ++step) {
+        const int cur_bs = batch_starts[step + 1] - batch_starts[step];
+        GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
+        DEFINE_CUR;
+        for (int i = 0; i < cur_bs; ++i) {
+          COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
+                       cur_h_out_data);
+          MOVE_ONE_BATCH;
+        }
+        MOVE_ONE_STEP;
+      }
+    }
+#undef MOVE_ONE_STEP
+#undef MOVE_ONE_BATCH
+#undef DEFINE_CUR
+
+    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
+    batched_h_out->set_lod(batched_lod);
+    to_seq(dev_ctx, *batched_h_out, hidden_out);
+    batched_c_out->set_lod(batched_lod);
+    to_seq(dev_ctx, *batched_c_out, cell_out);
+  }
+
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    if (ctx.Attr<bool>("use_seq")) {
+      SeqCompute(ctx);
+    } else {
+      BatchCompute(ctx);
+    }
+  }
+
+#undef COMPUTE_CtHt_PEEPHOLE
+#undef COMPUTE_CtHt
+#undef GET_Ct_NOH0C0
+#undef COMPUTE_CtHt_NOH0C0
+#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
+#undef GET_Ht
+#undef GET_Ct
+#undef GEMM_WH_ADDON
+#undef INIT_BASE_INPUT_DATAS
+#undef INIT_BASE_SIZES
+#undef INIT_BASE_INPUT_OUTPUT
+#undef INIT_VEC_FUNC
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(fused_embedding_fc_lstm, ops::FusedEmbeddingFCLSTMOp,
+                  ops::FusedEmbeddingFCLSTMOpMaker,
+                  paddle::framework::DefaultGradOpDescMaker<true>);
+
+REGISTER_OP_CPU_KERNEL(fused_embedding_fc_lstm,
+                       ops::FusedEmbeddingFCLSTMKernel<float>,
+                       ops::FusedEmbeddingFCLSTMKernel<double>);
diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.h b/paddle/fluid/operators/fused_embedding_fc_lstm_op.h
new file mode 100644
index 0000000000..2775b2ac04
--- /dev/null
+++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.h
@@ -0,0 +1,41 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using LoDTensor = framework::LoDTensor;
+using Tensor = framework::Tensor;
+
+class FusedEmbeddingFCLSTMOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override;
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override;
+};
+
+class FusedEmbeddingFCLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override;
+};
+
+}  // namespace operators
+}  // namespace paddle

From d5114c60b098a3c5f778d48b70d0683b093b49db Mon Sep 17 00:00:00 2001
From: Jacek Czaja <jacek.czaja@intel.com>
Date: Tue, 25 Sep 2018 11:00:30 +0200
Subject: [PATCH 02/13] - Reviewers suggesstions to fused_embedding_fc_lstm_op

---
 .../fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc | 11 ++++++-----
 paddle/fluid/operators/fused_embedding_fc_lstm_op.cc  |  4 ----
 2 files changed, 6 insertions(+), 9 deletions(-)

diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
index 38495125c3..af3f23cbf9 100644
--- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
+#include <algorithm>
 #include <string>
 #include "paddle/fluid/framework/lod_tensor.h"
 
@@ -98,17 +99,17 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
 
     // Copy only gate biases values (only actual bias data, not peephole
     // weights)
-    std::vector<float> combined_biases(n, 0.0f);
-    memcpy(&combined_biases[0], lstm_bias_tensor.data<float>(),
-           n * sizeof(float));
+    std::vector<float> combined_biases;
+    combined_biases.reserve(n);
+    std::copy_n(lstm_bias_tensor.data<float>(), n,
+                std::back_inserter(combined_biases));
 
     if (with_fc_bias) {
       // Add FC-bias with LSTM-bias (into GEMM result to be)
       auto* fc_bias_var = scope->FindVar(fc_bias->Name());
       const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
       for (int i = 0; i < fc_bias_tensor.numel(); i++) {
-        combined_biases[i] =
-            lstm_bias_tensor.data<float>()[i] + fc_bias_tensor.data<float>()[i];
+        combined_biases[i] += fc_bias_tensor.data<float>()[i];
       }
     }
 
diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
index 3c4cc77452..0b917a4036 100644
--- a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
+++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
@@ -63,10 +63,6 @@ void FusedEmbeddingFCLSTMOp::InferShape(
   auto embeddings_dims = ctx->GetInputDim("Embeddings");
   PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2,
                     "The rank of Input(Embeddings) should be 2.");
-  //  PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
-  //                    "The first dimension of Input(Embeddings) "
-  //                    "should be %d.",
-  //                    x_dims[1]);
 
   auto wh_dims = ctx->GetInputDim("WeightH");
   int frame_size = wh_dims[1] / 4;

From 910cd415f2147291f5cee83c103c1a1bd84e982f Mon Sep 17 00:00:00 2001
From: Jacek Czaja <jacek.czaja@intel.com>
Date: Thu, 27 Sep 2018 14:01:11 +0200
Subject: [PATCH 03/13] - Disabled embedding_fc_lstm_fuse by defult and  
 extended test_text_classification ot use new op

---
 paddle/fluid/inference/api/paddle_inference_api.h   |  2 +-
 .../api/analyzer_text_classification_tester.cc      | 13 +++++++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h
index 984358b2bd..77b04bb6f5 100644
--- a/paddle/fluid/inference/api/paddle_inference_api.h
+++ b/paddle/fluid/inference/api/paddle_inference_api.h
@@ -216,7 +216,7 @@ struct AnalysisConfig : public NativeConfig {
   bool enable_ir_optim = true;
   // Manually determine the IR passes to run.
   IrPassMode ir_mode{IrPassMode::kExclude};
-  std::vector<std::string> ir_passes;
+  std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"};
 
   // NOTE this is just for internal development, please not use it.
   bool _use_mkldnn{false};
diff --git a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
index 340ef152f0..ca19475bda 100644
--- a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
@@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) {
   CompareNativeAndAnalysis(cfg, input_slots_all);
 }
 
+TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  // Enable embedding_fc_lstm_fuse_pass (disabled by default)
+  auto it = std::find(cfg.ir_passes.begin(), cfg.ir_passes.end(),
+                      "embedding_fc_lstm_fuse_pass");
+  if (it != cfg.ir_passes.end()) cfg.ir_passes.erase(it);
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  CompareNativeAndAnalysis(cfg, input_slots_all);
+}
+
 }  // namespace inference
 }  // namespace paddle

From 1df69f7c9dc53e317babc32d0d91842a11fedd97 Mon Sep 17 00:00:00 2001
From: Jacek Czaja <jacek.czaja@intel.com>
Date: Fri, 28 Sep 2018 09:42:13 +0200
Subject: [PATCH 04/13] -  Fix to comment

test=develop
---
 paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
index af3f23cbf9..b155da375f 100644
--- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
@@ -199,7 +199,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
       embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
                              Bias, Hidden, Cell, fc_out, fc_bias);
       // Remove unneeded nodes.
-      // TODO(jczaja): Proper removing of loopup table
+      // TODO(jczaja): Proper removing of lookup table
       std::unordered_set<const Node*> marked_nodes(
           //{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
           {mul, lstm, elementwise_add, fc_bias});
@@ -209,7 +209,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
       embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
                              Bias, Hidden, Cell, fc_out, nullptr);
       // Remove unneeded nodes.
-      // TODO(jczaja): Proper removing of loopup table
+      // TODO(jczaja): Proper removing of lookup table
       // std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
       // lstm});
       std::unordered_set<const Node*> marked_nodes({mul, lstm});

From 9ae5baebfa3939a3af07a3e4338a34bb5667c993 Mon Sep 17 00:00:00 2001
From: JiabinYang <marsyang199376@gmail.com>
Date: Fri, 28 Sep 2018 07:52:24 +0000
Subject: [PATCH 05/13] test=develop

---
 paddle/legacy/trainer/tests/CMakeLists.txt       |  6 +++++-
 .../recognize_digits/CMakeLists.txt              | 16 +++++++++++++---
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/paddle/legacy/trainer/tests/CMakeLists.txt b/paddle/legacy/trainer/tests/CMakeLists.txt
index 08548bea4c..fbefcced56 100644
--- a/paddle/legacy/trainer/tests/CMakeLists.txt
+++ b/paddle/legacy/trainer/tests/CMakeLists.txt
@@ -16,7 +16,11 @@ endfunction()
 trainer_test(test_Compare)
 trainer_test(test_PyDataProviderWrapper)
 trainer_test(test_recurrent_machine_generation)
-trainer_test(test_Trainer)
+if(NOT APPLE)
+  trainer_test(test_Trainer)
+else()
+  message(WARNING "These tests has been disabled in OSX for random fail: \n test_Trainer") 
+endif()
 
 ############### test_TrainerOnePass ##########################
 if(WITH_PYTHON)
diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/CMakeLists.txt
index 673c965b66..ad056aaa7b 100644
--- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/CMakeLists.txt
+++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/CMakeLists.txt
@@ -2,6 +2,16 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
 string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
 
 # default test
-foreach(src ${TEST_OPS})
-    py_test(${src} SRCS ${src}.py)
-endforeach()
+if(NOT APPLE)
+    foreach(src ${TEST_OPS})
+        py_test(${src} SRCS ${src}.py)
+    endforeach()
+else()
+    foreach(src ${TEST_OPS})
+        if(${src} STREQUAL "test_recognize_digits_conv")
+            message(WARNING "These tests has been disabled in OSX for random fail: \n" ${src})
+        else()
+            py_test(${src} SRCS ${src}.py)
+        endif()
+    endforeach()
+endif()

From 358b38695356226875aa7495244e2ea70e8224e9 Mon Sep 17 00:00:00 2001
From: JiabinYang <marsyang199376@gmail.com>
Date: Fri, 28 Sep 2018 10:34:15 +0000
Subject: [PATCH 06/13] test=develop

---
 paddle/fluid/inference/api/api_impl_tester.cc | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/paddle/fluid/inference/api/api_impl_tester.cc b/paddle/fluid/inference/api/api_impl_tester.cc
index 106a941b29..bed7c87131 100644
--- a/paddle/fluid/inference/api/api_impl_tester.cc
+++ b/paddle/fluid/inference/api/api_impl_tester.cc
@@ -21,6 +21,12 @@ limitations under the License. */
 #include "paddle/fluid/inference/api/api_impl.h"
 #include "paddle/fluid/inference/tests/test_helper.h"
 
+#ifdef __clang__
+#define ACC_DIFF 4e-3
+#else
+#define ACC_DIFF 1e-3
+#endif
+
 DEFINE_string(dirname, "", "Directory of the inference model.");
 
 namespace paddle {
@@ -99,8 +105,8 @@ void MainWord2Vec(bool use_gpu) {
 
   float* lod_data = output1.data<float>();
   for (int i = 0; i < output1.numel(); ++i) {
-    EXPECT_LT(lod_data[i] - data[i], 1e-3);
-    EXPECT_GT(lod_data[i] - data[i], -1e-3);
+    EXPECT_LT(lod_data[i] - data[i], ACC_DIFF);
+    EXPECT_GT(lod_data[i] - data[i], -ACC_DIFF);
   }
 }
 
@@ -144,7 +150,7 @@ void MainImageClassification(bool use_gpu) {
   float* data = static_cast<float*>(outputs[0].data.data());
   float* lod_data = output1.data<float>();
   for (size_t j = 0; j < len / sizeof(float); ++j) {
-    EXPECT_NEAR(lod_data[j], data[j], 1e-3);
+    EXPECT_NEAR(lod_data[j], data[j], ACC_DIFF);
   }
 }
 
@@ -199,7 +205,7 @@ void MainThreadsWord2Vec(bool use_gpu) {
       float* ref_data = refs[tid].data<float>();
       EXPECT_EQ(refs[tid].numel(), static_cast<int64_t>(len / sizeof(float)));
       for (int i = 0; i < refs[tid].numel(); ++i) {
-        EXPECT_NEAR(ref_data[i], data[i], 1e-3);
+        EXPECT_NEAR(ref_data[i], data[i], ACC_DIFF);
       }
     });
   }
@@ -251,7 +257,7 @@ void MainThreadsImageClassification(bool use_gpu) {
       float* ref_data = refs[tid].data<float>();
       EXPECT_EQ((size_t)refs[tid].numel(), len / sizeof(float));
       for (int i = 0; i < refs[tid].numel(); ++i) {
-        EXPECT_NEAR(ref_data[i], data[i], 1e-3);
+        EXPECT_NEAR(ref_data[i], data[i], ACC_DIFF);
       }
     });
   }

From e202f33aa96ee8c44f9bac892881dce0fe5067be Mon Sep 17 00:00:00 2001
From: Jacek Czaja <jacek.czaja@intel.com>
Date: Fri, 28 Sep 2018 13:13:43 +0200
Subject: [PATCH 07/13] - Yet another clarification to comment

test=develop
---
 paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
index b155da375f..ba11f19c92 100644
--- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
@@ -119,7 +119,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
         CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1,
         &combined_biases[0], n, 0.0f, embeddings_data, n);
 
-    // Wx*embeddings
+    // Wx*embeddings + biases
     paddle::operators::math::CBlas<float>::GEMM(
         CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
         embedding_data, k, weightx_data, n, beta, embeddings_data, n);

From ddd60581b7f442e8f232f83a760c3d4c537a16b1 Mon Sep 17 00:00:00 2001
From: Xin Pan <panxin.grad@gmail.com>
Date: Fri, 28 Sep 2018 10:19:55 +0800
Subject: [PATCH 08/13] clean up channel

test=develop
---
 paddle/fluid/framework/CMakeLists.txt         |    7 -
 paddle/fluid/framework/channel.h              |  291 -----
 paddle/fluid/framework/channel_impl.h         |  369 ------
 paddle/fluid/framework/channel_test.cc        | 1008 -----------------
 paddle/fluid/framework/concurrency_test.cc    |  292 -----
 paddle/fluid/framework/executor.cc            |    5 +-
 paddle/fluid/framework/framework.proto        |    7 -
 paddle/fluid/framework/tuple.h                |    1 -
 paddle/fluid/framework/var_desc.cc            |   54 +-
 paddle/fluid/framework/var_desc.h             |    4 -
 paddle/fluid/framework/var_type.h             |    6 -
 .../fluid/inference/analysis/analysis_pass.h  |    6 -
 paddle/fluid/operators/CMakeLists.txt         |    5 -
 paddle/fluid/operators/channel_close_op.cc    |   70 --
 paddle/fluid/operators/channel_create_op.cc   |  113 --
 paddle/fluid/operators/channel_recv_op.cc     |   98 --
 paddle/fluid/operators/channel_send_op.cc     |   76 --
 .../operators/concurrency/CMakeLists.txt      |    1 -
 .../operators/concurrency/channel_util.cc     |  111 --
 .../operators/concurrency/channel_util.h      |   38 -
 paddle/fluid/operators/select_op.cc           |  419 -------
 paddle/fluid/pybind/protobuf.cc               |    2 -
 paddle/fluid/pybind/pybind.cc                 |    1 -
 python/paddle/fluid/concurrency.py            |  454 --------
 python/paddle/fluid/framework.py              |    3 +-
 .../paddle/fluid/tests/no_test_concurrency.py |  260 -----
 .../paddle/fluid/tests/notest_concurrency.py  |   41 -
 27 files changed, 4 insertions(+), 3738 deletions(-)
 delete mode 100644 paddle/fluid/framework/channel.h
 delete mode 100644 paddle/fluid/framework/channel_impl.h
 delete mode 100644 paddle/fluid/framework/channel_test.cc
 delete mode 100644 paddle/fluid/framework/concurrency_test.cc
 delete mode 100644 paddle/fluid/operators/channel_close_op.cc
 delete mode 100644 paddle/fluid/operators/channel_create_op.cc
 delete mode 100644 paddle/fluid/operators/channel_recv_op.cc
 delete mode 100644 paddle/fluid/operators/channel_send_op.cc
 delete mode 100644 paddle/fluid/operators/concurrency/CMakeLists.txt
 delete mode 100644 paddle/fluid/operators/concurrency/channel_util.cc
 delete mode 100644 paddle/fluid/operators/concurrency/channel_util.h
 delete mode 100644 paddle/fluid/operators/select_op.cc
 delete mode 100644 python/paddle/fluid/concurrency.py
 delete mode 100644 python/paddle/fluid/tests/no_test_concurrency.py
 delete mode 100644 python/paddle/fluid/tests/notest_concurrency.py

diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index 39898dd236..de960dba8f 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -169,15 +169,8 @@ cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
 cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
 cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
 
-# cc_test(channel_test SRCS channel_test.cc)
 cc_test(tuple_test SRCS tuple_test.cc )
 
 if (NOT WIN32)
 cc_test(rw_lock_test SRCS rw_lock_test.cc)
 endif (NOT WIN32)
-
-# disable test temporarily.
-# TODO https://github.com/PaddlePaddle/Paddle/issues/11971
-# cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
-#         channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
-#         conditional_block_op while_op assign_op print_op executor proto_desc)
diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h
deleted file mode 100644
index 722bf8e8ec..0000000000
--- a/paddle/fluid/framework/channel.h
+++ /dev/null
@@ -1,291 +0,0 @@
-/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-
-#include <stddef.h>            // for size_t
-#include <condition_variable>  // NOLINT
-#include <typeindex>
-#include "paddle/fluid/platform/enforce.h"
-
-namespace paddle {
-namespace framework {
-
-enum class ChannelAction {
-  SEND = 0,
-  RECEIVE = 1,
-  CLOSE = 2,
-};
-
-// Channel is the abstract class of buffered and un-buffered channels.
-template <typename T>
-class Channel {
- public:
-  virtual bool CanSend() = 0;
-  virtual bool CanReceive() = 0;
-  virtual void Send(T*) = 0;
-  virtual bool Receive(T*) = 0;
-  virtual size_t Cap() = 0;
-  virtual void Lock() = 0;
-
-  virtual void Unlock() = 0;
-  virtual bool IsClosed() = 0;
-  virtual void Close() = 0;
-  virtual ~Channel() {}
-
-  virtual void AddToSendQ(const void* referrer, T* data,
-                          std::shared_ptr<std::condition_variable_any> cond,
-                          std::function<bool(ChannelAction)> cb) = 0;
-  virtual void AddToReceiveQ(const void* referrer, T* data,
-                             std::shared_ptr<std::condition_variable_any> cond,
-                             std::function<bool(ChannelAction)> cb) = 0;
-  virtual void RemoveFromSendQ(const void* referrer) = 0;
-  virtual void RemoveFromReceiveQ(const void* referrer) = 0;
-};
-
-// Forward declaration of channel implementations.
-template <typename T>
-class ChannelImpl;
-
-template <typename T>
-Channel<T>* MakeChannel(size_t buffer_size) {
-  return new ChannelImpl<T>(buffer_size);
-}
-
-template <typename T>
-void CloseChannel(Channel<T>* ch) {
-  ch->Close();
-}
-
-/*
- * The ChannelHolder class serves two main purposes:
- * 1. It acts as a unified wrapper for the different kinds of
- *    channels, i.e. Buffered and Unbuffered channels. This is
- *    similar to the ReaderHolder class.
- * 2. It also helps us in TypeHiding. This is similar to the
- *    PlaceHolder implementations in variable.h and tensor.h.
- */
-class ChannelHolder {
- public:
-  template <typename T>
-  void Reset(size_t buffer_size) {
-    holder_.reset(new PlaceholderImpl<T>(buffer_size));
-  }
-
-  template <typename T>
-  void Send(T* data) {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    PADDLE_ENFORCE_EQ(
-        holder_->Type(), std::type_index(typeid(T)),
-        "Channel type is not same as the type of the data being sent");
-    // Static cast should be safe because we have ensured that types are same
-    Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
-    PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
-    channel->Send(data);
-  }
-
-  template <typename T>
-  bool Receive(T* data) {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    PADDLE_ENFORCE_EQ(
-        holder_->Type(), std::type_index(typeid(T)),
-        "Channel type is not same as the type of the data being sent");
-    Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
-    PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
-    return channel->Receive(data);
-  }
-
-  bool IsClosed() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    return holder_->IsClosed();
-  }
-
-  bool CanSend() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    return holder_->CanSend();
-  }
-
-  bool CanReceive() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    return holder_->CanReceive();
-  }
-
-  void close() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    holder_->Close();
-  }
-
-  size_t Cap() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    return holder_->Cap();
-  }
-
-  void Lock() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    holder_->Lock();
-  }
-
-  void Unlock() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    holder_->Unlock();
-  }
-
-  template <typename T>
-  void AddToSendQ(const void* referrer, T* data,
-                  std::shared_ptr<std::condition_variable_any> cond,
-                  std::function<bool(ChannelAction)> cb) {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
-    if (channel != nullptr) {
-      channel->AddToSendQ(referrer, data, cond, cb);
-    }
-  }
-
-  template <typename T>
-  void AddToReceiveQ(const void* referrer, T* data,
-                     std::shared_ptr<std::condition_variable_any> cond,
-                     std::function<bool(ChannelAction)> cb) {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
-    if (channel != nullptr) {
-      channel->AddToReceiveQ(referrer, data, cond, cb);
-    }
-  }
-
-  void RemoveFromSendQ(const void* referrer) {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    holder_->RemoveFromSendQ(referrer);
-  }
-
-  void RemoveFromReceiveQ(const void* referrer) {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    holder_->RemoveFromReceiveQ(referrer);
-  }
-
-  inline bool IsInitialized() const { return holder_ != nullptr; }
-
-  inline const std::type_index Type() {
-    PADDLE_ENFORCE_EQ(IsInitialized(), true,
-                      "The Channel hasn't been initialized");
-    return holder_->Type();
-  }
-
- private:
-  /**
-   * @note    Placeholder hides type T, so it doesn't appear as a template
-   *          parameter of ChannelHolder.
-   */
-  struct Placeholder {
-    virtual ~Placeholder() {}
-    virtual const std::type_index Type() const = 0;
-    virtual void* Ptr() const = 0;
-    virtual bool IsClosed() = 0;
-    virtual bool CanSend() = 0;
-    virtual bool CanReceive() = 0;
-    virtual void RemoveFromSendQ(const void* referrer) = 0;
-    virtual void RemoveFromReceiveQ(const void* referrer) = 0;
-    virtual void Close() = 0;
-    virtual void Lock() = 0;
-    virtual void Unlock() = 0;
-    virtual size_t Cap() = 0;
-  };
-
-  template <typename T>
-  struct PlaceholderImpl : public Placeholder {
-    explicit PlaceholderImpl(size_t buffer_size)
-        : type_(std::type_index(typeid(T))) {
-      channel_.reset(MakeChannel<T>(buffer_size));
-    }
-
-    virtual const std::type_index Type() const { return type_; }
-
-    virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
-
-    virtual bool IsClosed() {
-      if (channel_) {
-        return channel_->IsClosed();
-      }
-      return false;
-    }
-
-    virtual bool CanSend() {
-      if (channel_) {
-        return channel_->CanSend();
-      }
-      return false;
-    }
-
-    virtual bool CanReceive() {
-      if (channel_) {
-        return channel_->CanReceive();
-      }
-      return false;
-    }
-
-    virtual void RemoveFromSendQ(const void* referrer) {
-      if (channel_) {
-        channel_->RemoveFromSendQ(referrer);
-      }
-    }
-
-    virtual void RemoveFromReceiveQ(const void* referrer) {
-      if (channel_) {
-        channel_->RemoveFromReceiveQ(referrer);
-      }
-    }
-
-    virtual void Close() {
-      if (channel_) channel_->Close();
-    }
-
-    virtual size_t Cap() {
-      if (channel_)
-        return channel_->Cap();
-      else
-        return -1;
-    }
-
-    virtual void Lock() {
-      if (channel_) channel_->Lock();
-    }
-
-    virtual void Unlock() {
-      if (channel_) channel_->Unlock();
-    }
-
-    std::unique_ptr<Channel<T>> channel_;
-    const std::type_index type_;
-  };
-
-  // Pointer to a PlaceholderImpl object
-  std::unique_ptr<Placeholder> holder_;
-};
-
-}  // namespace framework
-}  // namespace paddle
-
-#include "paddle/fluid/framework/channel_impl.h"
diff --git a/paddle/fluid/framework/channel_impl.h b/paddle/fluid/framework/channel_impl.h
deleted file mode 100644
index 26d454534e..0000000000
--- a/paddle/fluid/framework/channel_impl.h
+++ /dev/null
@@ -1,369 +0,0 @@
-/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-#include <stddef.h>  // for size_t
-#include <atomic>
-#include <condition_variable>  // NOLINT
-#include <deque>
-#include "paddle/fluid/framework/channel.h"
-#include "paddle/fluid/platform/enforce.h"
-
-namespace paddle {
-namespace framework {
-
-template <typename T>
-class ChannelImpl : public paddle::framework::Channel<T> {
-  friend Channel<T> *paddle::framework::MakeChannel<T>(size_t);
-  friend void paddle::framework::CloseChannel<T>(Channel<T> *);
-
- public:
-  virtual bool CanSend();
-  virtual bool CanReceive();
-  virtual void Send(T *);
-  virtual bool Receive(T *);
-  virtual size_t Cap() { return cap_; }
-  virtual void Lock();
-  virtual void Unlock();
-  virtual bool IsClosed();
-  virtual void Close();
-  explicit ChannelImpl(size_t);
-  virtual ~ChannelImpl();
-
-  virtual void AddToSendQ(const void *referrer, T *data,
-                          std::shared_ptr<std::condition_variable_any> cond,
-                          std::function<bool(ChannelAction)> cb);
-  virtual void AddToReceiveQ(const void *referrer, T *data,
-                             std::shared_ptr<std::condition_variable_any> cond,
-                             std::function<bool(ChannelAction)> cb);
-
-  virtual void RemoveFromSendQ(const void *referrer);
-  virtual void RemoveFromReceiveQ(const void *referrer);
-
- private:
-  struct QueueMessage {
-    T *data;
-    std::shared_ptr<std::condition_variable_any> cond;
-    bool chan_closed = false;
-    bool completed = false;
-    const void *referrer;  // TODO(thuan): figure out better way to do this
-    std::function<bool(ChannelAction)> callback;
-
-    explicit QueueMessage(T *item)
-        : data(item), cond(std::make_shared<std::condition_variable_any>()) {}
-
-    QueueMessage(T *item, std::shared_ptr<std::condition_variable_any> cond)
-        : data(item), cond(cond) {}
-
-    void Wait(std::unique_lock<std::recursive_mutex> &lock) {
-      cond->wait(lock, [this]() { return completed; });
-    }
-
-    void Notify() {
-      completed = true;
-      cond->notify_all();
-    }
-  };
-
-  void send_return() {
-    send_ctr--;
-    destructor_cond_.notify_all();
-  }
-
-  bool recv_return(bool value) {
-    recv_ctr--;
-    destructor_cond_.notify_all();
-    return value;
-  }
-
-  std::shared_ptr<QueueMessage> get_first_message(
-      std::deque<std::shared_ptr<QueueMessage>> *queue, ChannelAction action) {
-    while (!queue->empty()) {
-      // Check whether this message was added by Select
-      // If this was added by Select then execute the callback
-      // to check if you can execute this message. The callback
-      // can return false if some other case was executed in Select.
-      // In that case just discard this QueueMessage and process next.
-      std::shared_ptr<QueueMessage> m = queue->front();
-      queue->pop_front();
-      if (m->callback == nullptr || m->callback(action)) return m;
-    }
-    return nullptr;
-  }
-
-  size_t cap_;
-  std::recursive_mutex mu_;
-  bool closed_;
-  std::deque<T> buf_;
-  std::deque<std::shared_ptr<QueueMessage>> recvq;
-  std::deque<std::shared_ptr<QueueMessage>> sendq;
-  std::atomic<unsigned> send_ctr{0};
-  std::atomic<unsigned> recv_ctr{0};
-  std::condition_variable_any destructor_cond_;
-};
-
-template <typename T>
-ChannelImpl<T>::ChannelImpl(size_t capacity)
-    : cap_(capacity), closed_(false), send_ctr(0), recv_ctr(0) {
-  PADDLE_ENFORCE_GE(capacity, 0);
-}
-
-template <typename T>
-bool ChannelImpl<T>::CanSend() {
-  std::lock_guard<std::recursive_mutex> lock{mu_};
-  return !closed_ && (!recvq.empty() || buf_.size() < cap_);
-}
-
-template <typename T>
-bool ChannelImpl<T>::CanReceive() {
-  std::lock_guard<std::recursive_mutex> lock{mu_};
-  return !(closed_ && buf_.empty()) && (!sendq.empty() || buf_.size() > 0);
-}
-
-template <typename T>
-void ChannelImpl<T>::Send(T *item) {
-  send_ctr++;
-  std::unique_lock<std::recursive_mutex> lock{mu_};
-
-  // If channel is closed, throw exception
-  if (closed_) {
-    send_return();
-    lock.unlock();
-    PADDLE_THROW("Cannot send on closed channel");
-  }
-
-  // If there is a receiver, directly pass the value we want
-  // to send to the receiver, bypassing the channel buffer if any
-  if (!recvq.empty()) {
-    std::shared_ptr<QueueMessage> m =
-        get_first_message(&recvq, ChannelAction::SEND);
-
-    if (m != nullptr) {
-      *(m->data) = std::move(*item);
-      m->Notify();
-      send_return();
-      return;
-    } else {
-      Send(item);
-      send_return();
-      return;
-    }
-  }
-
-  // Unbuffered channel will always bypass this
-  // If buffered channel has space in buffer,
-  // write the element to the buffer.
-  if (buf_.size() < cap_) {
-    // Copy to buffer
-    buf_.push_back(std::move(*item));
-    send_return();
-    return;
-  }
-
-  // Block on channel, because some receiver will complete
-  // the operation for us
-  auto m = std::make_shared<QueueMessage>(item);
-  sendq.push_back(m);
-  m->Wait(lock);
-  if (m->chan_closed) {
-    send_return();
-    lock.unlock();
-    PADDLE_THROW("Cannot send on closed channel");
-  }
-  send_return();
-}
-
-template <typename T>
-bool ChannelImpl<T>::Receive(T *item) {
-  recv_ctr++;
-  std::unique_lock<std::recursive_mutex> lock{mu_};
-
-  // If channel is closed and buffer is empty or
-  // channel is unbuffered
-  if (closed_ && buf_.empty()) return recv_return(false);
-
-  // If there is a sender, directly receive the value we want
-  // from the sender. In case of a buffered channel, read from
-  // buffer and move front of send queue to the buffer
-  if (!sendq.empty()) {
-    std::shared_ptr<QueueMessage> m =
-        get_first_message(&sendq, ChannelAction::RECEIVE);
-    if (buf_.size() > 0) {
-      // Case 1 : Channel is Buffered
-      // Do Data transfer from front of buffer
-      // and add a QueueMessage to the buffer
-      *item = std::move(buf_.front());
-      buf_.pop_front();
-      // If first message from sendq is not null
-      // add it to the buffer and notify it
-      if (m != nullptr) {
-        // Copy to buffer
-        buf_.push_back(std::move(*(m->data)));
-        m->Notify();
-      }  // Ignore if there is no first message
-    } else {
-      // Case 2: Channel is Unbuffered
-      // Do data transfer from front of SendQ
-      // If front is nullptr, then recursively call itself
-      if (m != nullptr) {
-        *item = std::move(*(m->data));
-        m->Notify();
-      } else {
-        return recv_return(Receive(item));
-      }
-    }
-    return recv_return(true);
-  }
-
-  // If this is a buffered channel and there are items in buffer
-  if (buf_.size() > 0) {
-    // Directly read from buffer
-    *item = std::move(buf_.front());
-    buf_.pop_front();
-    // return true
-    return recv_return(true);
-  }
-
-  // No sender available, block on this channel
-  // Some receiver will complete the option for us
-  auto m = std::make_shared<QueueMessage>(item);
-  recvq.push_back(m);
-  m->Wait(lock);
-
-  return recv_return(!m->chan_closed);
-}
-
-template <typename T>
-void ChannelImpl<T>::Lock() {
-  mu_.lock();
-}
-
-template <typename T>
-void ChannelImpl<T>::Unlock() {
-  mu_.unlock();
-}
-
-template <typename T>
-bool ChannelImpl<T>::IsClosed() {
-  std::lock_guard<std::recursive_mutex> lock{mu_};
-  return closed_;
-}
-
-template <typename T>
-void ChannelImpl<T>::Close() {
-  std::unique_lock<std::recursive_mutex> lock{mu_};
-
-  if (closed_) {
-    // TODO(abhinavarora): closing an already closed channel should panic
-    lock.unlock();
-    return;
-  }
-
-  closed_ = true;
-
-  // Empty the readers
-  while (!recvq.empty()) {
-    std::shared_ptr<QueueMessage> m = recvq.front();
-    recvq.pop_front();
-    m->chan_closed = true;
-
-    // Execute callback function (if any)
-    if (m->callback != nullptr) {
-      m->callback(ChannelAction::CLOSE);
-    }
-
-    m->Notify();
-  }
-
-  // Empty the senders
-  while (!sendq.empty()) {
-    std::shared_ptr<QueueMessage> m = sendq.front();
-    sendq.pop_front();
-    m->chan_closed = true;
-
-    // Execute callback function (if any)
-    if (m->callback != nullptr) {
-      m->callback(ChannelAction::CLOSE);
-    }
-
-    m->Notify();
-  }
-}
-
-template <typename T>
-void ChannelImpl<T>::AddToSendQ(
-    const void *referrer, T *data,
-    std::shared_ptr<std::condition_variable_any> cond,
-    std::function<bool(ChannelAction)> cb) {
-  std::lock_guard<std::recursive_mutex> lock{mu_};
-  auto m = std::make_shared<QueueMessage>(data, cond);
-  m->referrer = referrer;
-  m->callback = cb;
-  sendq.push_back(m);
-}
-
-template <typename T>
-void ChannelImpl<T>::AddToReceiveQ(
-    const void *referrer, T *data,
-    std::shared_ptr<std::condition_variable_any> cond,
-    std::function<bool(ChannelAction)> cb) {
-  std::lock_guard<std::recursive_mutex> lock{mu_};
-  auto m = std::make_shared<QueueMessage>(data, cond);
-  m->referrer = referrer;
-  m->callback = cb;
-  recvq.push_back(m);
-}
-
-template <typename T>
-void ChannelImpl<T>::RemoveFromSendQ(const void *referrer) {
-  std::lock_guard<std::recursive_mutex> lock{mu_};
-
-  for (auto it = sendq.begin(); it != sendq.end();) {
-    std::shared_ptr<QueueMessage> sendMsg = (std::shared_ptr<QueueMessage>)*it;
-
-    if (sendMsg->referrer == referrer) {
-      it = sendq.erase(it);
-    } else {
-      ++it;
-    }
-  }
-}
-
-template <typename T>
-void ChannelImpl<T>::RemoveFromReceiveQ(const void *referrer) {
-  std::lock_guard<std::recursive_mutex> lock{mu_};
-
-  for (auto it = recvq.begin(); it != recvq.end();) {
-    std::shared_ptr<QueueMessage> recvMsg = (std::shared_ptr<QueueMessage>)*it;
-
-    if (recvMsg->referrer == referrer) {
-      it = recvq.erase(it);
-    } else {
-      ++it;
-    }
-  }
-}
-
-template <typename T>
-ChannelImpl<T>::~ChannelImpl() {
-  Close();
-  // The destructor must wait for all readers and writers to complete their task
-  // The channel has been closed, so we will not accept new readers and writers
-  std::unique_lock<std::recursive_mutex> lock{mu_};
-  destructor_cond_.wait(lock,
-                        [this]() { return send_ctr == 0 && recv_ctr == 0; });
-}
-
-}  // namespace framework
-}  // namespace paddle
diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc
deleted file mode 100644
index 542d791f6b..0000000000
--- a/paddle/fluid/framework/channel_test.cc
+++ /dev/null
@@ -1,1008 +0,0 @@
-/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "paddle/fluid/framework/channel.h"
-
-#include <chrono>  // NOLINT
-#include <thread>  // NOLINT
-#include "gtest/gtest.h"
-
-using paddle::framework::Channel;
-using paddle::framework::ChannelHolder;
-using paddle::framework::MakeChannel;
-using paddle::framework::CloseChannel;
-
-TEST(Channel, ChannelCapacityTest) {
-  const size_t buffer_size = 10;
-  auto ch = MakeChannel<size_t>(buffer_size);
-  EXPECT_EQ(ch->Cap(), buffer_size);
-  CloseChannel(ch);
-  delete ch;
-
-  ch = MakeChannel<size_t>(0);
-  EXPECT_EQ(ch->Cap(), 0U);
-  CloseChannel(ch);
-  delete ch;
-}
-
-void RecevingOrderEqualToSendingOrder(Channel<int> *ch, int num_items) {
-  unsigned sum_send = 0;
-  std::thread t([&]() {
-    for (int i = 0; i < num_items; i++) {
-      ch->Send(&i);
-      sum_send += i;
-    }
-  });
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));
-  for (int i = 0; i < num_items; i++) {
-    int recv = -1;
-    EXPECT_EQ(ch->Receive(&recv), true);
-    EXPECT_EQ(recv, i);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));
-  CloseChannel(ch);
-  t.join();
-  unsigned expected_sum = (num_items * (num_items - 1)) / 2;
-  EXPECT_EQ(sum_send, expected_sum);
-  delete ch;
-}
-
-TEST(Channel, SufficientBufferSizeDoesntBlock) {
-  const size_t buffer_size = 10;
-  auto ch = MakeChannel<size_t>(buffer_size);
-  for (size_t i = 0; i < buffer_size; ++i) {
-    ch->Send(&i);
-  }
-
-  size_t out;
-  for (size_t i = 0; i < buffer_size; ++i) {
-    EXPECT_EQ(ch->Receive(&out), true);  // should not block
-    EXPECT_EQ(out, i);
-  }
-  CloseChannel(ch);
-  delete ch;
-}
-
-// This tests that a  channel must return false
-// on send and receive performed after closing the channel.
-// Receive will only return false after close when queue is empty.
-// By creating separate threads for sending and receiving, we make this
-// function able to test both buffered and unbuffered channels.
-void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) {
-  const size_t data = 5;
-  std::thread send_thread{[&]() {
-    size_t i = data;
-    ch->Send(&i);  // should not block
-  }};
-
-  std::thread recv_thread{[&]() {
-    size_t i;
-    EXPECT_EQ(ch->Receive(&i), true);  // should not block
-    EXPECT_EQ(i, data);
-  }};
-
-  send_thread.join();
-  recv_thread.join();
-
-  // After closing send should panic. Receive should
-  // also  false as there is no data in queue.
-  CloseChannel(ch);
-  send_thread = std::thread{[&]() {
-    size_t i = data;
-    bool is_exception = false;
-    try {
-      ch->Send(&i);
-    } catch (paddle::platform::EnforceNotMet e) {
-      is_exception = true;
-    }
-    EXPECT_EQ(is_exception, true);
-  }};
-  recv_thread = std::thread{[&]() {
-    size_t i;
-    // should return false because channel is closed and queue is empty
-    EXPECT_EQ(ch->Receive(&i), false);
-  }};
-
-  send_thread.join();
-  recv_thread.join();
-}
-
-TEST(Channel, SendReceiveClosedBufferedChannelPanics) {
-  size_t buffer_size = 10;
-  auto ch = MakeChannel<size_t>(buffer_size);
-  SendReceiveWithACloseChannelShouldPanic(ch);
-  delete ch;
-}
-
-TEST(Channel, SendReceiveClosedUnBufferedChannelPanics) {
-  auto ch = MakeChannel<size_t>(0);
-  SendReceiveWithACloseChannelShouldPanic(ch);
-  delete ch;
-}
-
-TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
-  const size_t buffer_size = 10;
-  auto ch = MakeChannel<size_t>(buffer_size);
-
-  for (size_t i = 0; i < buffer_size; ++i) {
-    ch->Send(&i);  // sending should not block
-  }
-
-  size_t out;
-  for (size_t i = 0; i < buffer_size / 2; ++i) {
-    EXPECT_EQ(ch->Receive(&out), true);  // receiving should not block
-    EXPECT_EQ(out, i);
-  }
-
-  CloseChannel(ch);
-
-  for (size_t i = buffer_size / 2; i < buffer_size; ++i) {
-    EXPECT_EQ(ch->Receive(&out),
-              true);  // receving should return residual values.
-    EXPECT_EQ(out, i);
-  }
-
-  for (size_t i = 0; i < buffer_size; ++i) {
-    EXPECT_EQ(ch->Receive(&out),
-              false);  // receiving on closed channel should return false
-  }
-  delete ch;
-}
-
-TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
-  const size_t buffer_size = 10;
-  auto ch = MakeChannel<size_t>(buffer_size);
-  std::thread t([&]() {
-    // Try to write more than buffer size.
-    for (size_t i = 0; i < 2 * buffer_size; ++i) {
-      if (i < buffer_size) {
-        ch->Send(&i);  // should block after 10 iterations
-      } else {
-        bool is_exception = false;
-        try {
-          ch->Send(&i);
-        } catch (paddle::platform::EnforceNotMet e) {
-          is_exception = true;
-        }
-        EXPECT_EQ(is_exception, true);
-      }
-    }
-  });
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-  CloseChannel(ch);
-  t.join();
-  delete ch;
-}
-
-TEST(Channel, RecevingOrderEqualToSendingOrderWithUnBufferedChannel) {
-  auto ch = MakeChannel<int>(0);
-  RecevingOrderEqualToSendingOrder(ch, 20);
-}
-
-TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel1) {
-  // Test that Receive Order is same as Send Order when number of items
-  // sent is less than size of buffer
-  auto ch = MakeChannel<int>(10);
-  RecevingOrderEqualToSendingOrder(ch, 5);
-}
-
-TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel2) {
-  // Test that Receive Order is same as Send Order when number of items
-  // sent is equal to size of buffer
-  auto ch = MakeChannel<int>(10);
-  RecevingOrderEqualToSendingOrder(ch, 10);
-}
-
-TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel3) {
-  // Test that Receive Order is same as Send Order when number of items
-  // sent is greater than the size of buffer
-  auto ch = MakeChannel<int>(10);
-  RecevingOrderEqualToSendingOrder(ch, 20);
-}
-
-void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-
-  // Launches threads that try to read and are blocked because of no writers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    t[i] = std::thread(
-        [&](bool *p) {
-          int data;
-          EXPECT_EQ(ch->Receive(&data), false);
-          *p = true;
-        },
-        &thread_ended[i]);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-
-  // Verify that all the threads are blocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], false);
-  }
-
-  // Explicitly close the channel
-  // This should unblock all receivers
-  CloseChannel(ch);
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-void ChannelCloseUnblocksSendersTest(Channel<int> *ch, bool isBuffered) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-  bool send_success[kNumThreads];
-
-  // Launches threads that try to write and are blocked because of no readers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    send_success[i] = false;
-    t[i] = std::thread(
-        [&](bool *ended, bool *success) {
-          int data = 10;
-          bool is_exception = false;
-          try {
-            ch->Send(&data);
-          } catch (paddle::platform::EnforceNotMet e) {
-            is_exception = true;
-          }
-          *success = !is_exception;
-          *ended = true;
-        },
-        &thread_ended[i], &send_success[i]);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-
-  if (isBuffered) {
-    // If ch is Buffered, atleast 4 threads must be blocked.
-    int ct = 0;
-    for (size_t i = 0; i < kNumThreads; i++) {
-      if (!thread_ended[i]) ct++;
-    }
-    EXPECT_GE(ct, 4);
-  } else {
-    // If ch is UnBuffered, all the threads should be blocked.
-    for (size_t i = 0; i < kNumThreads; i++) {
-      EXPECT_EQ(thread_ended[i], false);
-    }
-  }
-  // Explicitly close the thread
-  // This should unblock all senders
-  CloseChannel(ch);
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  if (isBuffered) {
-    // Verify that only 1 send was successful
-    int ct = 0;
-    for (size_t i = 0; i < kNumThreads; i++) {
-      if (send_success[i]) ct++;
-    }
-    // Only 1 send must be successful
-    EXPECT_EQ(ct, 1);
-  }
-
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-// This tests that closing a buffered channel also unblocks
-//  any receivers waiting on the channel
-TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
-  auto ch = MakeChannel<int>(1);
-  ChannelCloseUnblocksReceiversTest(ch);
-  delete ch;
-}
-
-// This tests that closing a buffered channel also unblocks
-//  any senders waiting for channel to have write space
-TEST(Channel, BufferedChannelCloseUnblocksSendersTest) {
-  auto ch = MakeChannel<int>(1);
-  ChannelCloseUnblocksSendersTest(ch, true);
-  delete ch;
-}
-
-// This tests that closing an unbuffered channel also unblocks
-//  unblocks any receivers waiting for senders
-TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) {
-  auto ch = MakeChannel<int>(0);
-  ChannelCloseUnblocksReceiversTest(ch);
-  delete ch;
-}
-
-// This tests that closing an unbuffered channel also unblocks
-//  unblocks any senders waiting for senders
-TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) {
-  auto ch = MakeChannel<int>(0);
-  ChannelCloseUnblocksSendersTest(ch, false);
-  delete ch;
-}
-
-TEST(Channel, UnbufferedLessReceiveMoreSendTest) {
-  auto ch = MakeChannel<int>(0);
-  unsigned sum_send = 0;
-  // Send should block after three iterations
-  // since we only have three receivers.
-  std::thread t([&]() {
-    // Try to send more number of times
-    // than receivers
-    for (int i = 0; i < 4; i++) {
-      try {
-        ch->Send(&i);
-        sum_send += i;
-      } catch (paddle::platform::EnforceNotMet e) {
-      }
-    }
-  });
-  for (int i = 0; i < 3; i++) {
-    int recv;
-    ch->Receive(&recv);
-    EXPECT_EQ(recv, i);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-  EXPECT_EQ(sum_send, 3U);
-
-  CloseChannel(ch);
-  t.join();
-  delete ch;
-}
-
-TEST(Channel, UnbufferedMoreReceiveLessSendTest) {
-  auto ch = MakeChannel<int>(0);
-  unsigned sum_send = 0;
-  unsigned sum_receive = 0;
-  // The receiver should block after 5
-  // iterations, since there are only 5 senders.
-  std::thread t([&]() {
-    for (int i = 0; i < 8; i++) {
-      int recv;
-      ch->Receive(&recv);  // should block after the fifth iteration.
-      EXPECT_EQ(recv, i);
-      sum_receive += i;
-    }
-  });
-  for (int i = 0; i < 5; i++) {
-    ch->Send(&i);
-    sum_send += i;
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-  EXPECT_EQ(sum_send, 10U);
-  EXPECT_EQ(sum_receive, 10U);
-  // send three more elements
-  for (int i = 5; i < 8; i++) {
-    ch->Send(&i);
-    sum_send += i;
-  }
-
-  CloseChannel(ch);
-  t.join();
-  EXPECT_EQ(sum_send, 28U);
-  EXPECT_EQ(sum_receive, 28U);
-  delete ch;
-}
-
-// This tests that destroying a channel unblocks
-//  any senders waiting for channel to have write space
-void ChannelDestroyUnblockSenders(Channel<int> *ch, bool isBuffered) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-  bool send_success[kNumThreads];
-
-  // Launches threads that try to write and are blocked because of no readers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    send_success[i] = false;
-    t[i] = std::thread(
-        [&](bool *ended, bool *success) {
-          int data = 10;
-          bool is_exception = false;
-          try {
-            ch->Send(&data);
-          } catch (paddle::platform::EnforceNotMet e) {
-            is_exception = true;
-          }
-          *success = !is_exception;
-          *ended = true;
-        },
-        &thread_ended[i], &send_success[i]);
-  }
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-
-  if (isBuffered) {
-    // If channel is buffered, verify that atleast 4 threads are blocked
-    int ct = 0;
-    for (size_t i = 0; i < kNumThreads; i++) {
-      if (thread_ended[i] == false) ct++;
-    }
-    // Atleast 4 threads must be blocked
-    EXPECT_GE(ct, 4);
-  } else {
-    // Verify that all the threads are blocked
-    for (size_t i = 0; i < kNumThreads; i++) {
-      EXPECT_EQ(thread_ended[i], false);
-    }
-  }
-  // Explicitly destroy the channel
-  delete ch;
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  // Count number of successful sends
-  int ct = 0;
-  for (size_t i = 0; i < kNumThreads; i++) {
-    if (send_success[i]) ct++;
-  }
-
-  if (isBuffered) {
-    // Only 1 send must be successful
-    EXPECT_EQ(ct, 1);
-  } else {
-    // In unbuffered channel, no send should be successful
-    EXPECT_EQ(ct, 0);
-  }
-
-  // Join all threads
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-// This tests that destroying a channel also unblocks
-//  any receivers waiting on the channel
-void ChannelDestroyUnblockReceivers(Channel<int> *ch) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-
-  // Launches threads that try to read and are blocked because of no writers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    t[i] = std::thread(
-        [&](bool *p) {
-          int data;
-          // All reads should return false
-          EXPECT_EQ(ch->Receive(&data), false);
-          *p = true;
-        },
-        &thread_ended[i]);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(100));  // wait
-
-  // Verify that all threads are blocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], false);
-  }
-  // delete the channel
-  delete ch;
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) {
-  size_t buffer_size = 1;
-  auto ch = MakeChannel<int>(buffer_size);
-  ChannelDestroyUnblockReceivers(ch);
-}
-
-TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) {
-  size_t buffer_size = 1;
-  auto ch = MakeChannel<int>(buffer_size);
-  ChannelDestroyUnblockSenders(ch, true);
-}
-
-// This tests that destroying an unbuffered channel also unblocks
-//  unblocks any receivers waiting for senders
-TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) {
-  auto ch = MakeChannel<int>(0);
-  ChannelDestroyUnblockReceivers(ch);
-}
-
-TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) {
-  auto ch = MakeChannel<int>(0);
-  ChannelDestroyUnblockSenders(ch, false);
-}
-
-TEST(ChannelHolder, ChannelHolderCapacityTest) {
-  const size_t buffer_size = 10;
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(buffer_size);
-  EXPECT_EQ(ch->Cap(), buffer_size);
-  delete ch;
-
-  ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  EXPECT_EQ(ch->Cap(), 0U);
-  delete ch;
-}
-
-void ChannelHolderSendReceive(ChannelHolder *ch) {
-  unsigned sum_send = 0;
-  std::thread t([&]() {
-    for (int i = 0; i < 5; i++) {
-      ch->Send(&i);
-      sum_send += i;
-    }
-  });
-  for (int i = 0; i < 5; i++) {
-    int recv;
-    EXPECT_EQ(ch->Receive(&recv), true);
-    EXPECT_EQ(recv, i);
-  }
-
-  ch->close();
-  t.join();
-  EXPECT_EQ(sum_send, 10U);
-}
-
-TEST(ChannelHolder, ChannelHolderBufferedSendReceiveTest) {
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(10);
-  ChannelHolderSendReceive(ch);
-  delete ch;
-}
-
-TEST(ChannelHolder, ChannelHolderUnBufferedSendReceiveTest) {
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  ChannelHolderSendReceive(ch);
-  delete ch;
-}
-
-TEST(ChannelHolder, ChannelUninitializedTest) {
-  ChannelHolder *ch = new ChannelHolder();
-  EXPECT_EQ(ch->IsInitialized(), false);
-  int i = 10;
-  bool send_exception = false;
-  try {
-    ch->Send(&i);
-  } catch (paddle::platform::EnforceNotMet e) {
-    send_exception = true;
-  }
-  EXPECT_EQ(send_exception, true);
-
-  bool recv_exception = false;
-  try {
-    ch->Receive(&i);
-  } catch (paddle::platform::EnforceNotMet e) {
-    recv_exception = true;
-  }
-  EXPECT_EQ(recv_exception, true);
-
-  bool is_exception = false;
-  try {
-    ch->Type();
-  } catch (paddle::platform::EnforceNotMet e) {
-    is_exception = true;
-  }
-  EXPECT_EQ(is_exception, true);
-  delete ch;
-}
-
-TEST(ChannelHolder, ChannelInitializedTest) {
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(2);
-  EXPECT_EQ(ch->IsInitialized(), true);
-  // Channel should remain intialized even after close
-  ch->close();
-  EXPECT_EQ(ch->IsInitialized(), true);
-  delete ch;
-}
-
-TEST(ChannelHolder, TypeMismatchSendTest) {
-  // Test with unbuffered channel
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  bool is_exception = false;
-  bool boolean_data = true;
-  try {
-    ch->Send(&boolean_data);
-  } catch (paddle::platform::EnforceNotMet e) {
-    is_exception = true;
-  }
-  EXPECT_EQ(is_exception, true);
-  delete ch;
-
-  // Test with Buffered Channel
-  ch = new ChannelHolder();
-  ch->Reset<float>(10);
-  is_exception = false;
-  int int_data = 23;
-  try {
-    ch->Send(&int_data);
-  } catch (paddle::platform::EnforceNotMet e) {
-    is_exception = true;
-  }
-  EXPECT_EQ(is_exception, true);
-  delete ch;
-}
-
-TEST(ChannelHolder, TypeMismatchReceiveTest) {
-  // Test with unbuffered channel
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  bool is_exception = false;
-  bool float_data;
-  try {
-    ch->Receive(&float_data);
-  } catch (paddle::platform::EnforceNotMet e) {
-    is_exception = true;
-  }
-  EXPECT_EQ(is_exception, true);
-  delete ch;
-
-  // Test with Buffered Channel
-  ch = new ChannelHolder();
-  ch->Reset<float>(10);
-  is_exception = false;
-  int int_data = 23;
-  try {
-    ch->Receive(&int_data);
-  } catch (paddle::platform::EnforceNotMet e) {
-    is_exception = true;
-  }
-  EXPECT_EQ(is_exception, true);
-  delete ch;
-}
-
-void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-
-  // Launches threads that try to read and are blocked because of no writers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    t[i] = std::thread(
-        [&](bool *p) {
-          int data;
-          EXPECT_EQ(ch->Receive(&data), false);
-          *p = true;
-        },
-        &thread_ended[i]);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-
-  // Verify that all the threads are blocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], false);
-  }
-
-  // Explicitly close the channel
-  // This should unblock all receivers
-  ch->close();
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-  bool send_success[kNumThreads];
-
-  // Launches threads that try to write and are blocked because of no readers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    send_success[i] = false;
-    t[i] = std::thread(
-        [&](bool *ended, bool *success) {
-          int data = 10;
-          bool is_exception = false;
-          try {
-            ch->Send(&data);
-          } catch (paddle::platform::EnforceNotMet e) {
-            is_exception = true;
-          }
-          *success = !is_exception;
-          *ended = true;
-        },
-        &thread_ended[i], &send_success[i]);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-
-  if (isBuffered) {
-    // If ch is Buffered, atleast 4 threads must be blocked.
-    int ct = 0;
-    for (size_t i = 0; i < kNumThreads; i++) {
-      if (!thread_ended[i]) ct++;
-    }
-    EXPECT_GE(ct, 4);
-  } else {
-    // If ch is UnBuffered, all the threads should be blocked.
-    for (size_t i = 0; i < kNumThreads; i++) {
-      EXPECT_EQ(thread_ended[i], false);
-    }
-  }
-  // Explicitly close the thread
-  // This should unblock all senders
-  ch->close();
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  if (isBuffered) {
-    // Verify that only 1 send was successful
-    int ct = 0;
-    for (size_t i = 0; i < kNumThreads; i++) {
-      if (send_success[i]) ct++;
-    }
-    // Only 1 send must be successful
-    EXPECT_EQ(ct, 1);
-  }
-
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-// This tests that closing a channelholder unblocks
-//  any receivers waiting on the channel
-TEST(ChannelHolder, ChannelHolderCloseUnblocksReceiversTest) {
-  // Check for buffered channel
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(1);
-  ChannelHolderCloseUnblocksReceiversTest(ch);
-  delete ch;
-
-  // Check for unbuffered channel
-  ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  ChannelHolderCloseUnblocksReceiversTest(ch);
-  delete ch;
-}
-
-// This tests that closing a channelholder unblocks
-//  any senders waiting for channel to have write space
-TEST(Channel, ChannelHolderCloseUnblocksSendersTest) {
-  // Check for buffered channel
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(1);
-  ChannelHolderCloseUnblocksSendersTest(ch, true);
-  delete ch;
-
-  // Check for unbuffered channel
-  ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  ChannelHolderCloseUnblocksSendersTest(ch, false);
-  delete ch;
-}
-
-// This tests that destroying a channelholder unblocks
-//  any senders waiting for channel
-void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-  bool send_success[kNumThreads];
-
-  // Launches threads that try to write and are blocked because of no readers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    send_success[i] = false;
-    t[i] = std::thread(
-        [&](bool *ended, bool *success) {
-          int data = 10;
-          bool is_exception = false;
-          try {
-            ch->Send(&data);
-          } catch (paddle::platform::EnforceNotMet e) {
-            is_exception = true;
-          }
-          *success = !is_exception;
-          *ended = true;
-        },
-        &thread_ended[i], &send_success[i]);
-  }
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait 0.2 sec
-  if (isBuffered) {
-    // If channel is buffered, verify that atleast 4 threads are blocked
-    int ct = 0;
-    for (size_t i = 0; i < kNumThreads; i++) {
-      if (thread_ended[i] == false) ct++;
-    }
-    // Atleast 4 threads must be blocked
-    EXPECT_GE(ct, 4);
-  } else {
-    // Verify that all the threads are blocked
-    for (size_t i = 0; i < kNumThreads; i++) {
-      EXPECT_EQ(thread_ended[i], false);
-    }
-  }
-  // Explicitly destroy the channel
-  delete ch;
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  // Count number of successfuld sends
-  int ct = 0;
-  for (size_t i = 0; i < kNumThreads; i++) {
-    if (send_success[i]) ct++;
-  }
-
-  if (isBuffered) {
-    // Only 1 send must be successful
-    EXPECT_EQ(ct, 1);
-  } else {
-    // In unbuffered channel, no send should be successful
-    EXPECT_EQ(ct, 0);
-  }
-
-  // Join all threads
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-// This tests that destroying a channelholder also unblocks
-//  any receivers waiting on the channel
-void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) {
-  const size_t kNumThreads = 5;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-
-  // Launches threads that try to read and are blocked because of no writers
-  for (size_t i = 0; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    t[i] = std::thread(
-        [&](bool *p) {
-          int data;
-          // All reads should return false
-          EXPECT_EQ(ch->Receive(&data), false);
-          *p = true;
-        },
-        &thread_ended[i]);
-  }
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-
-  // Verify that all threads are blocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], false);
-  }
-  // delete the channel
-  delete ch;
-  std::this_thread::sleep_for(std::chrono::milliseconds(200));  // wait
-  // Verify that all threads got unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-TEST(ChannelHolder, ChannelHolderDestroyUnblocksReceiversTest) {
-  // Check for Buffered Channel
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(1);
-  ChannelHolderDestroyUnblockReceivers(ch);
-  // ch is already deleted already deleted in
-  // ChannelHolderDestroyUnblockReceivers
-
-  // Check for Unbuffered channel
-  ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  ChannelHolderDestroyUnblockReceivers(ch);
-}
-
-TEST(ChannelHolder, ChannelHolderDestroyUnblocksSendersTest) {
-  // Check for Buffered Channel
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(1);
-  ChannelHolderDestroyUnblockSenders(ch, true);
-  // ch is already deleted already deleted in
-  // ChannelHolderDestroyUnblockReceivers
-
-  // Check for Unbuffered channel
-  ch = new ChannelHolder();
-  ch->Reset<int>(0);
-  ChannelHolderDestroyUnblockSenders(ch, false);
-}
-
-// This tests that closing a channelholder many times.
-void ChannelHolderManyTimesClose(ChannelHolder *ch) {
-  const int kNumThreads = 15;
-  std::thread t[kNumThreads];
-  bool thread_ended[kNumThreads];
-
-  // Launches threads that try to send data to channel.
-  for (size_t i = 0; i < kNumThreads / 3; i++) {
-    thread_ended[i] = false;
-    t[i] = std::thread(
-        [&](bool *ended) {
-          int data = 10;
-          ch->Send(&data);
-          *ended = true;
-        },
-        &thread_ended[i]);
-  }
-
-  // Launches threads that try to receive data to channel.
-  for (size_t i = kNumThreads / 3; i < 2 * kNumThreads / 3; i++) {
-    thread_ended[i] = false;
-    t[i] = std::thread(
-        [&](bool *p) {
-          int data;
-          if (ch->Receive(&data)) {
-            EXPECT_EQ(data, 10);
-          }
-          *p = true;
-        },
-        &thread_ended[i]);
-  }
-
-  // Launches threads that try to close the channel.
-  for (size_t i = 2 * kNumThreads / 3; i < kNumThreads; i++) {
-    thread_ended[i] = false;
-    t[i] = std::thread(
-        [&](bool *p) {
-          if (!ch->IsClosed()) {
-            ch->close();
-          }
-          *p = true;
-        },
-        &thread_ended[i]);
-  }
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(100));  // wait
-
-  // Verify that all threads are unblocked
-  for (size_t i = 0; i < kNumThreads; i++) {
-    EXPECT_EQ(thread_ended[i], true);
-  }
-  EXPECT_TRUE(ch->IsClosed());
-  // delete the channel
-  delete ch;
-  for (size_t i = 0; i < kNumThreads; i++) t[i].join();
-}
-
-TEST(ChannelHolder, ChannelHolderManyTimesCloseTest) {
-  // Check for Buffered Channel
-  ChannelHolder *ch = new ChannelHolder();
-  ch->Reset<int>(10);
-  ChannelHolderManyTimesClose(ch);
-}
diff --git a/paddle/fluid/framework/concurrency_test.cc b/paddle/fluid/framework/concurrency_test.cc
deleted file mode 100644
index bbf67f5ba9..0000000000
--- a/paddle/fluid/framework/concurrency_test.cc
+++ /dev/null
@@ -1,292 +0,0 @@
-/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include <thread>  // NOLINT
-
-#include "gtest/gtest.h"
-#include "paddle/fluid/framework/block_desc.h"
-#include "paddle/fluid/framework/channel.h"
-#include "paddle/fluid/framework/executor.h"
-#include "paddle/fluid/framework/op_registry.h"
-
-USE_NO_KERNEL_OP(go);
-USE_NO_KERNEL_OP(channel_close);
-USE_NO_KERNEL_OP(channel_create);
-USE_NO_KERNEL_OP(channel_recv);
-USE_NO_KERNEL_OP(channel_send);
-USE_NO_KERNEL_OP(elementwise_add);
-USE_NO_KERNEL_OP(select);
-USE_NO_KERNEL_OP(conditional_block);
-USE_NO_KERNEL_OP(equal);
-USE_NO_KERNEL_OP(assign);
-USE_NO_KERNEL_OP(while);
-USE_NO_KERNEL_OP(print);
-
-namespace f = paddle::framework;
-namespace p = paddle::platform;
-
-namespace paddle {
-namespace framework {
-
-template <typename T>
-LoDTensor *CreateVariable(Scope *scope, const p::CPUPlace &place,
-                          std::string name, T value) {
-  // Create LoDTensor<int> of dim [1]
-  auto var = scope->Var(name);
-  auto tensor = var->GetMutable<LoDTensor>();
-  tensor->Resize({1});
-  T *expect = tensor->mutable_data<T>(place);
-  expect[0] = value;
-  return tensor;
-}
-
-void AddOp(const std::string &type, const VariableNameMap &inputs,
-           const VariableNameMap &outputs, AttributeMap attrs,
-           BlockDesc *block) {
-  // insert op
-  auto op = block->AppendOp();
-  op->SetType(type);
-  for (auto &kv : inputs) {
-    op->SetInput(kv.first, kv.second);
-  }
-  for (auto &kv : outputs) {
-    op->SetOutput(kv.first, kv.second);
-  }
-  op->SetAttrMap(attrs);
-}
-
-void AddCase(ProgramDesc *program, Scope *scope, p::CPUPlace *place,
-             BlockDesc *casesBlock, int caseId, int caseType,
-             std::string caseChannel, std::string caseVarName,
-             std::function<void(BlockDesc *, Scope *)> func) {
-  std::string caseCondName = std::string("caseCond") + std::to_string(caseId);
-  std::string caseCondXVarName =
-      std::string("caseCondX") + std::to_string(caseId);
-
-  BlockDesc *caseBlock = program->AppendBlock(*casesBlock);
-  func(caseBlock, scope);
-
-  CreateVariable(scope, *place, caseCondName, false);
-  CreateVariable(scope, *place, caseCondXVarName, caseId);
-  CreateVariable(scope, *place, caseVarName, caseId);
-
-  scope->Var("step_scope");
-
-  AddOp("equal", {{"X", {caseCondXVarName}}, {"Y", {"caseToExecute"}}},
-        {{"Out", {caseCondName}}}, {}, casesBlock);
-
-  AddOp("conditional_block", {{"X", {caseCondName}}, {"Params", {}}},
-        {{"Out", {}}, {"Scope", {"step_scope"}}},
-        {{"sub_block", caseBlock}, {"is_scalar_condition", true}}, casesBlock);
-}
-
-void AddFibonacciSelect(Scope *scope, p::CPUPlace *place, ProgramDesc *program,
-                        BlockDesc *parentBlock, std::string dataChanName,
-                        std::string quitChanName) {
-  BlockDesc *whileBlock = program->AppendBlock(*parentBlock);
-
-  CreateVariable(scope, *place, "whileExitCond", true);
-  CreateVariable(scope, *place, "caseToExecute", -1);
-  CreateVariable(scope, *place, "case1var", 0);
-
-  CreateVariable(scope, *place, "xtemp", 0);
-
-  // TODO(thuan): Need to create fibXToSend, since channel send moves the actual
-  // data,
-  // which causes the data to be no longer accessible to do the fib calculation
-  // TODO(abhinav): Change channel send to do a copy instead of a move!
-  CreateVariable(scope, *place, "fibXToSend", 0);
-
-  CreateVariable(scope, *place, "fibX", 0);
-  CreateVariable(scope, *place, "fibY", 1);
-  CreateVariable(scope, *place, "quitVar", 0);
-
-  BlockDesc *casesBlock = program->AppendBlock(*whileBlock);
-  std::function<void(BlockDesc * caseBlock)> f = [](BlockDesc *caseBlock) {};
-
-  // TODO(thuan): Remove this once we change channel send to do a copy instead
-  // of move
-  AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"fibXToSend"}}}, {}, whileBlock);
-
-  // Case 0: Send to dataChanName
-  std::function<void(BlockDesc * caseBlock, Scope * scope)> case0Func = [&](
-      BlockDesc *caseBlock, Scope *scope) {
-    AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"xtemp"}}}, {}, caseBlock);
-    AddOp("assign", {{"X", {"fibY"}}}, {{"Out", {"fibX"}}}, {}, caseBlock);
-    AddOp("elementwise_add", {{"X", {"xtemp"}}, {"Y", {"fibY"}}},
-          {{"Out", {"fibY"}}}, {}, caseBlock);
-  };
-  AddCase(program, scope, place, casesBlock, 0, 1, dataChanName, "fibXToSend",
-          case0Func);
-  std::string case0Config =
-      std::string("0,1,") + dataChanName + std::string(",fibXToSend");
-
-  // Case 1: Receive from quitChanName
-  std::function<void(BlockDesc * caseBlock, Scope * scope)> case2Func = [&](
-      BlockDesc *caseBlock, Scope *scope) {
-    // Exit the while loop after we receive from quit channel.
-    // We assign a false to "whileExitCond" variable, which will
-    // break out of while_op loop
-    CreateVariable(scope, *place, "whileFalse", false);
-    AddOp("assign", {{"X", {"whileFalse"}}}, {{"Out", {"whileExitCond"}}}, {},
-          caseBlock);
-  };
-  AddCase(program, scope, place, casesBlock, 1, 2, quitChanName, "quitVar",
-          case2Func);
-  std::string case1Config =
-      std::string("1,2,") + quitChanName + std::string(",quitVar");
-
-  // Select block
-  AddOp("select", {{"X", {dataChanName, quitChanName}},
-                   {"case_to_execute", {"caseToExecute"}}},
-        {{"Out", {}}},
-        {{"sub_block", casesBlock},
-         {"cases", std::vector<std::string>{case0Config, case1Config}}},
-        whileBlock);
-
-  scope->Var("stepScopes");
-  AddOp("while",
-        {{"X", {dataChanName, quitChanName}}, {"Condition", {"whileExitCond"}}},
-        {{"Out", {}}, {"StepScopes", {"stepScopes"}}},
-        {{"sub_block", whileBlock}}, parentBlock);
-}
-
-TEST(Concurrency, Go_Op) {
-  Scope scope;
-  p::CPUPlace place;
-
-  // Initialize scope variables
-  p::CPUDeviceContext ctx(place);
-
-  // Create channel variable
-  scope.Var("Channel");
-
-  // Create Variables, x0 will be put into channel,
-  // result will be pulled from channel
-  CreateVariable(&scope, place, "Status", false);
-  CreateVariable(&scope, place, "x0", 99);
-  CreateVariable(&scope, place, "result", 0);
-
-  framework::Executor executor(place);
-  ProgramDesc program;
-  BlockDesc *block = program.MutableBlock(0);
-
-  // Create channel OP
-  AddOp("channel_create", {}, {{"Out", {"Channel"}}},
-        {{"capacity", 10}, {"data_type", f::proto::VarType::LOD_TENSOR}},
-        block);
-
-  // Create Go Op routine
-  BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
-  AddOp("channel_send", {{"Channel", {"Channel"}}, {"X", {"x0"}}},
-        {{"Status", {"Status"}}}, {}, goOpBlock);
-
-  // Create Go Op
-  AddOp("go", {{"X", {"Channel", "x0"}}}, {}, {{"sub_block", goOpBlock}},
-        block);
-
-  // Create Channel Receive Op
-  AddOp("channel_recv", {{"Channel", {"Channel"}}},
-        {{"Status", {"Status"}}, {"Out", {"result"}}}, {}, block);
-
-  // Create Channel Close Op
-  AddOp("channel_close", {{"Channel", {"Channel"}}}, {}, {}, block);
-
-  // Check the result tensor to make sure it is set to 0
-  const LoDTensor &tensor = (scope.FindVar("result"))->Get<LoDTensor>();
-  auto *initialData = tensor.data<int>();
-  EXPECT_EQ(initialData[0], 0);
-
-  executor.Run(program, &scope, 0, true, true);
-
-  // After we call executor.run, the Go operator should do a channel_send to
-  // set the "result" variable to 99.
-  auto *finalData = tensor.data<int>();
-  EXPECT_EQ(finalData[0], 99);
-}
-
-/**
- * This test implements the fibonacci function using go_op and select_op
- */
-TEST(Concurrency, Select) {
-  Scope scope;
-  p::CPUPlace place;
-
-  // Initialize scope variables
-  p::CPUDeviceContext ctx(place);
-
-  CreateVariable(&scope, place, "Status", false);
-  CreateVariable(&scope, place, "result", 0);
-  CreateVariable(&scope, place, "currentXFib", 0);
-
-  framework::Executor executor(place);
-  ProgramDesc program;
-  BlockDesc *block = program.MutableBlock(0);
-
-  // Create channel OP
-  std::string dataChanName = "Channel";
-  scope.Var(dataChanName);
-  AddOp("channel_create", {}, {{"Out", {dataChanName}}},
-        {{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);
-
-  std::string quitChanName = "Quit";
-  scope.Var(quitChanName);
-  AddOp("channel_create", {}, {{"Out", {quitChanName}}},
-        {{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);
-
-  // Create Go Op routine, which loops 10 times over fibonacci sequence
-  CreateVariable(&scope, place, "xReceiveVar", 0);
-
-  BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
-  for (int i = 0; i < 10; ++i) {
-    AddOp("channel_recv", {{"Channel", {dataChanName}}},
-          {{"Status", {"Status"}}, {"Out", {"currentXFib"}}}, {}, goOpBlock);
-    AddOp("print", {{"In", {"currentXFib"}}}, {{"Out", {"currentXFib"}}},
-          {{"first_n", 100},
-           {"summarize", -1},
-           {"print_tensor_name", false},
-           {"print_tensor_type", true},
-           {"print_tensor_shape", false},
-           {"print_tensor_lod", false},
-           {"print_phase", std::string("FORWARD")},
-           {"message", std::string("X: ")}},
-          goOpBlock);
-  }
-
-  CreateVariable(&scope, place, "quitSignal", 0);
-  AddOp("channel_send", {{"Channel", {quitChanName}}, {"X", {"quitSignal"}}},
-        {{"Status", {"Status"}}}, {}, goOpBlock);
-
-  // Create Go Op
-  AddOp("go", {{"X", {dataChanName, quitChanName}}}, {},
-        {{"sub_block", goOpBlock}}, block);
-
-  AddFibonacciSelect(&scope, &place, &program, block, dataChanName,
-                     quitChanName);
-
-  // Create Channel Close Op
-  AddOp("channel_close", {{"Channel", {dataChanName}}}, {}, {}, block);
-  AddOp("channel_close", {{"Channel", {quitChanName}}}, {}, {}, block);
-
-  executor.Run(program, &scope, 0, true, true);
-
-  // After we call executor.run, "result" variable should be equal to 34
-  // (which is 10 loops through fibonacci sequence)
-  const LoDTensor &tensor = (scope.FindVar("currentXFib"))->Get<LoDTensor>();
-  auto *finalData = tensor.data<int>();
-  EXPECT_EQ(finalData[0], 34);
-}
-
-}  // namespace framework
-}  // namespace paddle
diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index 8d8042a056..70ec6e90a4 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -14,7 +14,6 @@ limitations under the License. */
 
 #include "paddle/fluid/framework/executor.h"
 
-#include "paddle/fluid/framework/channel.h"
 #include "paddle/fluid/framework/feed_fetch_method.h"
 #include "paddle/fluid/framework/lod_rank_table.h"
 #include "paddle/fluid/framework/lod_tensor_array.h"
@@ -76,15 +75,13 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
     var->GetMutable<platform::PlaceList>();
   } else if (var_type == proto::VarType::READER) {
     var->GetMutable<ReaderHolder>();
-  } else if (var_type == proto::VarType::CHANNEL) {
-    var->GetMutable<ChannelHolder>();
   } else if (var_type == proto::VarType::RAW) {
     // GetMutable will be called in operator
   } else {
     PADDLE_THROW(
         "Variable type %d is not in "
         "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
-        "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
+        "LOD_RANK_TABLE, PLACE_LIST, READER, RAW]",
         var_type);
   }
 }
diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto
index 460401df54..25f0ba4184 100644
--- a/paddle/fluid/framework/framework.proto
+++ b/paddle/fluid/framework/framework.proto
@@ -126,7 +126,6 @@ message VarType {
     LOD_TENSOR_ARRAY = 13;
     PLACE_LIST = 14;
     READER = 15;
-    CHANNEL = 16;
     // Any runtime decided variable type is raw
     // raw variables should manage their own allocations
     // in operators like nccl_op
@@ -158,12 +157,6 @@ message VarType {
   message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
   optional ReaderDesc reader = 5;
 
-  message ChannelDesc {
-    required Type data_type = 1;
-    required int64 capacity = 2;
-  }
-  optional ChannelDesc channel = 6;
-
   message Tuple { repeated Type element_type = 1; }
   optional Tuple tuple = 7;
 }
diff --git a/paddle/fluid/framework/tuple.h b/paddle/fluid/framework/tuple.h
index f6c6a1fec1..508ee931c6 100644
--- a/paddle/fluid/framework/tuple.h
+++ b/paddle/fluid/framework/tuple.h
@@ -17,7 +17,6 @@ limitations under the License. */
 #include <stdexcept>
 #include <string>
 #include <vector>
-#include "paddle/fluid/framework/channel.h"
 #include "paddle/fluid/framework/lod_tensor.h"
 #include "paddle/fluid/framework/tensor.h"
 #include "paddle/fluid/framework/var_desc.h"
diff --git a/paddle/fluid/framework/var_desc.cc b/paddle/fluid/framework/var_desc.cc
index 1aa0ae0f7c..7e3f002b53 100644
--- a/paddle/fluid/framework/var_desc.cc
+++ b/paddle/fluid/framework/var_desc.cc
@@ -88,13 +88,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
 }
 
 void VarDesc::SetDataType(proto::VarType::Type data_type) {
-  switch (desc_.type().type()) {
-    case proto::VarType::CHANNEL:
-      mutable_channel_desc()->set_data_type(data_type);
-      break;
-    default:
-      mutable_tensor_desc()->set_data_type(data_type);
-  }
+  mutable_tensor_desc()->set_data_type(data_type);
 }
 
 void VarDesc::SetDataTypes(
@@ -115,13 +109,7 @@ void VarDesc::SetDataTypes(
 }
 
 proto::VarType::Type VarDesc::GetDataType() const {
-  switch (desc_.type().type()) {
-    case proto::VarType::CHANNEL:
-      return channel_desc().data_type();
-      break;
-    default:
-      return tensor_desc().data_type();
-  }
+  return tensor_desc().data_type();
 }
 
 std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
@@ -134,17 +122,6 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
   return res;
 }
 
-void VarDesc::SetCapacity(int64_t capacity) {
-  switch (desc_.type().type()) {
-    case proto::VarType::CHANNEL:
-      desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
-      break;
-    default:
-      PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
-                   this->Name());
-  }
-}
-
 void VarDesc::SetLoDLevel(int32_t lod_level) {
   switch (desc_.type().type()) {
     case proto::VarType::LOD_TENSOR:
@@ -214,19 +191,6 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
   }
 }
 
-const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
-  PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
-  PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
-  switch (desc_.type().type()) {
-    case proto::VarType::CHANNEL:
-      return desc_.type().channel();
-    default:
-      PADDLE_THROW(
-          "Getting 'channel_desc' is not supported by the type of var %s.",
-          this->Name());
-  }
-}
-
 const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
   PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
   PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
@@ -262,20 +226,6 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
   }
 }
 
-proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
-  PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
-  PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
-  switch (desc_.type().type()) {
-    case proto::VarType::CHANNEL:
-      return desc_.mutable_type()->mutable_channel();
-    default:
-      PADDLE_THROW(
-          "Getting 'mutable_channel_desc' is not supported by the type of var "
-          "%s.",
-          this->Name());
-  }
-}
-
 proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
   PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
   PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h
index 9f7a21ef42..e33849ef50 100644
--- a/paddle/fluid/framework/var_desc.h
+++ b/paddle/fluid/framework/var_desc.h
@@ -87,8 +87,6 @@ class VarDesc {
   void SetDataTypes(
       const std::vector<proto::VarType::Type> &multiple_data_type);
 
-  void SetCapacity(int64_t capacity);
-
   proto::VarType::Type GetDataType() const;
 
   std::vector<proto::VarType::Type> GetDataTypes() const;
@@ -110,10 +108,8 @@ class VarDesc {
   void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
 
  private:
-  const proto::VarType::ChannelDesc &channel_desc() const;
   const proto::VarType::TensorDesc &tensor_desc() const;
   std::vector<proto::VarType::TensorDesc> tensor_descs() const;
-  proto::VarType::ChannelDesc *mutable_channel_desc();
   proto::VarType::TensorDesc *mutable_tensor_desc();
   std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
 
diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h
index e9550dbfb9..3b6f1cdb8f 100644
--- a/paddle/fluid/framework/var_type.h
+++ b/paddle/fluid/framework/var_type.h
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #pragma once
-#include "paddle/fluid/framework/channel.h"
 #include "paddle/fluid/framework/framework.pb.h"
 #include "paddle/fluid/framework/lod_rank_table.h"
 #include "paddle/fluid/framework/lod_tensor.h"
@@ -41,8 +40,6 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
     return proto::VarType_Type_SELECTED_ROWS;
   } else if (IsType<ReaderHolder>(type)) {
     return proto::VarType_Type_READER;
-  } else if (IsType<ChannelHolder>(type)) {
-    return proto::VarType_Type_CHANNEL;
   } else {
     PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
   }
@@ -66,9 +63,6 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
     case proto::VarType_Type_READER:
       visitor(var.Get<ReaderHolder>());
       return;
-    case proto::VarType_Type_CHANNEL:
-      visitor(var.Get<ChannelHolder>());
-      return;
     default:
       PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
   }
diff --git a/paddle/fluid/inference/analysis/analysis_pass.h b/paddle/fluid/inference/analysis/analysis_pass.h
index b6edb5529a..13805ea4ac 100644
--- a/paddle/fluid/inference/analysis/analysis_pass.h
+++ b/paddle/fluid/inference/analysis/analysis_pass.h
@@ -41,12 +41,6 @@ class AnalysisPass {
   // all passes have run.
   virtual bool Finalize() { return false; }
 
-  // Get a Pass appropriate to print the Node this pass operates on.
-  virtual AnalysisPass *CreatePrinterPass(std::ostream &os,
-                                          const std::string &banner) const {
-    return nullptr;
-  }
-
   // Create a debugger Pass that draw the DFG by graphviz toolkit.
   virtual AnalysisPass *CreateGraphvizDebugerPass() const { return nullptr; }
 
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index 9c67df7bdf..fa41266d62 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -313,11 +313,6 @@ op_library(save_combine_op DEPS lod_tensor)
 op_library(load_combine_op DEPS lod_tensor)
 op_library(concat_op DEPS concat)
 
-# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
-add_subdirectory(concurrency)
-op_library(channel_send_op DEPS concurrency)
-op_library(channel_recv_op DEPS concurrency)
-
 list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
 
 foreach(src ${GENERAL_OPS})
diff --git a/paddle/fluid/operators/channel_close_op.cc b/paddle/fluid/operators/channel_close_op.cc
deleted file mode 100644
index 8e2db250a0..0000000000
--- a/paddle/fluid/operators/channel_close_op.cc
+++ /dev/null
@@ -1,70 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "paddle/fluid/framework/channel.h"
-#include "paddle/fluid/framework/op_registry.h"
-
-namespace pf = paddle::framework;
-static constexpr char kChannel[] = "Channel";
-
-namespace paddle {
-namespace operators {
-
-class ChannelCloseOp : public framework::OperatorBase {
- public:
-  ChannelCloseOp(const std::string &type,
-                 const framework::VariableNameMap &inputs,
-                 const framework::VariableNameMap &outputs,
-                 const framework::AttributeMap &attrs)
-      : framework::OperatorBase(type, inputs, outputs, attrs) {}
-
- private:
-  void RunImpl(const framework::Scope &scope,
-               const platform::Place &dev_place) const override {
-    auto &inp = *scope.FindVar(Input(kChannel));
-
-    // Get the mutable version of the channel variable and closes it.
-    pf::ChannelHolder *ch = inp.GetMutable<framework::ChannelHolder>();
-    ch->close();
-  }
-};
-
-class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
- public:
-  void operator()(framework::InferShapeContext *context) const override {
-    PADDLE_ENFORCE(context->HasInput("Channel"),
-                   "The input of ChannelClose op must be set");
-  }
-};
-
-class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
- public:
-  void Make() override {
-    AddInput(kChannel,
-             "The Channel Variable that should be closed by"
-             " the ChannelClose Op.");
-    AddComment(R"DOC(
-Channel Close Operator.
-
-This operator closes an open channel.
-)DOC");
-  }
-};
-
-}  // namespace operators
-}  // namespace paddle
-
-REGISTER_OPERATOR(channel_close, paddle::operators::ChannelCloseOp,
-                  paddle::framework::EmptyGradOpMaker,
-                  paddle::operators::ChannelCloseOpMaker);
diff --git a/paddle/fluid/operators/channel_create_op.cc b/paddle/fluid/operators/channel_create_op.cc
deleted file mode 100644
index a7f59e4088..0000000000
--- a/paddle/fluid/operators/channel_create_op.cc
+++ /dev/null
@@ -1,113 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "paddle/fluid/framework/channel.h"
-#include "paddle/fluid/framework/lod_rank_table.h"
-#include "paddle/fluid/framework/lod_tensor_array.h"
-#include "paddle/fluid/framework/op_registry.h"
-#include "paddle/fluid/framework/reader.h"
-
-namespace pf = paddle::framework;
-
-static constexpr char kOutput[] = "Out";
-
-namespace paddle {
-namespace operators {
-
-class ChannelCreateOp : public framework::OperatorBase {
- public:
-  ChannelCreateOp(const std::string &type,
-                  const framework::VariableNameMap &inputs,
-                  const framework::VariableNameMap &outputs,
-                  const framework::AttributeMap &attrs)
-      : framework::OperatorBase(type, inputs, outputs, attrs) {}
-
- private:
-  void RunImpl(const framework::Scope &scope,
-               const platform::Place &dev_place) const override {
-    auto &out = *scope.FindVar(Output(kOutput));
-
-    // Determine the datatype and capacity of the channel to be created
-    // from the attributes provided.
-    auto dtype =
-        static_cast<framework::proto::VarType::Type>(Attr<int>("data_type"));
-    auto capacity = Attr<int>("capacity");
-
-    // Based on the datatype, create a new channel holder initialized with
-    // the given capacity. When capacity is 0, an unbuffered channel is
-    // created.
-    pf::ChannelHolder *ch = out.GetMutable<framework::ChannelHolder>();
-    if (dtype == framework::proto::VarType::LOD_TENSOR) {
-      ch->Reset<pf::LoDTensor>(capacity);
-    } else if (dtype == framework::proto::VarType::SELECTED_ROWS) {
-      ch->Reset<pf::SelectedRows>(capacity);
-    } else if (dtype == framework::proto::VarType::LOD_RANK_TABLE) {
-      ch->Reset<pf::LoDRankTable>(capacity);
-    } else if (dtype == framework::proto::VarType::LOD_TENSOR_ARRAY) {
-      ch->Reset<pf::LoDTensorArray>(capacity);
-    } else if (dtype == framework::proto::VarType::READER) {
-      ch->Reset<pf::ReaderHolder>(capacity);
-    } else if (dtype == framework::proto::VarType::CHANNEL) {
-      ch->Reset<pf::ChannelHolder>(capacity);
-    } else if (dtype == framework::proto::VarType::BOOL) {
-      ch->Reset<bool>(capacity);
-    } else if (dtype == framework::proto::VarType::INT32) {
-      ch->Reset<int>(capacity);
-    } else if (dtype == framework::proto::VarType::INT64) {
-      ch->Reset<int64_t>(capacity);
-    } else if (dtype == framework::proto::VarType::FP32) {
-      ch->Reset<float>(capacity);
-    } else if (dtype == framework::proto::VarType::FP64) {
-      ch->Reset<double>(capacity);
-    } else {
-      PADDLE_THROW(
-          "Data type %d is not in "
-          "[LOD_TENSOR, SELECTED_ROWS, LOD_RANK_TABLE, LOD_TENSOR_ARRAY, "
-          "READER, CHANNEL, BOOL, INT32, INT64, FP32, FP64]",
-          dtype);
-    }
-  }
-};
-
-class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
- public:
-  void operator()(framework::InferShapeContext *context) const override {
-    PADDLE_ENFORCE(context->HasOutput(kOutput),
-                   "The output of ChannelCreate op must be set");
-    context->SetOutputDim(kOutput, {1});
-  }
-};
-
-class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
- public:
-  void Make() override {
-    AddOutput(kOutput,
-              "The object of a Channel type created by ChannelCreate Op.");
-    AddAttr<int>("capacity", "The size of the buffer of Channel.")
-        .SetDefault(0);
-    AddAttr<int>("data_type", "The data type of elements inside the Channel.");
-    AddComment(R"DOC(
-Channel Create Operator.
-
-This operator creates an object of the VarType Channel and returns it.
-)DOC");
-  }
-};
-
-}  // namespace operators
-}  // namespace paddle
-
-REGISTER_OPERATOR(channel_create, paddle::operators::ChannelCreateOp,
-                  paddle::framework::EmptyGradOpMaker,
-                  paddle::operators::ChannelCreateOpMaker);
diff --git a/paddle/fluid/operators/channel_recv_op.cc b/paddle/fluid/operators/channel_recv_op.cc
deleted file mode 100644
index 101015e837..0000000000
--- a/paddle/fluid/operators/channel_recv_op.cc
+++ /dev/null
@@ -1,98 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "paddle/fluid/framework/channel.h"
-#include <paddle/fluid/framework/lod_rank_table.h>
-#include <paddle/fluid/framework/lod_tensor_array.h>
-#include <paddle/fluid/framework/reader.h>
-#include "paddle/fluid/framework/op_registry.h"
-#include "paddle/fluid/framework/var_type.h"
-#include "paddle/fluid/operators/concurrency/channel_util.h"
-#include "paddle/fluid/operators/math/math_function.h"
-
-static constexpr char Channel[] = "Channel";
-static constexpr char Status[] = "Status";
-static constexpr char Out[] = "Out";
-
-namespace paddle {
-namespace operators {
-
-void SetReceiveStatus(const platform::Place &dev_place,
-                      framework::Variable *status_var, bool status) {
-  auto cpu = platform::CPUPlace();
-  auto status_tensor =
-      status_var->GetMutable<framework::LoDTensor>()->mutable_data<bool>({1},
-                                                                         cpu);
-  status_tensor[0] = status;
-}
-
-class ChannelRecvOp : public framework::OperatorBase {
- public:
-  ChannelRecvOp(const std::string &type,
-                const framework::VariableNameMap &inputs,
-                const framework::VariableNameMap &outputs,
-                const framework::AttributeMap &attrs)
-      : framework::OperatorBase(type, inputs, outputs, attrs) {}
-
-  void InferShape(framework::InferShapeContext *ctx) const {
-    PADDLE_ENFORCE(ctx->HasInput(Channel),
-                   "Input(Channel) of ChannelRecvOp should not be null.");
-    PADDLE_ENFORCE(ctx->HasOutput(Out),
-                   "Input(Channel) of ChannelRecvOp should not be null.");
-    PADDLE_ENFORCE(ctx->HasOutput(Status),
-                   "Output(Status) of ChannelRecvOp should not be null.");
-    ctx->SetOutputDim("Status", {1});
-  }
-
- private:
-  void RunImpl(const framework::Scope &scope,
-               const platform::Place &dev_place) const override {
-    // Get the channel holder created by channel_create op, passed as input.
-    framework::ChannelHolder *ch =
-        scope.FindVar(Input(Channel))->GetMutable<framework::ChannelHolder>();
-    auto output_var = scope.FindVar(Output(Out));
-    // Receive the data from the channel.
-    bool ok = concurrency::ChannelReceive(ch, output_var);
-
-    // Set the status output of the `ChannelReceive` call.
-    SetReceiveStatus(dev_place, scope.FindVar(Output(Status)), ok);
-  }
-};
-
-class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
- public:
-  void Make() override {
-    AddInput(Channel,
-             "(Channel) A variable which \"receives\" the a value sent"
-             "to it by a channel_send op.")
-        .AsDuplicable();
-    AddOutput(Out,
-              "(Variable) Output Variable that will hold the data received"
-              " from the Channel")
-        .AsDuplicable();
-    AddOutput(Status,
-              "(Tensor) An LoD Tensor that returns a boolean status of the"
-              "result of the receive operation.")
-        .AsDuplicable();
-    AddComment(R"DOC(
-)DOC");
-  }
-};
-
-}  // namespace operators
-}  // namespace paddle
-
-REGISTER_OPERATOR(channel_recv, paddle::operators::ChannelRecvOp,
-                  paddle::framework::EmptyGradOpMaker,
-                  paddle::operators::ChannelRecvOpMaker);
diff --git a/paddle/fluid/operators/channel_send_op.cc b/paddle/fluid/operators/channel_send_op.cc
deleted file mode 100644
index 67d6deb511..0000000000
--- a/paddle/fluid/operators/channel_send_op.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "paddle/fluid/framework/channel.h"
-#include <paddle/fluid/framework/lod_rank_table.h>
-#include <paddle/fluid/framework/lod_tensor_array.h>
-#include <paddle/fluid/framework/reader.h>
-#include "paddle/fluid/framework/op_registry.h"
-#include "paddle/fluid/framework/var_type.h"
-#include "paddle/fluid/operators/concurrency/channel_util.h"
-#include "paddle/fluid/operators/math/math_function.h"
-
-static constexpr char Channel[] = "Channel";
-static constexpr char X[] = "X";
-
-namespace paddle {
-namespace operators {
-
-class ChannelSendOp : public framework::OperatorBase {
- public:
-  ChannelSendOp(const std::string &type,
-                const framework::VariableNameMap &inputs,
-                const framework::VariableNameMap &outputs,
-                const framework::AttributeMap &attrs)
-      : framework::OperatorBase(type, inputs, outputs, attrs) {}
-
-  void InferShape(framework::InferShapeContext *ctx) const {
-    PADDLE_ENFORCE(ctx->HasInput(Channel),
-                   "Input(Channel) of ChannelSendOp should not be null.");
-    PADDLE_ENFORCE(ctx->HasInput(X),
-                   "Input(X) of ChannelSendOp should not be null.");
-  }
-
- private:
-  void RunImpl(const framework::Scope &scope,
-               const platform::Place &dev_place) const override {
-    // Get the channel holder created by channel_create op, passed as input.
-    framework::ChannelHolder *ch =
-        scope.FindVar(Input(Channel))->GetMutable<framework::ChannelHolder>();
-    auto input_var = scope.FindVar(Input(X));
-
-    // Send the input data through the channel.
-    concurrency::ChannelSend(ch, input_var);
-  }
-};
-
-class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
- public:
-  void Make() override {
-    AddInput(Channel,
-             "(Channel) A variable which \"sends\" the passed in value to "
-             "a listening receiver.")
-        .AsDuplicable();
-    AddInput(X, "(Variable) The value which gets sent by the channel.")
-        .AsDuplicable();
-    AddComment(R"DOC(
-)DOC");
-  }
-};
-}  // namespace operators
-}  // namespace paddle
-
-REGISTER_OPERATOR(channel_send, paddle::operators::ChannelSendOp,
-                  paddle::framework::EmptyGradOpMaker,
-                  paddle::operators::ChannelSendOpMaker);
diff --git a/paddle/fluid/operators/concurrency/CMakeLists.txt b/paddle/fluid/operators/concurrency/CMakeLists.txt
deleted file mode 100644
index e4617440d1..0000000000
--- a/paddle/fluid/operators/concurrency/CMakeLists.txt
+++ /dev/null
@@ -1 +0,0 @@
-cc_library(concurrency SRCS channel_util.cc DEPS device_context framework_proto boost eigen3)
diff --git a/paddle/fluid/operators/concurrency/channel_util.cc b/paddle/fluid/operators/concurrency/channel_util.cc
deleted file mode 100644
index fba4abf189..0000000000
--- a/paddle/fluid/operators/concurrency/channel_util.cc
+++ /dev/null
@@ -1,111 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "paddle/fluid/operators/concurrency/channel_util.h"
-#include "paddle/fluid/framework/var_type.h"
-
-namespace poc = paddle::operators::concurrency;
-
-void poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) {
-  auto type = framework::ToVarType(var->Type());
-  if (type == framework::proto::VarType_Type_LOD_TENSOR)
-    ch->Send(var->GetMutable<framework::LoDTensor>());
-  else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
-    ch->Send(var->GetMutable<framework::LoDRankTable>());
-  else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
-    ch->Send(var->GetMutable<framework::LoDTensorArray>());
-  else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
-    ch->Send(var->GetMutable<framework::SelectedRows>());
-  else if (type == framework::proto::VarType_Type_READER)
-    ch->Send(var->GetMutable<framework::ReaderHolder>());
-  else if (type == framework::proto::VarType_Type_CHANNEL)
-    ch->Send(var->GetMutable<framework::ChannelHolder>());
-  else
-    PADDLE_THROW("ChannelSend:Unsupported type");
-}
-
-bool poc::ChannelReceive(framework::ChannelHolder *ch,
-                         framework::Variable *var) {
-  // Get type of channel and use that to call mutable data for Variable
-  auto type = framework::ToVarType(ch->Type());
-  if (type == framework::proto::VarType_Type_LOD_TENSOR)
-    return ch->Receive(var->GetMutable<framework::LoDTensor>());
-  else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
-    return ch->Receive(var->GetMutable<framework::LoDRankTable>());
-  else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
-    return ch->Receive(var->GetMutable<framework::LoDTensorArray>());
-  else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
-    return ch->Receive(var->GetMutable<framework::SelectedRows>());
-  else if (type == framework::proto::VarType_Type_READER)
-    return ch->Receive(var->GetMutable<framework::ReaderHolder>());
-  else if (type == framework::proto::VarType_Type_CHANNEL)
-    return ch->Receive(var->GetMutable<framework::ChannelHolder>());
-  else
-    PADDLE_THROW("ChannelReceive:Unsupported type");
-}
-
-void poc::ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
-                            framework::Variable *var,
-                            std::shared_ptr<std::condition_variable_any> cond,
-                            std::function<bool(framework::ChannelAction)> cb) {
-  auto type = framework::ToVarType(var->Type());
-  if (type == framework::proto::VarType_Type_LOD_TENSOR) {
-    ch->AddToSendQ(referrer, var->GetMutable<framework::LoDTensor>(), cond, cb);
-  } else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) {
-    ch->AddToSendQ(referrer, var->GetMutable<framework::LoDRankTable>(), cond,
-                   cb);
-  } else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) {
-    ch->AddToSendQ(referrer, var->GetMutable<framework::LoDTensorArray>(), cond,
-                   cb);
-  } else if (type == framework::proto::VarType_Type_SELECTED_ROWS) {
-    ch->AddToSendQ(referrer, var->GetMutable<framework::SelectedRows>(), cond,
-                   cb);
-  } else if (type == framework::proto::VarType_Type_READER) {
-    ch->AddToSendQ(referrer, var->GetMutable<framework::ReaderHolder>(), cond,
-                   cb);
-  } else if (type == framework::proto::VarType_Type_CHANNEL) {
-    ch->AddToSendQ(referrer, var->GetMutable<framework::ChannelHolder>(), cond,
-                   cb);
-  } else {
-    PADDLE_THROW("ChannelAddToSendQ:Unsupported type");
-  }
-}
-
-void poc::ChannelAddToReceiveQ(
-    framework::ChannelHolder *ch, const void *referrer,
-    framework::Variable *var, std::shared_ptr<std::condition_variable_any> cond,
-    std::function<bool(framework::ChannelAction)> cb) {
-  auto type = framework::ToVarType(var->Type());
-  if (type == framework::proto::VarType_Type_LOD_TENSOR) {
-    ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDTensor>(), cond,
-                      cb);
-  } else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) {
-    ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDRankTable>(),
-                      cond, cb);
-  } else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) {
-    ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDTensorArray>(),
-                      cond, cb);
-  } else if (type == framework::proto::VarType_Type_SELECTED_ROWS) {
-    ch->AddToReceiveQ(referrer, var->GetMutable<framework::SelectedRows>(),
-                      cond, cb);
-  } else if (type == framework::proto::VarType_Type_READER) {
-    ch->AddToReceiveQ(referrer, var->GetMutable<framework::ReaderHolder>(),
-                      cond, cb);
-  } else if (type == framework::proto::VarType_Type_CHANNEL) {
-    ch->AddToReceiveQ(referrer, var->GetMutable<framework::ChannelHolder>(),
-                      cond, cb);
-  } else {
-    PADDLE_THROW("ChannelAddToReceiveQ:Unsupported type");
-  }
-}
diff --git a/paddle/fluid/operators/concurrency/channel_util.h b/paddle/fluid/operators/concurrency/channel_util.h
deleted file mode 100644
index cd18ca78c6..0000000000
--- a/paddle/fluid/operators/concurrency/channel_util.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-
-#include "paddle/fluid/framework/channel.h"
-#include "paddle/fluid/framework/variable.h"
-
-namespace paddle {
-namespace operators {
-namespace concurrency {
-
-void ChannelSend(framework::ChannelHolder *ch, framework::Variable *var);
-bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var);
-
-void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
-                       framework::Variable *var,
-                       std::shared_ptr<std::condition_variable_any> cond,
-                       std::function<bool(framework::ChannelAction)> cb);
-void ChannelAddToReceiveQ(framework::ChannelHolder *ch, const void *referrer,
-                          framework::Variable *var,
-                          std::shared_ptr<std::condition_variable_any> cond,
-                          std::function<bool(framework::ChannelAction)> cb);
-
-}  // namespace concurrency
-}  // namespace operators
-}  // namespace paddle
diff --git a/paddle/fluid/operators/select_op.cc b/paddle/fluid/operators/select_op.cc
deleted file mode 100644
index e71841d4d1..0000000000
--- a/paddle/fluid/operators/select_op.cc
+++ /dev/null
@@ -1,419 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include <memory>
-#include <thread>  // NOLINT
-#include <vector>
-#include "paddle/fluid/framework/channel.h"
-#include "paddle/fluid/framework/executor.h"
-#include "paddle/fluid/framework/lod_tensor.h"
-#include "paddle/fluid/framework/op_registry.h"
-#include "paddle/fluid/operators/concurrency/channel_util.h"
-
-#include <boost/tokenizer.hpp>
-
-namespace paddle {
-namespace operators {
-
-static constexpr char kX[] = "X";
-static constexpr char kCaseToExecute[] = "case_to_execute";
-static constexpr char kOutputs[] = "Out";
-
-static constexpr char kCases[] = "cases";
-static constexpr char kCasesBlock[] = "sub_block";
-
-class SelectOp : public framework::OperatorBase {
- public:
-  SelectOp(const std::string &type, const framework::VariableNameMap &inputs,
-           const framework::VariableNameMap &outputs,
-           const framework::AttributeMap &attrs)
-      : framework::OperatorBase(type, inputs, outputs, attrs) {}
-
- private:
-  enum class SelectOpCaseType {
-    DEFAULT = 0,
-    SEND = 1,
-    RECEIVE = 2,
-  };
-
-  struct SelectOpCase {
-    int caseIndex;
-    SelectOpCaseType caseType;
-    std::string channelName;
-    std::string varName;
-
-    SelectOpCase() {}
-
-    SelectOpCase(int caseIndex, SelectOpCaseType caseType,
-                 std::string channelName, std::string varName)
-        : caseIndex(caseIndex),
-          caseType(caseType),
-          channelName(channelName),
-          varName(varName) {}
-  };
-
-  void RunImpl(const framework::Scope &scope,
-               const platform::Place &dev_place) const override {
-    std::vector<std::string> casesConfigs =
-        Attr<std::vector<std::string>>(kCases);
-
-    framework::BlockDesc *casesBlock =
-        Attr<framework::BlockDesc *>(kCasesBlock);
-
-    framework::Scope &casesBlockScope = scope.NewScope();
-
-    std::string caseToExecuteVarName = Input(kCaseToExecute);
-    framework::Variable *caseToExecuteVar =
-        casesBlockScope.FindVar(caseToExecuteVarName);
-
-    // Construct cases from "conditional_block_op"(s) in the casesBlock
-    std::vector<std::shared_ptr<SelectOpCase>> cases =
-        ParseAndShuffleCases(&casesConfigs);
-
-    // Get all unique channels involved in select
-    std::set<framework::ChannelHolder *> channelsSet;
-    for (auto c : cases) {
-      if (!c->channelName.empty()) {
-        auto channelVar = scope.FindVar(c->channelName);
-        framework::ChannelHolder *ch =
-            channelVar->GetMutable<framework::ChannelHolder>();
-
-        if (channelsSet.find(ch) == channelsSet.end()) {
-          channelsSet.insert(ch);
-        }
-      }
-    }
-
-    // Order all channels by their pointer address
-    std::vector<framework::ChannelHolder *> channels(channelsSet.begin(),
-                                                     channelsSet.end());
-    std::sort(channels.begin(), channels.end());
-
-    // Poll all cases
-    int32_t caseToExecute = pollCases(&scope, &cases, channels);
-
-    // At this point, the case to execute has already been determined,
-    // so we can proceed with executing the cases block
-    framework::LoDTensor *caseToExecuteTensor =
-        caseToExecuteVar->GetMutable<framework::LoDTensor>();
-    caseToExecuteTensor->data<int32_t>()[0] = caseToExecute;
-
-    // Execute the cases block, only one case will be executed since we set the
-    // case_to_execute value to the index of the case we want to execute
-    framework::Executor executor(dev_place);
-    framework::ProgramDesc *program = casesBlock->Program();
-    executor.Run(*program, &casesBlockScope, casesBlock->ID(),
-                 false /*create_local_scope*/);
-  }
-
-  /**
-   * Goes through all operators in the casesConfigs and processes
-   * "conditional_block" operators.  These operators are mapped to our
-   * SelectOpCase objects.  We randomize the case orders, and set the
-   * default case (if any exists) as the last case)
-   * @param casesBlock
-   * @return
-   */
-  std::vector<std::shared_ptr<SelectOpCase>> ParseAndShuffleCases(
-      std::vector<std::string> *casesConfigs) const {
-    std::vector<std::shared_ptr<SelectOpCase>> cases;
-    std::shared_ptr<SelectOpCase> defaultCase;
-
-    if (casesConfigs != nullptr) {
-      boost::char_delimiters_separator<char> sep(false, ",", "");
-      for (std::vector<std::string>::iterator itr = casesConfigs->begin();
-           itr < casesConfigs->end(); ++itr) {
-        std::string caseConfig = *itr;
-        boost::tokenizer<> tokens(caseConfig, sep);
-
-        boost::tokenizer<>::iterator tok_iter = tokens.begin();
-        PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case index");
-        std::string caseIndexString = *tok_iter;
-        int caseIndex = std::stoi(caseIndexString);
-
-        ++tok_iter;
-        PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case type");
-        std::string caseTypeString = *tok_iter;
-        SelectOpCaseType caseType = (SelectOpCaseType)std::stoi(caseTypeString);
-
-        std::string caseChannel;
-        std::string caseChannelVar;
-
-        ++tok_iter;
-        if (caseType != SelectOpCaseType::DEFAULT) {
-          PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case channel");
-          caseChannel = *tok_iter;
-
-          ++tok_iter;
-          PADDLE_ENFORCE(tok_iter != tokens.end(),
-                         "Cannot get case channel variable");
-          caseChannelVar = *tok_iter;
-        }
-
-        auto c = std::make_shared<SelectOpCase>(caseIndex, caseType,
-                                                caseChannel, caseChannelVar);
-
-        if (caseType == SelectOpCaseType::DEFAULT) {
-          PADDLE_ENFORCE(defaultCase == nullptr,
-                         "Select can only contain one default case.");
-          defaultCase = c;
-        } else {
-          cases.push_back(c);
-        }
-      }
-    }
-
-    // Randomly sort cases, with default case being last
-    std::random_shuffle(cases.begin(), cases.end());
-    if (defaultCase != nullptr) {
-      cases.push_back(defaultCase);
-    }
-
-    return cases;
-  }
-
-  /**
-   * This method will recursively poll the cases and determines if any case
-   * condition is true.
-   * If none of the cases conditions are true (and there is no default case),
-   * then block
-   * the thread.  The thread may be woken up by a channel operation, at which
-   * point we
-   * execute the case.
-   * @param scope
-   * @param cases
-   * @param channels
-   * @return
-   */
-  int32_t pollCases(const framework::Scope *scope,
-                    std::vector<std::shared_ptr<SelectOpCase>> *cases,
-                    std::vector<framework::ChannelHolder *> channels) const {
-    // Lock all involved channels
-    lockChannels(channels);
-
-    std::atomic<int> caseToExecute(-1);
-
-    std::vector<std::shared_ptr<SelectOpCase>>::iterator it = cases->begin();
-    while (it != cases->end()) {
-      std::shared_ptr<SelectOpCase> c = *it;
-
-      auto chVar = scope->FindVar(c->channelName);
-      framework::ChannelHolder *ch =
-          chVar->GetMutable<framework::ChannelHolder>();
-
-      switch (c->caseType) {
-        case SelectOpCaseType::SEND:
-          PADDLE_ENFORCE(!ch->IsClosed(), "Cannot send to a closed channel");
-          if (ch->CanSend()) {
-            // We can send to channel directly, send the data to channel
-            // and execute case
-            auto chVar = scope->FindVar(c->varName);
-            concurrency::ChannelSend(ch, chVar);
-            caseToExecute = c->caseIndex;
-          }
-          break;
-        case SelectOpCaseType::RECEIVE:
-          if (ch->CanReceive()) {
-            // We can receive from channel directly, send the data to channel
-            // and execute case
-            auto chVar = scope->FindVar(c->varName);
-            concurrency::ChannelReceive(ch, chVar);
-            caseToExecute = c->caseIndex;
-          }
-          break;
-        case SelectOpCaseType::DEFAULT:
-          caseToExecute = c->caseIndex;
-          break;
-      }
-
-      if (caseToExecute != -1) {
-        // We found a case to execute, stop looking at other case statements
-        break;
-      }
-
-      ++it;
-    }
-
-    if (caseToExecute == -1) {
-      // None of the cases are eligible to execute, enqueue current thread
-      // into all the sending/receiving queue of each involved channel
-      std::atomic<bool> completed(false);
-      std::recursive_mutex mutex;
-      std::unique_lock<std::recursive_mutex> lock{mutex};
-      // std::condition_variable_any selectCond;
-      auto selectCond = std::make_shared<std::condition_variable_any>();
-
-      std::recursive_mutex callbackMutex;
-      pushThreadOnChannelQueues(scope, cases, selectCond, &caseToExecute,
-                                &completed, &callbackMutex);
-
-      // TODO(thuan): Atomically unlock all channels and sleep current thread
-      unlockChannels(channels);
-      selectCond->wait(lock, [&completed]() { return completed.load(); });
-
-      // Select has been woken up by case operation
-      lockChannels(channels);
-      removeThreadOnChannelQueues(scope, cases);
-
-      if (caseToExecute == -1) {
-        // Recursively poll cases, since we were woken up by a channel close
-        // TODO(thuan): Need to test if this is a valid case
-        unlockChannels(channels);
-        return pollCases(scope, cases, channels);
-      }
-    }
-
-    // At this point, caseToExecute != -1, and we can proceed with executing
-    // the case block
-    unlockChannels(channels);
-
-    return caseToExecute;
-  }
-
-  void lockChannels(std::vector<framework::ChannelHolder *> chs) const {
-    std::vector<framework::ChannelHolder *>::iterator it = chs.begin();
-    while (it != chs.end()) {
-      framework::ChannelHolder *ch = *it;
-      ch->Lock();
-      ++it;
-    }
-  }
-
-  void unlockChannels(std::vector<framework::ChannelHolder *> chs) const {
-    std::vector<framework::ChannelHolder *>::reverse_iterator it = chs.rbegin();
-    while (it != chs.rend()) {
-      framework::ChannelHolder *ch = *it;
-      ch->Unlock();
-      ++it;
-    }
-  }
-
-  void pushThreadOnChannelQueues(
-      const framework::Scope *scope,
-      std::vector<std::shared_ptr<SelectOpCase>> *cases,
-      std::shared_ptr<std::condition_variable_any> rCond,
-      std::atomic<int> *caseToExecute, std::atomic<bool> *completed,
-      std::recursive_mutex *callbackMutex) const {
-    std::vector<std::shared_ptr<SelectOpCase>>::iterator it = cases->begin();
-    while (it != cases->end()) {
-      std::shared_ptr<SelectOpCase> c = *it;
-
-      auto chVar = scope->FindVar(c->channelName);
-      framework::ChannelHolder *ch =
-          chVar->GetMutable<framework::ChannelHolder>();
-
-      std::function<bool(framework::ChannelAction channelAction)> cb =
-          [&caseToExecute, &completed, &callbackMutex,
-           c](framework::ChannelAction channelAction) {
-            std::lock_guard<std::recursive_mutex> lock{*callbackMutex};
-
-            bool canProcess = false;
-            if (!(*completed)) {
-              // If the channel wasn't closed, we set the caseToExecute index
-              // as this current case
-              if (channelAction != framework::ChannelAction::CLOSE) {
-                *caseToExecute = c->caseIndex;
-              }
-              // This will allow our conditional variable to break out of wait
-              *completed = true;
-              canProcess = true;
-            }
-
-            return canProcess;
-          };
-
-      switch (c->caseType) {
-        case SelectOpCaseType::SEND: {
-          auto chOutputVar = scope->FindVar(c->varName);
-          concurrency::ChannelAddToSendQ(ch, this, chOutputVar, rCond, cb);
-          break;
-        }
-        case SelectOpCaseType::RECEIVE: {
-          auto chOutputVar = scope->FindVar(c->varName);
-          concurrency::ChannelAddToReceiveQ(ch, this, chOutputVar, rCond, cb);
-          break;
-        }
-        default:
-          break;
-      }
-      ++it;
-    }
-  }
-
-  void removeThreadOnChannelQueues(
-      const framework::Scope *scope,
-      std::vector<std::shared_ptr<SelectOpCase>> *cases) const {
-    std::vector<std::shared_ptr<SelectOpCase>>::iterator it = cases->begin();
-    while (it != cases->end()) {
-      std::shared_ptr<SelectOpCase> c = *it;
-
-      auto chVar = scope->FindVar(c->channelName);
-      framework::ChannelHolder *ch =
-          chVar->GetMutable<framework::ChannelHolder>();
-      switch (c->caseType) {
-        case SelectOpCaseType::SEND: {
-          ch->RemoveFromSendQ(this);
-          break;
-        }
-        case SelectOpCaseType::RECEIVE: {
-          ch->RemoveFromReceiveQ(this);
-          break;
-        }
-        default:
-          break;
-      }
-      ++it;
-    }
-  }
-};
-
-class SelectOpMaker : public framework::OpProtoAndCheckerMaker {
- public:
-  void Make() override {
-    AddInput(kX,
-             "A set of variables, which are required by operators inside the "
-             "cases of Select Op")
-        .AsDuplicable();
-    AddInput(kCaseToExecute,
-             "(Int) The variable the sets the index of the case to execute, "
-             "after evaluating the channels being sent to and received from")
-        .AsDuplicable();
-    AddOutput(kOutputs,
-              "A set of variables, which will be assigned with values "
-              "generated by the operators inside the cases of Select Op.")
-        .AsDuplicable();
-    AddAttr<std::vector<std::string>>(kCases,
-                                      "(String vector) Serialized list of"
-                                      "all cases in the select op. Each"
-                                      "case is serialized as: "
-                                      "'<index>,<type>,<channel>,<value>'"
-                                      "where type is 0 for default, 1 for"
-                                      "send, and 2 for receive"
-                                      "No channel and values are needed for"
-                                      "default cases.");
-    AddAttr<framework::BlockDesc *>(kCasesBlock,
-                                    "The cases block inside select_op");
-    AddComment(R"DOC(
-)DOC");
-  }
-};
-
-// TODO(thuan): Implement Gradient Operator for SELECT_OP
-
-}  // namespace operators
-}  // namespace paddle
-
-REGISTER_OPERATOR(select, paddle::operators::SelectOp,
-                  paddle::framework::EmptyGradOpMaker,
-                  paddle::operators::SelectOpMaker);
diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc
index a5bc441220..3b22718a8c 100644
--- a/paddle/fluid/pybind/protobuf.cc
+++ b/paddle/fluid/pybind/protobuf.cc
@@ -214,7 +214,6 @@ void BindVarDsec(pybind11::module *m) {
       .def("set_shapes", &pd::VarDesc::SetShapes)
       .def("set_dtype", &pd::VarDesc::SetDataType)
       .def("set_dtypes", &pd::VarDesc::SetDataTypes)
-      .def("set_capacity", &pd::VarDesc::SetCapacity)
       .def("shape", &pd::VarDesc::GetShape,
            pybind11::return_value_policy::reference)
       .def("shapes", &pd::VarDesc::GetShapes,
@@ -251,7 +250,6 @@ void BindVarDsec(pybind11::module *m) {
       .value("STEP_SCOPES", pd::proto::VarType::STEP_SCOPES)
       .value("LOD_RANK_TABLE", pd::proto::VarType::LOD_RANK_TABLE)
       .value("LOD_TENSOR_ARRAY", pd::proto::VarType::LOD_TENSOR_ARRAY)
-      .value("CHANNEL", pd::proto::VarType::CHANNEL)
       .value("PLACE_LIST", pd::proto::VarType::PLACE_LIST)
       .value("READER", pd::proto::VarType::READER)
       .value("RAW", pd::proto::VarType::RAW);
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index ef2f1f2a20..295af1c583 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -21,7 +21,6 @@ limitations under the License. */
 #include <utility>
 #include <vector>
 
-#include "paddle/fluid/framework/channel.h"
 #include "paddle/fluid/framework/executor.h"
 #include "paddle/fluid/framework/feed_fetch_method.h"
 #include "paddle/fluid/framework/framework.pb.h"
diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py
deleted file mode 100644
index e375fdef9c..0000000000
--- a/python/paddle/fluid/concurrency.py
+++ /dev/null
@@ -1,454 +0,0 @@
-#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-from .layers.control_flow import BlockGuard, equal
-from .framework import Operator
-from .layer_helper import LayerHelper, unique_name
-from .layers import fill_constant
-from . import core
-
-__all__ = [
-    'make_channel', 'channel_send', 'channel_recv', 'channel_close', 'Select'
-]
-
-
-class Go(BlockGuard):
-    def __init__(self, name=None):
-        self.helper = LayerHelper("go", name=name)
-        super(Go, self).__init__(self.helper.main_program)
-
-    def __enter__(self):
-        super(Go, self).__enter__()
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        if exc_type is not None:
-            return False
-        self._construct_go_op()
-        return super(Go, self).__exit__(exc_type, exc_val, exc_tb)
-
-    def _construct_go_op(self):
-        main_program = self.helper.main_program
-        go_block = main_program.current_block()
-        parent_block = main_program.block(main_program.current_block()
-                                          .parent_idx)
-
-        inner_outputs = set()
-        x_name_list = set()
-        for op in go_block.ops:
-            # Iterate over all operators, get all the inputs
-            # and add as input to the Go operator.
-            for iname in op.input_names:
-                for in_var_name in op.input(iname):
-                    if in_var_name not in inner_outputs:
-                        x_name_list.add(in_var_name)
-
-            for oname in op.output_names:
-                for out_var_name in op.output(oname):
-                    inner_outputs.add(out_var_name)
-
-        # Iterate over all operators , get all the outputs
-        # add to the output list of Go operator only if
-        # they exist in the parent block.
-        out_vars = []
-        for inner_out_name in inner_outputs:
-            if inner_out_name in parent_block.vars:
-                out_vars.append(parent_block.var(inner_out_name))
-
-        parent_block.append_op(
-            type='go',
-            inputs={
-                'X': [
-                    parent_block._var_recursive(x_name)
-                    for x_name in x_name_list
-                ]
-            },
-            outputs={},
-            attrs={'sub_block': go_block})
-
-
-class SelectCase(object):
-    DEFAULT = 0
-    SEND = 1
-    RECEIVE = 2
-
-    def __init__(self,
-                 select,
-                 case_idx,
-                 case_to_execute,
-                 channel_action_fn=None,
-                 channel=None,
-                 value=None,
-                 is_copy=False):
-        self.select = select
-        self.helper = LayerHelper('conditional_block')
-        self.main_program = self.helper.main_program
-        self.is_scalar_condition = True
-
-        self.case_to_execute = case_to_execute
-        self.idx = case_idx
-
-        # Since we aren't going to use the `channel_send` or `channel_recv`
-        # functions directly, we just need to capture the name.
-        self.action = (self.SEND
-                       if channel_action_fn.__name__ == ('channel_send') else
-                       self.RECEIVE) if channel_action_fn else self.DEFAULT
-
-        X = value
-        if self.action == self.SEND and is_copy:
-            # We create of copy of the data we want to send
-            copied_X = self.select.parent_block.create_var(
-                name=unique_name.generate(value.name + '_copy'),
-                type=value.type,
-                dtype=value.dtype,
-                shape=value.shape,
-                lod_level=value.lod_level,
-                capacity=value.capacity
-                if hasattr(value, 'capacity') else None, )
-
-            self.select.parent_block.append_op(
-                type="assign", inputs={"X": value}, outputs={"Out": copied_X})
-            X = copied_X
-
-        self.value = X
-        self.channel = channel
-
-    def __enter__(self):
-        self.block = self.main_program._create_block()
-
-    def construct_op(self):
-        main_program = self.helper.main_program
-        cases_block = main_program.current_block()
-
-        inner_outputs = set()
-        input_set = set()
-        params = set()
-
-        for op in self.block.ops:
-            # Iterate over all operators, get all the inputs
-            # and add as input to the SelectCase operator.
-            for iname in op.input_names:
-                for in_var_name in op.input(iname):
-                    if in_var_name not in inner_outputs:
-                        input_set.add(in_var_name)
-
-            for oname in op.output_names:
-                for out_var_name in op.output(oname):
-                    inner_outputs.add(out_var_name)
-
-        param_list = [
-            cases_block.var(each_name) for each_name in params
-            if each_name not in input_set
-        ]
-
-        # Iterate over all operators, get all the outputs
-        # add to the output list of SelectCase operator only if
-        # they exist in the parent block.
-        out_vars = []
-        for inner_out_name in inner_outputs:
-            if inner_out_name in cases_block.vars:
-                out_vars.append(cases_block.var(inner_out_name))
-
-        # First, create an op that will determine whether or not this is the
-        # conditional variable to execute.
-        should_execute_block = equal(
-            fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.INT32, value=self.idx),
-            self.case_to_execute)
-
-        step_scope = cases_block.create_var(
-            type=core.VarDesc.VarType.STEP_SCOPES)
-
-        cases_block.append_op(
-            type='conditional_block',
-            inputs={'X': [should_execute_block],
-                    'Params': param_list},
-            outputs={'Out': out_vars,
-                     'Scope': [step_scope]},
-            attrs={
-                'sub_block': self.block,
-                'is_scalar_condition': self.is_scalar_condition
-            })
-
-        return '%s,%s,%s,%s' % (self.idx, self.action, self.channel.name
-                                if self.channel else '', self.value.name
-                                if self.value else '')
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        self.main_program._rollback()
-        if exc_type is not None:
-            return False  # re-raise exception
-        return True
-
-
-class Select(BlockGuard):
-    def __init__(self, name=None):
-        self.helper = LayerHelper('select', name=name)
-        self.parent_block = self.helper.main_program.current_block()
-        self.cases = []
-
-        super(Select, self).__init__(self.helper.main_program)
-        self.case_to_execute = fill_constant(
-            shape=[1], dtype=core.VarDesc.VarType.INT32, value=-1)
-
-    def __enter__(self):
-        super(Select, self).__enter__()
-        return self
-
-    def case(self, channel_action_fn, channel, value, is_copy=False):
-        """Create a new block for this condition.
-        """
-        select_case = SelectCase(self,
-                                 len(self.cases), self.case_to_execute,
-                                 channel_action_fn, channel, value, is_copy)
-
-        self.cases.append(select_case)
-
-        return select_case
-
-    def default(self):
-        """Create a default case block for this condition.
-        """
-        default_case = SelectCase(self, len(self.cases), self.case_to_execute)
-
-        self.cases.append(default_case)
-
-        return default_case
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        if exc_type is not None:
-            return False
-
-        # Create a select op and another block to wrap its
-        # case blocks.
-        select_block = self.helper.main_program.current_block()
-        parent_block = self.helper.main_program.block(select_block.parent_idx)
-
-        # Construct each case op, inside the newly created select block.
-        serialized_cases = []
-        for case in self.cases:
-            serialized_cases.append(case.construct_op())
-
-        intermediate = set()
-        params = set()
-
-        for case_block in select_block.ops:
-            if case_block.attrs and 'sub_block' in case_block.attrs:
-                for each_op in case_block.attrs['sub_block'].ops:
-                    assert isinstance(each_op, Operator)
-                    for iname in each_op.input_names:
-                        for in_var_name in each_op.input(iname):
-                            if in_var_name not in intermediate:
-                                params.add(in_var_name)
-
-                    for oname in each_op.output_names:
-                        for out_var_name in each_op.output(oname):
-                            intermediate.add(out_var_name)
-
-        out_list = [
-            parent_block.var(var_name) for var_name in parent_block.vars
-            if var_name in intermediate
-        ]
-
-        X = [select_block._var_recursive(x_name) for x_name in params]
-
-        # Needs to be used by `equal` inside the cases block.
-        X.append(self.case_to_execute)
-
-        # Construct the select op.
-        parent_block.append_op(
-            type='select',
-            inputs={'X': X,
-                    'case_to_execute': self.case_to_execute},
-            attrs={'sub_block': select_block,
-                   'cases': serialized_cases},
-            outputs={'Out': out_list})
-
-        return super(Select, self).__exit__(exc_type, exc_val, exc_tb)
-
-
-def make_channel(dtype, capacity=0):
-    """
-    Helps implementation of a concurrent program by creating a "channel" of
-    a defined data type. Channels allow for the passing of data in
-    concurrent scenarios - such as when using threads to divide computation.
-    Channels can be used to "send" and "receive" such data concurrently.
-
-    There are two kinds of channels: unbuffered and buffered. Unbuffered
-    channels have no capacity - and thus, block on send and only unblock only
-    once what they have sent has been received.
-
-    On the other hand, buffered channels are initialized with a capacity -
-    and do not block on sends.
-
-    Use this method in combination with `channel_send`, `channel_recv`,
-    `channel_close`, and `Go` to design a concurrent Paddle program.
-
-    Args:
-        dtype (ParamAttr|string): Data type of the data sent in the channel.
-        This data type should be the string name of a numpy data type.
-        capacity (ParamAttr|int): Size of the channel. Defaults to 0 for
-        to create an unbuffered channel.
-
-    Returns:
-        Variable: The channel variable that can be used to send an receive data
-                  of the defined dtype.
-
-    Examples:
-        .. code-block:: python
-
-          ch = fluid.make_channel(dtype='int32', capacity=10)
-          ...
-          # Code to execute in a Go block, which receives the channel data.
-          fluid.channel_send(ch, 100)
-          fluid.channel_close(ch)
-    """
-    helper = LayerHelper('channel_create', **locals())
-    main_program = helper.main_program
-    make_channel_block = main_program.current_block()
-
-    # Make a channel variable (using the channel data type) and make sure it
-    # persists into the global scope.
-    channel = helper.create_variable(
-        name=unique_name.generate('channel'),
-        type=core.VarDesc.VarType.CHANNEL,
-        persistable=True)
-
-    create_channel_op = make_channel_block.append_op(
-        type="channel_create",
-        outputs={"Out": channel},
-        attrs={"data_type": dtype,
-               "capacity": capacity})
-
-    return channel
-
-
-def channel_send(channel, value, is_copy=False):
-    """
-    Sends a value through a channel variable. Used by an unbuffered or buffered
-    channel to pass data from within or to a concurrent Go block, where
-    `channel_recv` to used to get the passed value.
-
-    Args:
-        channel (Variable|Channel): Channel variable created using
-        `make_channel`.
-        value (Variable): Value to send to channel
-        is_copy (bool): Copy data while channel send. If False, then data
-        is moved. The input cannot be used after move. (default False)
-    Returns:
-        Variable: The boolean status on whether or not the channel
-                  successfully sent the passed value.
-
-    Examples:
-        .. code-block:: python
-
-          ch = fluid.make_channel(dtype='int32', capacity=10)
-          ...
-          # Code to execute in a Go block, which receives the channel data.
-          fluid.channel_send(ch, 100)
-    """
-    helper = LayerHelper('channel_send', **locals())
-    main_program = helper.main_program
-    channel_send_block = main_program.current_block()
-
-    X = value
-
-    if is_copy:
-        copied_X = helper.create_variable(
-            name=unique_name.generate(value.name + '_copy'),
-            type=value.type,
-            dtype=value.dtype,
-            shape=value.shape,
-            lod_level=value.lod_level,
-            capacity=value.capacity if hasattr(value, 'capacity') else None)
-
-        assign_op = channel_send_block.append_op(
-            type="assign", inputs={"X": value}, outputs={"Out": copied_X})
-        X = copied_X
-
-    channel_send_block.append_op(
-        type="channel_send", inputs={
-            "Channel": channel,
-            "X": X,
-        })
-
-
-def channel_recv(channel, return_value):
-    """
-    Receives a value through a channel variable. Used by an unbuffered or
-    buffered channel within a concurrent Go block to get data from originally
-    sent using `channel_send`, or from outside such a block where
-    `channel_send` is used to send the value.
-
-    Args:
-        channel (Variable|Channel): Channel variable created using
-        `make_channel`.
-        return_value (Variable): Variable to set as a result of running channel_recv_op
-
-    Returns:
-        Variable: The received value from the channel.
-        Variable: The boolean status on whether or not the channel
-                  successfully received the passed value.
-
-    Examples:
-        .. code-block:: python
-
-          ch = fluid.make_channel(dtype='int32', capacity=10)
-          with fluid.Go():
-            returned_value, return_status = fluid.channel_recv(ch, 'int32')
-
-          # Code to send data through the channel.
-    """
-    helper = LayerHelper('channel_recv', **locals())
-    main_program = helper.main_program
-    channel_recv_block = main_program.current_block()
-
-    status = helper.create_variable(
-        name=unique_name.generate('status'),
-        type=core.VarDesc.VarType.LOD_TENSOR,
-        dtype=core.VarDesc.VarType.BOOL)
-
-    channel_recv_op = channel_recv_block.append_op(
-        type="channel_recv",
-        inputs={"Channel": channel},
-        outputs={"Out": return_value,
-                 "Status": status})
-
-    return return_value, status
-
-
-def channel_close(channel):
-    """
-    Closes a channel created using `make_channel`.
-
-    Args:
-        channel (Variable|Channel): Channel variable created using
-        `make_channel`.
-
-    Examples:
-        .. code-block:: python
-
-          ch = fluid.make_channel(dtype='int32', capacity=10)
-          ...
-          # Code to receive and send data through a channel
-          ...
-          fluid.channel_close(ch)
-    """
-    helper = LayerHelper('channel_close', **locals())
-    main_program = helper.main_program
-    channel_close_block = main_program.current_block()
-
-    channel_close_op = channel_close_block.append_op(
-        type="channel_close", inputs={"Channel": channel})
diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py
index d795b92d79..63988af993 100644
--- a/python/paddle/fluid/framework.py
+++ b/python/paddle/fluid/framework.py
@@ -541,8 +541,7 @@ class Operator(object):
         'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
         'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
         'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
-        'ncclInit', 'channel_create', 'channel_close', 'channel_send',
-        'channel_recv', 'select', 'checkpoint_notify', 'gen_nccl_id'
+        'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id'
     }
 
     def __init__(self,
diff --git a/python/paddle/fluid/tests/no_test_concurrency.py b/python/paddle/fluid/tests/no_test_concurrency.py
deleted file mode 100644
index b5d7676f4a..0000000000
--- a/python/paddle/fluid/tests/no_test_concurrency.py
+++ /dev/null
@@ -1,260 +0,0 @@
-#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import unittest
-import paddle.fluid as fluid
-import paddle.fluid.core as core
-from paddle.fluid import framework, unique_name, layer_helper
-from paddle.fluid.executor import Executor
-from paddle.fluid.layers import fill_constant, assign, While, elementwise_add, Print
-
-
-class TestRoutineOp(unittest.TestCase):
-    def test_simple_routine(self):
-        ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
-
-        # Create LOD_TENSOR<INT64> and put it into the scope.  This placeholder
-        # variable will be filled in and returned by fluid.channel_recv
-        result = self._create_tensor('return_value',
-                                     core.VarDesc.VarType.LOD_TENSOR,
-                                     core.VarDesc.VarType.INT64)
-
-        with fluid.Go():
-            input_value = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.FP64, value=1234)
-            fluid.channel_send(ch, input_value)
-
-        result, status = fluid.channel_recv(ch, result)
-        fluid.channel_close(ch)
-
-        cpu = core.CPUPlace()
-        exe = Executor(cpu)
-
-        outs = exe.run(fetch_list=[result])
-        self.assertEqual(outs[0], 1234)
-
-    def test_daisy_chain(self):
-        '''
-        Mimics classic Daisy-chain test:  https://talks.golang.org/2012/concurrency.slide#39
-        '''
-        n = 100
-
-        leftmost = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
-        left = leftmost
-
-        # TODO(thuan): Use fluid.While() after scope capture is implemented.
-        # https://github.com/PaddlePaddle/Paddle/issues/8502
-        for i in range(n):
-            right = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
-            with fluid.Go():
-                one_tensor = self._create_one_dim_tensor(1)
-                result = self._create_tensor('return_value',
-                                             core.VarDesc.VarType.LOD_TENSOR,
-                                             core.VarDesc.VarType.INT64)
-
-                result, status = fluid.channel_recv(right, result)
-                one_added = fluid.layers.elementwise_add(x=one_tensor, y=result)
-                fluid.channel_send(left, one_added)
-            left = right
-
-        # Trigger the channel propagation by sending a "1" to rightmost channel
-        with fluid.Go():
-            one_tensor = self._create_one_dim_tensor(1)
-            fluid.channel_send(right, one_tensor)
-
-        leftmost_result = self._create_tensor('return_value',
-                                              core.VarDesc.VarType.LOD_TENSOR,
-                                              core.VarDesc.VarType.INT64)
-        leftmost_result, status = fluid.channel_recv(leftmost, leftmost_result)
-
-        cpu = core.CPUPlace()
-        exe = Executor(cpu)
-        leftmost_data = exe.run(fetch_list=[leftmost_result])
-
-        # The leftmost_data should be equal to the number of channels + 1
-        self.assertEqual(leftmost_data[0][0], n + 1)
-
-    def _create_one_dim_tensor(self, value):
-        one_dim_tensor = fill_constant(shape=[1], dtype='int', value=value)
-        one_dim_tensor.stop_gradient = True
-        return one_dim_tensor
-
-    def _create_tensor(self, name, type, dtype):
-        return framework.default_main_program().current_block().create_var(
-            name=unique_name.generate(name), type=type, dtype=dtype)
-
-    def _create_persistable_tensor(self, name, type, dtype):
-        return framework.default_main_program().current_block().create_var(
-            name=unique_name.generate(name),
-            type=type,
-            dtype=dtype,
-            persistable=True)
-
-    def test_select(self):
-        with framework.program_guard(framework.Program()):
-            ch1 = fluid.make_channel(
-                dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
-
-            result1 = self._create_tensor('return_value',
-                                          core.VarDesc.VarType.LOD_TENSOR,
-                                          core.VarDesc.VarType.FP64)
-
-            input_value = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.FP64, value=10)
-
-            with fluid.Select() as select:
-                with select.case(fluid.channel_send, ch1, input_value):
-                    # Execute something.
-                    pass
-
-                with select.default():
-                    pass
-
-            # This should not block because we are using a buffered channel.
-            result1, status = fluid.channel_recv(ch1, result1)
-            fluid.channel_close(ch1)
-
-            cpu = core.CPUPlace()
-            exe = Executor(cpu)
-
-            result = exe.run(fetch_list=[result1])
-            self.assertEqual(result[0][0], 10)
-
-    def test_fibonacci(self):
-        """
-        Mimics Fibonacci Go example: https://tour.golang.org/concurrency/5
-        """
-        with framework.program_guard(framework.Program()):
-            quit_ch_input_var = self._create_persistable_tensor(
-                'quit_ch_input', core.VarDesc.VarType.LOD_TENSOR,
-                core.VarDesc.VarType.INT32)
-            quit_ch_input = fill_constant(
-                shape=[1],
-                dtype=core.VarDesc.VarType.INT32,
-                value=0,
-                out=quit_ch_input_var)
-
-            result = self._create_persistable_tensor(
-                'result', core.VarDesc.VarType.LOD_TENSOR,
-                core.VarDesc.VarType.INT32)
-            fill_constant(
-                shape=[1],
-                dtype=core.VarDesc.VarType.INT32,
-                value=0,
-                out=result)
-
-            x = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
-            y = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.INT32, value=1)
-
-            while_cond = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.BOOL, value=True)
-
-            while_false = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.BOOL, value=False)
-
-            x_tmp = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
-
-            def fibonacci(channel, quit_channel):
-                while_op = While(cond=while_cond)
-                with while_op.block():
-                    result2 = fill_constant(
-                        shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
-
-                    with fluid.Select() as select:
-                        with select.case(
-                                fluid.channel_send, channel, x, is_copy=True):
-                            assign(input=x, output=x_tmp)
-                            assign(input=y, output=x)
-                            assign(elementwise_add(x=x_tmp, y=y), output=y)
-
-                        with select.case(fluid.channel_recv, quit_channel,
-                                         result2):
-                            # Quit
-                            helper = layer_helper.LayerHelper('assign')
-                            helper.append_op(
-                                type='assign',
-                                inputs={'X': [while_false]},
-                                outputs={'Out': [while_cond]})
-
-            ch1 = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
-            quit_ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
-
-            with fluid.Go():
-                for i in range(10):
-                    fluid.channel_recv(ch1, result)
-                    Print(result)
-
-                fluid.channel_send(quit_ch, quit_ch_input)
-
-            fibonacci(ch1, quit_ch)
-
-            fluid.channel_close(ch1)
-            fluid.channel_close(quit_ch)
-
-            cpu = core.CPUPlace()
-            exe = Executor(cpu)
-
-            exe_result = exe.run(fetch_list=[result])
-            self.assertEqual(exe_result[0][0], 34)
-
-    def test_ping_pong(self):
-        """
-        Mimics Ping Pong example: https://gobyexample.com/channel-directions
-        """
-        with framework.program_guard(framework.Program()):
-            result = self._create_tensor('return_value',
-                                         core.VarDesc.VarType.LOD_TENSOR,
-                                         core.VarDesc.VarType.FP64)
-
-            ping_result = self._create_tensor('ping_return_value',
-                                              core.VarDesc.VarType.LOD_TENSOR,
-                                              core.VarDesc.VarType.FP64)
-
-            def ping(ch, message):
-                fluid.channel_send(ch, message, is_copy=True)
-
-            def pong(ch1, ch2):
-                fluid.channel_recv(ch1, ping_result)
-                fluid.channel_send(ch2, ping_result, is_copy=True)
-
-            pings = fluid.make_channel(
-                dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
-            pongs = fluid.make_channel(
-                dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
-
-            msg = fill_constant(
-                shape=[1], dtype=core.VarDesc.VarType.FP64, value=9)
-
-            ping(pings, msg)
-            pong(pings, pongs)
-
-            fluid.channel_recv(pongs, result)
-
-            fluid.channel_close(pings)
-            fluid.channel_close(pongs)
-
-            cpu = core.CPUPlace()
-            exe = Executor(cpu)
-
-            exe_result = exe.run(fetch_list=[result])
-            self.assertEqual(exe_result[0][0], 9)
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/python/paddle/fluid/tests/notest_concurrency.py b/python/paddle/fluid/tests/notest_concurrency.py
deleted file mode 100644
index fd9da4cce0..0000000000
--- a/python/paddle/fluid/tests/notest_concurrency.py
+++ /dev/null
@@ -1,41 +0,0 @@
-#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import unittest
-import paddle.fluid as fluid
-import paddle.fluid.core as core
-from paddle.fluid.executor import Executor
-
-
-class TestRoutineOp(unittest.TestCase):
-    def test_simple_routine(self):
-        ch = fluid.make_channel(
-            dtype=core.VarDesc.VarType.BOOL, name="CreateChannel")
-        with fluid.Go():
-            fluid.channel_send(ch, True)
-
-        result = fluid.channel_recv(ch)
-        fluid.channel_close(ch)
-
-        cpu = core.CPUPlace()
-        exe = Executor(cpu)
-
-        outs = exe.run(fetch_list=[result])
-        self.assertEqual(outs[0], True)
-
-
-if __name__ == '__main__':
-    unittest.main()

From 5fb72d840a7f6e1cb2edb0129a5ec3c3d06aae0d Mon Sep 17 00:00:00 2001
From: Xin Pan <panxin.grad@gmail.com>
Date: Fri, 28 Sep 2018 13:37:51 +0800
Subject: [PATCH 09/13] add header

test=develop
---
 paddle/fluid/operators/distributed/request_handler.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h
index 3dbbd75b1e..5be7095acd 100644
--- a/paddle/fluid/operators/distributed/request_handler.h
+++ b/paddle/fluid/operators/distributed/request_handler.h
@@ -15,6 +15,7 @@
 #pragma once
 
 #include <time.h>
+#include <condition_variable>  // NOLINT
 
 #include <functional>
 #include <string>

From 6746b1fdf3b8fc8426fd5c74032cbe9a97dc6377 Mon Sep 17 00:00:00 2001
From: Xin Pan <panxin.grad@gmail.com>
Date: Sat, 29 Sep 2018 09:47:08 +0800
Subject: [PATCH 10/13] add missing header

test=develop
---
 paddle/fluid/framework/naive_executor.cc         | 7 ++++---
 paddle/fluid/operators/distributed/rpc_server.cc | 1 +
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc
index f681d4ecef..2171213d4d 100644
--- a/paddle/fluid/framework/naive_executor.cc
+++ b/paddle/fluid/framework/naive_executor.cc
@@ -12,11 +12,14 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "paddle/fluid/framework/naive_executor.h"
+#include <string>
+#include <vector>
+
 #include "paddle/fluid/framework/channel.h"
 #include "paddle/fluid/framework/feed_fetch_method.h"
 #include "paddle/fluid/framework/lod_rank_table.h"
 #include "paddle/fluid/framework/lod_tensor_array.h"
+#include "paddle/fluid/framework/naive_executor.h"
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/framework/reader.h"
 #include "paddle/fluid/string/pretty_log.h"
@@ -44,8 +47,6 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
     var->GetMutable<platform::PlaceList>();
   } else if (var_type == proto::VarType::READER) {
     var->GetMutable<ReaderHolder>();
-  } else if (var_type == proto::VarType::CHANNEL) {
-    var->GetMutable<ChannelHolder>();
   } else if (var_type == proto::VarType::RAW) {
     // GetMutable will be called in operator
   } else {
diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc
index 084480ae48..4758dff96c 100644
--- a/paddle/fluid/operators/distributed/rpc_server.cc
+++ b/paddle/fluid/operators/distributed/rpc_server.cc
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <atomic>
 #include <fstream>
 #include <iostream>
 #include <limits>

From 748be49e778f328f2ee7f5c5864ceeaefb2db840 Mon Sep 17 00:00:00 2001
From: qingqing01 <dangqingqing@baidu.com>
Date: Sat, 29 Sep 2018 10:21:43 +0800
Subject: [PATCH 11/13] Fix random fail in Python3 (#13666)

---
 python/paddle/fluid/contrib/tests/test_quantize_transpiler.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py
index 9af3a6c9fd..095e78c053 100644
--- a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py
+++ b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py
@@ -238,7 +238,7 @@ class TestQuantizeTranspiler(unittest.TestCase):
             test_loss2, = exe.run(program=test_program,
                                   feed=feeder.feed(test_data),
                                   fetch_list=[loss])
-            self.assertAlmostEqual(test_loss1, test_loss2, delta=1e-3)
+            self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
             w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0')
                                 .get_tensor())
             self.assertEqual(np.sum(w_freeze), np.sum(w_quant))

From 33b68fdf25a408024e0d1f196327df3c68a029bb Mon Sep 17 00:00:00 2001
From: Xin Pan <panxin.grad@gmail.com>
Date: Sat, 29 Sep 2018 10:41:57 +0800
Subject: [PATCH 12/13] fix compile error

test=develop
---
 paddle/fluid/operators/distributed/rpc_server.cc | 1 -
 paddle/fluid/operators/distributed/rpc_server.h  | 1 +
 2 files changed, 1 insertion(+), 1 deletion(-)

diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc
index 4758dff96c..084480ae48 100644
--- a/paddle/fluid/operators/distributed/rpc_server.cc
+++ b/paddle/fluid/operators/distributed/rpc_server.cc
@@ -12,7 +12,6 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include <atomic>
 #include <fstream>
 #include <iostream>
 #include <limits>
diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h
index d88e8c640f..f3e61e1575 100644
--- a/paddle/fluid/operators/distributed/rpc_server.h
+++ b/paddle/fluid/operators/distributed/rpc_server.h
@@ -14,6 +14,7 @@
 
 #pragma once
 
+#include <atomic>
 #include <set>
 #include <string>
 #include <thread>  // NOLINT

From 642905958aca0c39f561a7ac623a9e2144d2fb0f Mon Sep 17 00:00:00 2001
From: Xin Pan <panxin.grad@gmail.com>
Date: Sat, 29 Sep 2018 10:56:07 +0800
Subject: [PATCH 13/13] fix compile error

test=develop
---
 paddle/fluid/framework/naive_executor.cc         | 1 -
 paddle/fluid/operators/distributed/grpc_client.h | 1 +
 2 files changed, 1 insertion(+), 1 deletion(-)

diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc
index 2171213d4d..53d39513f3 100644
--- a/paddle/fluid/framework/naive_executor.cc
+++ b/paddle/fluid/framework/naive_executor.cc
@@ -15,7 +15,6 @@
 #include <string>
 #include <vector>
 
-#include "paddle/fluid/framework/channel.h"
 #include "paddle/fluid/framework/feed_fetch_method.h"
 #include "paddle/fluid/framework/lod_rank_table.h"
 #include "paddle/fluid/framework/lod_tensor_array.h"
diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h
index 75a3662316..d8e9cee85b 100644
--- a/paddle/fluid/operators/distributed/grpc_client.h
+++ b/paddle/fluid/operators/distributed/grpc_client.h
@@ -15,6 +15,7 @@ limitations under the License. */
 #pragma once
 
 #include <time.h>
+#include <atomic>
 
 #include <chrono>              // NOLINT
 #include <condition_variable>  // NOLINT