diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index 4dca3ceb45..01733fdda2 100644
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -34,6 +34,7 @@ endif()
 pass_library(attention_lstm_fuse_pass inference)
 pass_library(infer_clean_graph_pass inference)
 pass_library(fc_lstm_fuse_pass inference)
+pass_library(embedding_fc_lstm_fuse_pass inference)
 pass_library(fc_gru_fuse_pass inference)
 pass_library(seq_concat_fc_fuse_pass inference)
 
diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
new file mode 100644
index 0000000000..38495125c3
--- /dev/null
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
@@ -0,0 +1,242 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
+#include <string>
+#include "paddle/fluid/framework/lod_tensor.h"
+
+#include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/cpu_vec.h"
+#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+static int BuildFusion(Graph* graph, const std::string& name_scope,
+                       Scope* scope, bool with_fc_bias) {
+  GraphPatternDetector gpd;
+  auto* pattern = gpd.mutable_pattern();
+
+  // Build pattern
+  PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
+                  ->assert_is_op_input("lookup_table")
+                  ->assert_var_not_persistable();
+  patterns::Embedding embedding_pattern(pattern, name_scope);
+  // TODO(jczaja): Intermediate can only be for val that are not used anywhere
+  //               but lookup table output may go into other LSTM (for reverse
+  //               direction)
+  auto* embedding_out = embedding_pattern(x);
+  patterns::FC fc_pattern(pattern, name_scope);
+
+  // fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
+  auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate();
+  patterns::LSTM lstm_pattern(pattern, name_scope);
+  lstm_pattern(fc_out);
+
+  // Create New OpDesc
+  auto embedding_lstm_creator = [&](Node* embedding, Node* W, Node* lstm,
+                                    Node* input, Node* weight_x, Node* weight_h,
+                                    Node* bias, Node* hidden, Node* cell,
+                                    Node* xx, Node* fc_bias) {
+    OpDesc op_desc;
+    op_desc.SetType("fused_embedding_fc_lstm");
+#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
+    SET_IN(Ids, input);
+    SET_IN(WeightH, weight_h);
+    // Neet to have this passed as We need Wc data for peephole connections
+    SET_IN(Bias, bias);
+#undef SET_IN
+
+    // Multiply embeddings with Weights
+    PADDLE_ENFORCE(scope);
+    const std::string& embeddings = patterns::UniqueKey("Embeddings");
+    auto* embeddings_var = scope->Var(embeddings);
+    PADDLE_ENFORCE(embeddings_var);
+    auto* embeddings_tensor =
+        embeddings_var->GetMutable<framework::LoDTensor>();
+    // Get WeightX size: [single_embedding, fc_size]
+    // and embedding size: [dict_size, single_embedding]
+    // and create new size of embeddings eg. [dict_size , hidden_size]
+    auto* embedding_var = scope->FindVar(W->Name());
+    PADDLE_ENFORCE(embedding_var);
+    const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
+
+    const auto& weightx_tensor =
+        scope->FindVar(weight_x->Name())->Get<framework::LoDTensor>();
+    embeddings_tensor->Resize(
+        {embedding_tensor.dims()[0], weightx_tensor.dims()[1]});
+
+    // Multiplie embeddings via WeightsX and add bias
+    auto embedding_data = embedding_tensor.data<float>();
+    auto weightx_data = weightx_tensor.data<float>();
+    auto embeddings_data =
+        embeddings_tensor->mutable_data<float>(platform::CPUPlace());
+
+    // Adding biases to GEMM result to be
+    auto* lstm_bias_var = scope->FindVar(bias->Name());
+    PADDLE_ENFORCE(lstm_bias_var);
+    const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
+
+    auto alpha = 1.0f;
+    auto beta = 1.0f;
+    int m = embedding_tensor.dims()[0];
+    int n = weightx_tensor.dims()[1];
+    int k = embedding_tensor.dims()[1];
+
+    // Copy only gate biases values (only actual bias data, not peephole
+    // weights)
+    std::vector<float> combined_biases(n, 0.0f);
+    memcpy(&combined_biases[0], lstm_bias_tensor.data<float>(),
+           n * sizeof(float));
+
+    if (with_fc_bias) {
+      // Add FC-bias with LSTM-bias (into GEMM result to be)
+      auto* fc_bias_var = scope->FindVar(fc_bias->Name());
+      const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
+      for (int i = 0; i < fc_bias_tensor.numel(); i++) {
+        combined_biases[i] =
+            lstm_bias_tensor.data<float>()[i] + fc_bias_tensor.data<float>()[i];
+      }
+    }
+
+    // broadcast biases
+    std::vector<float> ones(m, 1.0f);
+    paddle::operators::math::CBlas<float>::GEMM(
+        CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1,
+        &combined_biases[0], n, 0.0f, embeddings_data, n);
+
+    // Wx*embeddings
+    paddle::operators::math::CBlas<float>::GEMM(
+        CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
+        embedding_data, k, weightx_data, n, beta, embeddings_data, n);
+    op_desc.SetInput("Embeddings", {embeddings});
+
+    // Create temp variables.
+    const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
+    const std::string BatchedCellPreAct =
+        patterns::UniqueKey("BatchedCellPreAct");
+    const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
+
+    scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
+    scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
+    scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
+
+    op_desc.SetInput("H0", {});
+    op_desc.SetInput("C0", {});
+    op_desc.SetOutput("Hidden", {hidden->Name()});
+    op_desc.SetOutput("Cell", {cell->Name()});
+    op_desc.SetOutput("XX", {xx->Name()});
+    op_desc.SetOutput("BatchedGate", {BatchedGate});
+    op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
+    op_desc.SetOutput("BatchedInput", {BatchedInput});
+    op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
+    op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
+    // TODO(TJ): get from attr
+    op_desc.SetAttr("use_seq", true);
+
+    PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
+    auto* scope = graph->Get<Scope*>(kParamScopeAttr);
+#define OP_SET_OUT(x)                            \
+  const std::string x = patterns::UniqueKey(#x); \
+  op_desc.SetOutput(#x, {x});                    \
+  scope->Var(x)->GetMutable<LoDTensor>()
+    OP_SET_OUT(BatchedCell);
+    OP_SET_OUT(BatchedHidden);
+    OP_SET_OUT(ReorderedH0);
+    OP_SET_OUT(ReorderedC0);
+#undef OP_SET_OUT
+
+    auto* op = graph->CreateOpNode(&op_desc);
+    IR_NODE_LINK_TO(input, op);
+    IR_NODE_LINK_TO(weight_x, op);
+    IR_NODE_LINK_TO(weight_h, op);
+    IR_NODE_LINK_TO(bias, op);
+    IR_NODE_LINK_TO(op, hidden);
+    return op;
+  };
+
+  int fusion_count{0};
+
+  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
+                     Graph* g) {
+    GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, embedding_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(W, W, embedding_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
+
+    // TODO(jczaja): Add support for is_sparse / is_distributed
+    auto is_sparse = boost::get<bool>(lookup_table->Op()->GetAttr("is_sparse"));
+    auto is_distributed =
+        boost::get<bool>(lookup_table->Op()->GetAttr("is_distributed"));
+
+    if (is_sparse == true || is_distributed == true) {
+      return;
+    }
+
+    if (with_fc_bias) {
+      GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
+      GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
+      GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
+      embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
+                             Bias, Hidden, Cell, fc_out, fc_bias);
+      // Remove unneeded nodes.
+      // TODO(jczaja): Proper removing of loopup table
+      std::unordered_set<const Node*> marked_nodes(
+          //{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
+          {mul, lstm, elementwise_add, fc_bias});
+      GraphSafeRemoveNodes(graph, marked_nodes);
+    } else {
+      GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
+      embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
+                             Bias, Hidden, Cell, fc_out, nullptr);
+      // Remove unneeded nodes.
+      // TODO(jczaja): Proper removing of loopup table
+      // std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
+      // lstm});
+      std::unordered_set<const Node*> marked_nodes({mul, lstm});
+      GraphSafeRemoveNodes(graph, marked_nodes);
+    }
+
+    ++fusion_count;
+  };
+
+  gpd(graph, handler);
+
+  return fusion_count;
+}
+
+std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
+    std::unique_ptr<ir::Graph> graph) const {
+  FusePassBase::Init(name_scope_, graph.get());
+
+  int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
+                                 true /*with_fc_bias*/);
+
+  AddStatis(fusion_count);
+  return graph;
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(embedding_fc_lstm_fuse_pass,
+              paddle::framework::ir::EmbeddingFCLSTMFusePass);
diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
new file mode 100644
index 0000000000..e5ad3067ec
--- /dev/null
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
@@ -0,0 +1,40 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+// Fusing of Embedding , FC and LSTM op
+
+// Just FC without bias
+class EmbeddingFCLSTMFusePass : public FusePassBase {
+ public:
+  virtual ~EmbeddingFCLSTMFusePass() {}
+
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
+
+  const std::string name_scope_{"embedding_fc_lstm_fuse"};
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc
index 6d2c51b0e9..46c6a52c09 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc
@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
   }
 }
 
+PDNode *patterns::Embedding::operator()(PDNode *x) {
+  x->assert_is_op_input("lookup_table", "Ids");
+  auto *lookup_table_op =
+      pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table");
+#define NEW_NODE(arg__, io__)                    \
+  auto *arg__ = pattern->NewNode(arg__##_repr()) \
+                    ->assert_is_op_##io__("lookup_table", #arg__);
+
+  NEW_NODE(W, input);
+
+  NEW_NODE(Out, output);
+#undef NEW_NODE
+
+  lookup_table_op->LinksFrom({x, W});
+  lookup_table_op->LinksTo({Out});
+  return Out;
+}
+
 PDNode *patterns::LSTM::operator()(PDNode *x) {
   x->assert_is_op_input("lstm", "Input");
   auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h
index 69b486c29d..508113bf4f 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.h
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.h
@@ -418,6 +418,23 @@ struct FC : public PatternBase {
   PATTERN_DECL_NODE(Out);
 };
 
+// Embedding
+struct Embedding : public PatternBase {
+  Embedding(PDPattern* pattern, const std::string& name_scope)
+      : PatternBase(pattern, name_scope, "embedding") {}
+
+  PDNode* operator()(PDNode* x);
+
+  // declare operator node's name
+  PATTERN_DECL_NODE(lookup_table);
+  // Inputs
+  //
+  PATTERN_DECL_NODE(Ids);
+  PATTERN_DECL_NODE(W);  // embeddings
+  // Outputs
+  PATTERN_DECL_NODE(Out);
+};
+
 struct LSTM : public PatternBase {
   LSTM(PDPattern* pattern, const std::string& name_scope)
       : PatternBase(pattern, name_scope, "lstm") {}
diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h
index 9bdbefc07c..0aa9367bf5 100644
--- a/paddle/fluid/inference/analysis/analyzer.h
+++ b/paddle/fluid/inference/analysis/analyzer.h
@@ -64,14 +64,15 @@ class Analyzer : public OrderedRegistry<PassManager> {
   // larger fusion.
   const std::vector<std::string> all_ir_passes_{{
       // Manual update the passes here.
-      "infer_clean_graph_pass",    //
-      "attention_lstm_fuse_pass",  //
-      "fc_lstm_fuse_pass",         //
-      "mul_lstm_fuse_pass",        //
-      "fc_gru_fuse_pass",          //
-      "mul_gru_fuse_pass",         //
-      "seq_concat_fc_fuse_pass",   //
-      "fc_fuse_pass",              //
+      "infer_clean_graph_pass",       //
+      "attention_lstm_fuse_pass",     //
+      "embedding_fc_lstm_fuse_pass",  //
+      "fc_lstm_fuse_pass",            //
+      "mul_lstm_fuse_pass",           //
+      "fc_gru_fuse_pass",             //
+      "mul_gru_fuse_pass",            //
+      "seq_concat_fc_fuse_pass",      //
+      "fc_fuse_pass",                 //
 #ifdef PADDLE_WITH_MKLDNN
       "conv_relu_mkldnn_fuse_pass",  //
 #endif
diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
new file mode 100644
index 0000000000..3c4cc77452
--- /dev/null
+++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
@@ -0,0 +1,608 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/fused_embedding_fc_lstm_op.h"
+#include <string>
+#include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/cpu_vec.h"
+#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/sequence2batch.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace operators {
+
+void FusedEmbeddingFCLSTMOp::InferShape(
+    framework::InferShapeContext* ctx) const {
+  PADDLE_ENFORCE(ctx->HasInput("Embeddings"),
+                 "Assert only one Input(Embeddings) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasInput("WeightH"),
+                 "Assert only one Input(WeightH) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
+                 "Assert only one Output(Hidden) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasOutput("Cell"),
+                 "Assert only one Output(Cell) of LSTM.");
+  PADDLE_ENFORCE(ctx->HasInput("Ids"),
+                 "Input(Ids) of LookupTableOp should not be null.");
+
+  auto table_dims = ctx->GetInputDim("Embeddings");
+  auto ids_dims = ctx->GetInputDim("Ids");
+  int ids_rank = ids_dims.size();
+
+  PADDLE_ENFORCE_EQ(table_dims.size(), 2);
+  PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
+                    "The last dimension of the 'Ids' tensor must be 1.");
+
+  auto x_dims = ctx->GetInputDim("Ids");
+  PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(Ids)'s rank must be 2.");
+
+  if (ctx->HasInput("H0")) {
+    PADDLE_ENFORCE(ctx->HasInput("C0"),
+                   "Input(Cell) and Input(Hidden) of LSTM should not "
+                   "be null at the same time.");
+    auto h_dims = ctx->GetInputDim("H0");
+    auto c_dims = ctx->GetInputDim("C0");
+    PADDLE_ENFORCE(h_dims == c_dims,
+                   "The dimension of Input(H0) and Input(C0) "
+                   "should be the same.");
+  }
+
+  auto embeddings_dims = ctx->GetInputDim("Embeddings");
+  PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2,
+                    "The rank of Input(Embeddings) should be 2.");
+  //  PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
+  //                    "The first dimension of Input(Embeddings) "
+  //                    "should be %d.",
+  //                    x_dims[1]);
+
+  auto wh_dims = ctx->GetInputDim("WeightH");
+  int frame_size = wh_dims[1] / 4;
+  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
+                    "The rank of Input(WeightH) should be 2.");
+  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
+                    "The first dimension of Input(WeightH) "
+                    "should be %d.",
+                    frame_size);
+  PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
+                    "The second dimension of Input(WeightH) "
+                    "should be 4 * %d.",
+                    frame_size);
+
+  auto b_dims = ctx->GetInputDim("Bias");
+  PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
+  PADDLE_ENFORCE_EQ(b_dims[0], 1,
+                    "The first dimension of Input(Bias) should be 1.");
+  PADDLE_ENFORCE_EQ(
+      b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
+      "The second dimension of Input(Bias) should be "
+      "7 * %d if enable peepholes connection or"
+      "4 * %d if disable peepholes",
+      frame_size, frame_size);
+
+  framework::DDim out_dims({x_dims[0], frame_size});
+  ctx->SetOutputDim("Hidden", out_dims);
+  ctx->SetOutputDim("Cell", out_dims);
+  ctx->ShareLoD("Ids", "Hidden");
+  ctx->ShareLoD("Ids", "Cell");
+  int xx_width;
+  if (ctx->Attrs().Get<bool>("use_seq")) {
+    xx_width = wh_dims[1];
+  } else {
+    xx_width = x_dims[1] > wh_dims[1] ? wh_dims[1] : x_dims[1];
+    PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
+                   "Assert only one Output(BatchedInput) of LSTM.");
+    PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
+                   "Assert only one Output(BatchedHidden) of LSTM.");
+    PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
+                   "Assert only one Output(BatchedCell) of LSTM.");
+    PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
+                   "Assert only one Output(ReorderedH0) of LSTM");
+    PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
+                   "Assert only one Output(ReorderedC0) of LSTM.");
+    ctx->SetOutputDim("BatchedInput", {x_dims[0], wh_dims[1]});
+    ctx->SetOutputDim("BatchedHidden", out_dims);
+    ctx->SetOutputDim("BatchedCell", out_dims);
+  }
+  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
+  ctx->ShareLoD("Ids", "XX");
+}
+
+framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
+    const framework::ExecutionContext& ctx) const {
+  return framework::OpKernelType(
+      framework::ToDataType(
+          ctx.Input<framework::LoDTensor>("Embeddings")->type()),
+      ctx.device_context());
+}
+
+void FusedEmbeddingFCLSTMOpMaker::Make() {
+  AddInput("Ids",
+           "An input with type int32 or int64 "
+           "contains the ids to be looked up in W. "
+           "The last dimension size must be 1.");
+  AddInput("Embeddings",
+           "(Tensor) the learnable weights of X."
+           " - The shape is (M x 4D), where M is the dim size of x, D is the "
+           "hidden size. "
+           " - Weight = {W_cx, W_ix, W_fx, W_ox}");
+  AddInput("WeightH",
+           "(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
+           " - The shape is (D x 4D), where D is the hidden size. "
+           " - Weight = {W_ch, W_ih, W_fh, W_oh}");
+  AddInput("Bias",
+           "(Tensor) the learnable weights. Almost same as LSTMOp"
+           "Note: we should add the fc bias into this (1x4D) in bias."
+           "input-hidden bias weight and peephole connections weight if "
+           "setting `use_peepholes` True. "
+           "1. `use_peepholes = False` "
+           " - The shape is (1 x 4D). "
+           " - Bias = {b_c, b_i, b_f, b_o}."
+           "2. `use_peepholes = True` "
+           " - The shape is (1 x 7D). "
+           " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
+  AddInput("H0",
+           "(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
+           "optional "
+           "input. This is a tensor with shape (N x D), where N is the "
+           "batch size and D is the hidden size.")
+      .AsDispensable();
+  AddInput("C0",
+           "(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
+           "optional "
+           "input. This is a tensor with shape (N x D), where N is the "
+           "batch size. `H0` and `C0` can be NULL but only at the same time.")
+      .AsDispensable();
+  AddOutput("Hidden",
+            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
+            "The shape is (T x D), and lod is the same with the `Input`.");
+  AddOutput("Cell",
+            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
+            "The shape is (T x D), and lod is the same with the `Input`.");
+  AddOutput("XX",
+            "(LoDTensor) the result after X * WeightX (size is T x 4D)"
+            " or batched_X (size is T x M), this will be automatically chosen,"
+            " where T is the total time steps in this mini-batch,"
+            " D is the hidden size, M is the dim size of x input.")
+      .AsIntermediate();
+  AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate();
+  AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate();
+  AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
+  AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
+  AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
+  AddAttr<bool>("use_peepholes",
+                "(bool, defalut: True) "
+                "whether to enable diagonal/peephole connections.")
+      .SetDefault(true);
+  AddAttr<bool>("is_reverse",
+                "(bool, defalut: False) "
+                "whether to compute reversed LSTM.")
+      .SetDefault(false);
+  AddAttr<bool>("use_seq",
+                "(bool, defalut: True) "
+                "whether to use seq mode to compute.")
+      .SetDefault(true);
+  AddAttr<std::string>("gate_activation",
+                       "(string, default: sigmoid)"
+                       "The activation for input gate, forget gate and output "
+                       "gate, `sigmoid` by default.")
+      .SetDefault("sigmoid")
+      .InEnum({"sigmoid", "tanh", "relu", "identity"});
+  AddAttr<std::string>("cell_activation",
+                       "(string, default: tanh)"
+                       "The activation for cell output, `tanh` by defalut.")
+      .SetDefault("tanh")
+      .InEnum({"sigmoid", "tanh", "relu", "identity"});
+  AddAttr<std::string>("candidate_activation",
+                       "(string, default: tanh)"
+                       "The activation for candidate hidden state, "
+                       "`tanh` by default.")
+      .SetDefault("tanh")
+      .InEnum({"sigmoid", "tanh", "relu", "identity"});
+  AddComment(R"DOC(
+Fusion Long-Short Term Memory (LSTM) Operator.
+This operator fuse the X into LSTM, more details can refer to LSTM op.
+)DOC");
+}
+
+template <typename T>
+class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
+ public:
+#define INIT_VEC_FUNC                                                          \
+  std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
+  auto& act_gate_str = ctx.Attr<std::string>("gate_activation");               \
+  auto& act_cell_str = ctx.Attr<std::string>("cell_activation");               \
+  auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");          \
+  if (platform::jit::MayIUse(platform::jit::avx)) {                            \
+    math::VecActivations<T, platform::jit::avx> act_functor;                   \
+    act_gate = act_functor(act_gate_str);                                      \
+    act_cell = act_functor(act_cell_str);                                      \
+    act_cand = act_functor(act_cand_str);                                      \
+  } else {                                                                     \
+    math::VecActivations<T, platform::jit::isa_any> act_functor;               \
+    act_gate = act_functor(act_gate_str);                                      \
+    act_cell = act_functor(act_cell_str);                                      \
+    act_cand = act_functor(act_cand_str);                                      \
+  }
+
+#define INIT_BASE_INPUT_OUTPUT                        \
+  auto* ids = ctx.Input<LoDTensor>("Ids");            \
+  auto* h0 = ctx.Input<Tensor>("H0");                 \
+  auto* c0 = ctx.Input<Tensor>("C0");                 \
+  auto* embeddings = ctx.Input<Tensor>("Embeddings"); \
+  auto* wh = ctx.Input<Tensor>("WeightH");            \
+  auto* bias = ctx.Input<Tensor>("Bias");             \
+  auto* xx = ctx.Output<LoDTensor>("XX");             \
+  auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
+  auto* cell_out = ctx.Output<LoDTensor>("Cell");     \
+  bool is_reverse = ctx.Attr<bool>("is_reverse");     \
+  bool use_peepholes = ctx.Attr<bool>("use_peepholes");
+
+#define INIT_BASE_SIZES                       \
+  auto ids_dims = ids->dims();   /* T x M*/   \
+  auto ids_numel = ids->numel(); /* T x 1*/   \
+  auto wh_dims = wh->dims();     /* D x 4D*/  \
+  const int D = wh_dims[0];                   \
+  const int D2 = D * 2;                       \
+  const int D3 = D * 3;                       \
+  int64_t row_number = embeddings->dims()[0]; \
+  int64_t row_width = embeddings->dims()[1];  \
+  const int D4 = wh_dims[1];
+
+#define INIT_BASE_INPUT_DATAS                                        \
+  const int64_t* ids_data = ids->data<int64_t>();                    \
+  const T* embeddings_data = embeddings->data<T>();                  \
+  const T* wh_data = wh->data<T>();                                  \
+  /* diagonal weight*/                                               \
+  const T* wc_data = bias->data<T>() + D4;                           \
+  /* for peephole only*/                                             \
+  Tensor checked_cell;                                               \
+  T* checked_cell_data = nullptr;                                    \
+  auto place = ctx.GetPlace();                                       \
+  if (use_peepholes) {                                               \
+    /* w_ic * Ct-1, w_fc * Ct-1  ; w_oc * Ct => ih*/                 \
+    checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
+  }
+
+/// Compute LSTM
+#define GEMM_WH_ADDON(bs, prev, out)                                           \
+  blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
+            wh_data, D4, static_cast<T>(1), out, D4)
+
+// gates: W_ch, W_ih, W_fh, W_oh
+#define GET_Ct(ct_1, gates, ct)                   \
+  /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
+  act_cand(D, gates, gates);                      \
+  blas.VMUL(D, gates, gates + D, gates + D);      \
+  blas.VMUL(D, ct_1, gates + D2, gates + D2);     \
+  blas.VADD(D, gates + D, gates + D2, ct)
+
+#define GET_Ht(ct, gates, ht)        \
+  /* H_t = act_cell(C_t) * ogated */ \
+  act_cell(D, ct, gates + D2);       \
+  blas.VMUL(D, gates + D2, gates + D3, ht)
+
+#define GET_Ct_NOH0C0(gates, ct)     \
+  /* C_t = igated * cgated*/         \
+  act_gate(D, gates + D, gates + D); \
+  act_cand(D, gates, gates);         \
+  blas.VMUL(D, gates, gates + D, ct)
+
+#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
+  GET_Ct_NOH0C0(gates, ct);                \
+  act_gate(D, gates + D3, gates + D3);     \
+  GET_Ht(ct, gates, ht)
+
+#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
+  GET_Ct_NOH0C0(gates, ct);                         \
+  /* get outgated, put W_oc * C_t on igated */      \
+  blas.VMUL(D, wc_data + D2, ct, gates + D);        \
+  blas.VADD(D, gates + D, gates + D3, gates + D3);  \
+  act_gate(D, gates + D3, gates + D3);              \
+  GET_Ht(ct, gates, ht)
+
+#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
+  act_gate(D3, gates + D, gates + D);     \
+  GET_Ct(ct_1, gates, ct);                \
+  GET_Ht(ct, gates, ht)
+
+#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht)        \
+  /* get fgated and igated*/                              \
+  blas.VMUL(D, wc_data, ct_1, checked_cell_data);         \
+  blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
+  blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
+  act_gate(D2, gates + D, gates + D);                     \
+  GET_Ct(ct_1, gates, ct);                                \
+  /* get ogated*/                                         \
+  blas.VMUL(D, wc_data + D2, ct, gates + D);              \
+  blas.VADD(D, gates + D, gates + D3, gates + D3);        \
+  act_gate(D, gates + D3, gates + D3);                    \
+  GET_Ht(ct, gates, ht)
+
+  void SeqCompute(const framework::ExecutionContext& ctx) const {
+    using DeviceContext = paddle::platform::CPUDeviceContext;
+    INIT_BASE_INPUT_OUTPUT
+    INIT_BASE_SIZES
+    INIT_VEC_FUNC
+    INIT_BASE_INPUT_DATAS
+
+    //  std::cout << "====> SeqCompute" << std::endl;
+    auto ids_lod = ids->lod();
+    const int total_T = ids_dims[0];
+    const int N = ids_lod[0].size() - 1;
+    const T* h0_data = h0 ? h0->data<T>() : nullptr;
+    const T* c0_data = c0 ? c0->data<T>() : nullptr;
+    T* xx_data = xx->mutable_data<T>(place);
+    T* h_out_data = hidden_out->mutable_data<T>(place);
+    T* c_out_data = cell_out->mutable_data<T>(place);
+    auto blas = math::GetBlas<DeviceContext, T>(ctx);
+
+    for (int64_t i = 0; i < ids_numel; ++i) {
+      PADDLE_ENFORCE_LT(ids_data[i], row_number);
+      PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i);
+      memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
+             row_width * sizeof(T));
+    }
+
+    int xx_offset = D4;
+    int gate_offset = D;
+    if (is_reverse) {
+      const int offset = (total_T - 1) * D;
+      xx_data = xx_data + offset * 4;
+      h_out_data = h_out_data + offset;
+      c_out_data = c_out_data + offset;
+      xx_offset = -D4;
+      gate_offset = -D;
+    }
+
+#define MOVE_ONE_STEP                    \
+  prev_h_data = h_out_data;              \
+  prev_c_data = c_out_data;              \
+  xx_data = xx_data + xx_offset;         \
+  h_out_data = h_out_data + gate_offset; \
+  c_out_data = c_out_data + gate_offset
+
+#define PROCESS_H0C0_DEFINES                           \
+  int bid = is_reverse ? N - 1 - i : i;                \
+  int seq_len = ids_lod[0][bid + 1] - ids_lod[0][bid]; \
+  const T* prev_c_data = nullptr;                      \
+  const T* prev_h_data = nullptr;                      \
+  int tstart = 0
+
+#define PROCESS_H0C0_PEEPHOLE                                      \
+  PROCESS_H0C0_DEFINES;                                            \
+  if (h0_data) {                                                   \
+    prev_h_data = h0_data + bid * D;                               \
+    prev_c_data = c0_data + bid * D;                               \
+  } else {                                                         \
+    COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
+    MOVE_ONE_STEP;                                                 \
+    tstart = 1;                                                    \
+  }
+
+#define PROCESS_H0C0                                      \
+  PROCESS_H0C0_DEFINES;                                   \
+  if (h0_data) {                                          \
+    prev_h_data = h0_data + bid * D;                      \
+    prev_c_data = c0_data + bid * D;                      \
+  } else {                                                \
+    COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
+    MOVE_ONE_STEP;                                        \
+    tstart = 1;                                           \
+  }
+
+    if (use_peepholes) {
+      for (int i = 0; i < N; ++i) {
+        PROCESS_H0C0_PEEPHOLE
+        for (int step = tstart; step < seq_len; ++step) {
+          GEMM_WH_ADDON(1, prev_h_data, xx_data);
+          COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data);
+          MOVE_ONE_STEP;
+        }
+      }
+    } else {
+      for (int i = 0; i < N; ++i) {
+        PROCESS_H0C0
+        for (int step = tstart; step < seq_len; ++step) {
+          GEMM_WH_ADDON(1, prev_h_data, xx_data);
+          COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
+          MOVE_ONE_STEP;
+        }
+      }
+    }
+#undef PROCESS_H0C0_DEFINES
+#undef PROCESS_H0C0_PEEPHOLE
+#undef PROCESS_H0C0
+#undef MOVE_ONE_STEP
+  }
+
+  void BatchCompute(const framework::ExecutionContext& ctx) const {
+    using DeviceContext = platform::CPUDeviceContext;
+    INIT_BASE_INPUT_OUTPUT
+    if (ids->lod()[0].size() == 2) {
+      SeqCompute(ctx);
+      return;
+    }
+    INIT_BASE_SIZES
+    INIT_VEC_FUNC
+    INIT_BASE_INPUT_DATAS
+
+    // std::cout << "===> Batch Compute" << std::endl;
+
+    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
+    auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
+    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
+    auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
+    auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
+    T* xx_data = xx->mutable_data<T>(place);
+    T* batched_input_data = batched_input->mutable_data<T>(place);
+    T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
+    T* batched_h_out_data = batched_h_out->mutable_data<T>(place);
+    hidden_out->mutable_data<T>(place);
+    cell_out->mutable_data<T>(place);
+
+    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
+    auto& dev_ctx = ctx.template device_context<DeviceContext>();
+    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
+
+    for (int64_t i = 0; i < ids_numel; ++i) {
+      PADDLE_ENFORCE_LT(ids_data[i], row_number);
+      PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i);
+      memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width,
+             row_width * sizeof(T));
+    }
+
+    to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
+
+    auto batched_lod = batched_input->lod();
+    const auto& seq_order = batched_lod[2];
+    const int max_bs = seq_order.size();
+    reordered_h0->Resize({max_bs, D});
+    reordered_c0->Resize({max_bs, D});
+
+    int tstart = 0;
+    T* prev_h_data = nullptr;
+    T* prev_c_data = nullptr;
+    if (h0) {
+      // reorder h0, c0
+      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
+      T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
+      const T* h0_data = h0->data<T>();
+      const T* c0_data = c0->data<T>();
+      prev_h_data = reordered_h0_data;
+      prev_c_data = reordered_c0_data;
+      size_t sz = sizeof(T) * D;
+      for (int i = 0; i < max_bs; ++i) {
+        std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
+        std::memcpy(reordered_c0_data, c0_data + seq_order[i] * D, sz);
+        reordered_h0_data += D;
+        reordered_c0_data += D;
+      }
+    } else {
+      // compute without h0, c0
+      T* cur_in_data = batched_input_data;
+      T* cur_h_out_data = batched_h_out_data;
+      T* cur_c_out_data = batched_c_out_data;
+      for (int i = 0; i < max_bs; ++i) {
+        GET_Ct_NOH0C0(cur_in_data, cur_c_out_data);
+        if (use_peepholes) {
+          blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D);
+          blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
+        }
+        act_gate(D, cur_in_data + D3, cur_in_data + D3);
+        GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data);
+        cur_in_data += D4;
+        cur_c_out_data += D;
+        cur_h_out_data += D;
+      }
+      tstart = 1;
+      prev_h_data = batched_h_out_data;
+      prev_c_data = batched_c_out_data;
+    }
+    const auto& batch_starts = batched_lod[0];
+    const int max_seq_len = batch_starts.size() - 1;
+    const int offset = tstart * max_bs * D;
+    batched_input_data = batched_input_data + offset * 4;
+    batched_h_out_data = batched_h_out_data + offset;
+    batched_c_out_data = batched_c_out_data + offset;
+
+#define DEFINE_CUR                        \
+  T* cur_in_data = batched_input_data;    \
+  T* cur_prev_c_data = prev_c_data;       \
+  T* cur_c_out_data = batched_c_out_data; \
+  T* cur_h_out_data = batched_h_out_data
+
+#define MOVE_ONE_BATCH  \
+  cur_in_data += D4;    \
+  cur_prev_c_data += D; \
+  cur_c_out_data += D;  \
+  cur_h_out_data += D
+
+#define MOVE_ONE_STEP                  \
+  prev_c_data = batched_c_out_data;    \
+  prev_h_data = batched_h_out_data;    \
+  batched_c_out_data = cur_c_out_data; \
+  batched_h_out_data = cur_h_out_data; \
+  batched_input_data = cur_in_data
+
+    if (use_peepholes) {
+      for (int step = tstart; step < max_seq_len; ++step) {
+        const int cur_bs = batch_starts[step + 1] - batch_starts[step];
+        GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
+        DEFINE_CUR;
+        for (int i = 0; i < cur_bs; ++i) {
+          COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data,
+                                cur_h_out_data);
+          MOVE_ONE_BATCH;
+        }
+        MOVE_ONE_STEP;
+      }
+    } else {
+      for (int step = tstart; step < max_seq_len; ++step) {
+        const int cur_bs = batch_starts[step + 1] - batch_starts[step];
+        GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
+        DEFINE_CUR;
+        for (int i = 0; i < cur_bs; ++i) {
+          COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
+                       cur_h_out_data);
+          MOVE_ONE_BATCH;
+        }
+        MOVE_ONE_STEP;
+      }
+    }
+#undef MOVE_ONE_STEP
+#undef MOVE_ONE_BATCH
+#undef DEFINE_CUR
+
+    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
+    batched_h_out->set_lod(batched_lod);
+    to_seq(dev_ctx, *batched_h_out, hidden_out);
+    batched_c_out->set_lod(batched_lod);
+    to_seq(dev_ctx, *batched_c_out, cell_out);
+  }
+
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    if (ctx.Attr<bool>("use_seq")) {
+      SeqCompute(ctx);
+    } else {
+      BatchCompute(ctx);
+    }
+  }
+
+#undef COMPUTE_CtHt_PEEPHOLE
+#undef COMPUTE_CtHt
+#undef GET_Ct_NOH0C0
+#undef COMPUTE_CtHt_NOH0C0
+#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
+#undef GET_Ht
+#undef GET_Ct
+#undef GEMM_WH_ADDON
+#undef INIT_BASE_INPUT_DATAS
+#undef INIT_BASE_SIZES
+#undef INIT_BASE_INPUT_OUTPUT
+#undef INIT_VEC_FUNC
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(fused_embedding_fc_lstm, ops::FusedEmbeddingFCLSTMOp,
+                  ops::FusedEmbeddingFCLSTMOpMaker,
+                  paddle::framework::DefaultGradOpDescMaker<true>);
+
+REGISTER_OP_CPU_KERNEL(fused_embedding_fc_lstm,
+                       ops::FusedEmbeddingFCLSTMKernel<float>,
+                       ops::FusedEmbeddingFCLSTMKernel<double>);
diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.h b/paddle/fluid/operators/fused_embedding_fc_lstm_op.h
new file mode 100644
index 0000000000..2775b2ac04
--- /dev/null
+++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.h
@@ -0,0 +1,41 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using LoDTensor = framework::LoDTensor;
+using Tensor = framework::Tensor;
+
+class FusedEmbeddingFCLSTMOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override;
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override;
+};
+
+class FusedEmbeddingFCLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override;
+};
+
+}  // namespace operators
+}  // namespace paddle