diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc
index 46ed99bcf2..137bca5e2b 100644
--- a/paddle/fluid/operators/activation_mkldnn_op.cc
+++ b/paddle/fluid/operators/activation_mkldnn_op.cc
@@ -12,16 +12,20 @@
    See the License for the specific language governing permissions and
    limitations under the License. */
 
-#include "mkldnn.hpp"
 #include "paddle/fluid/operators/activation_op.h"
-#include "paddle/fluid/operators/mkldnn_activation_op.h"
 #include "paddle/fluid/platform/mkldnn_helper.h"
 
 namespace paddle {
 namespace operators {
 
-using paddle::framework::Tensor;
-using paddle::platform::MKLDNNDeviceContext;
+using framework::DataLayout;
+using framework::Tensor;
+using mkldnn::memory;
+using mkldnn::primitive;
+using mkldnn::stream;
+using platform::GetMKLDNNFormat;
+using platform::MKLDNNDeviceContext;
+using platform::to_void_cast;
 
 namespace {
 std::string gethash(const mkldnn::memory::dims &operand_dims,
@@ -35,188 +39,260 @@ std::string gethash(const mkldnn::memory::dims &operand_dims,
   };
   return dim2str(operand_dims) + std::to_string(algorithm);
 }
+}  // namespace
+
+template <typename Functor>
+class MKLDNNActivationKernel
+    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
+ public:
+  void Compute(const framework::ExecutionContext &ctx) const override {
+    const auto *x = ctx.Input<Tensor>("X");
+    PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
+                       x->format() != memory::format::format_undef,
+                   "Wrong layout/format set for Input x tensor");
+
+    Functor functor;
+
+    auto attrs = functor.GetAttrs();
+    for (auto &attr : attrs) {
+      *attr.second = ctx.Attr<float>(attr.first);
+    }
+    functor(ctx);
+  }
+};
 
-template <typename T, typename ExecContext>
-void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
-                     const T alpha = 0, const T beta = 0) {
+template <typename Functor>
+class MKLDNNActivationGradKernel
+    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
+ public:
+  void Compute(const framework::ExecutionContext &ctx) const override {
+    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
+    PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
+                       diff_y->format() != memory::format::format_undef,
+                   "Wrong layout/format set for Input OutGrad tensor");
+
+    Functor functor;
+
+    auto attrs = functor.GetAttrs();
+    for (auto &attr : attrs) {
+      *attr.second = ctx.Attr<float>(attr.first);
+    }
+    functor(ctx);
+  }
+};
+
+template <typename T>
+void eltwise_forward(const framework::ExecutionContext &ctx,
+                     mkldnn::algorithm algorithm, const T alpha = 0,
+                     const T beta = 0) {
   PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                  "It must use CPUPlace.");
-
   auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
   const auto &mkldnn_engine = dev_ctx.GetEngine();
 
-  // get buffers
-  const auto *src = ctx.template Input<Tensor>("X");
-  const auto *src_data = src->template data<T>();
+  const auto *x = ctx.Input<Tensor>("X");
+  auto *y = ctx.Output<Tensor>("Out");
 
-  auto *dst = ctx.template Output<Tensor>("Out");
-  T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
+  const T *x_data = x->data<T>();
+  T *y_data = y->mutable_data<T>(ctx.GetPlace());
 
-  // get memory dim
-  PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
+  PADDLE_ENFORCE(x->dims().size() == 2 || x->dims().size() == 4,
                  "Input dim must be with 2 or 4");
-  std::vector<int> src_tz = framework::vectorize2int(src->dims());
+
+  std::vector<int> src_tz = framework::vectorize2int(x->dims());
+
+  auto src_format =
+      src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
 
   const std::string key = gethash(src_tz, algorithm);
   const std::string key_src_data =
       key + ctx.op().Output("Out") + "@eltwise_fwd_src_data";
-  const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
-  const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem";
-  const std::string key_fwd = key + "@eltwise_fwd";
+  const std::string key_src_layout =
+      key + ctx.op().Output("Out") + "@eltwise_fwd_src_layout";
+  const std::string key_with_layout = key + std::to_string(src_format);
+  const std::string key_src_mem = key_with_layout + "@eltwise_fwd_src_mem";
+  const std::string key_dst_mem = key_with_layout + "@eltwise_fwd_dst_mem";
+  const std::string key_fwd = key_with_layout + "@eltwise_fwd";
+  const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd";
+
+  // save input data and layout to be referred in backward path
+  auto p_src_data = std::make_shared<const T *>(x_data);
+  dev_ctx.SetBlob(key_src_data, p_src_data);
+  auto p_src_layout = std::make_shared<memory::format>(src_format);
+  dev_ctx.SetBlob(key_src_layout, p_src_layout);
 
   auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
       dev_ctx.GetBlob(key_fwd));
 
-  // save input data to be referred in backward path
-  auto p_src_data = std::make_shared<const T *>(src_data);
-  dev_ctx.SetBlob(key_src_data, p_src_data);
+  std::shared_ptr<memory> dst_memory;
 
   if (p_fwd == nullptr) {
-    // create memory description
-    auto data_md = src_tz.size() == 2
-                       ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
-                                                 mkldnn::memory::format::nc)
-                       : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
-                                                 mkldnn::memory::format::nchw);
-
-    // create memory primitives
-    auto p_src_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
-        {data_md, mkldnn_engine}, platform::to_void_cast(src_data)));
-    dev_ctx.SetBlob(key_src_mem, p_src_mem);
-
-    auto p_dst_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
-        {data_md, mkldnn_engine}, platform::to_void_cast(dst_data)));
-    dev_ctx.SetBlob(key_dst_mem, p_dst_mem);
-
-    auto fwd_desc = mkldnn::eltwise_forward::desc(
-        mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
-    auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
-        fwd_desc, mkldnn_engine);
-    const std::string key_fwd_pd = key + "eltwise_fwd_pd";
-    dev_ctx.SetBlob(key_fwd_pd, p_fwd_pd);
-    p_fwd = std::make_shared<mkldnn::eltwise_forward>(
-        *p_fwd_pd, *(p_src_mem.get()), *(p_dst_mem.get()));
+    // create mkldnn memory for input X
+    auto src_md = platform::MKLDNNMemDesc(
+        src_tz, platform::MKLDNNGetDataType<T>(), src_format);
+    auto src_memory = std::shared_ptr<memory>(
+        new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
+    // save src_memory to be referred in backward path
+    dev_ctx.SetBlob(key_src_mem, src_memory);
+
+    // create primitive descriptor for activation forward and save it
+    auto forward_desc = mkldnn::eltwise_forward::desc(
+        mkldnn::prop_kind::forward_training, algorithm,
+        src_memory->get_primitive_desc().desc(), alpha, beta);
+    auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
+        forward_desc, mkldnn_engine);
+
+    // save prim desc into global device context to be referred in backward path
+    dev_ctx.SetBlob(key_fwd_pd, forward_pd);
+
+    // create mkldnn memory for output y
+    dst_memory =
+        std::make_shared<memory>(forward_pd->dst_primitive_desc(), y_data);
+
+    dev_ctx.SetBlob(key_dst_mem, dst_memory);
+
+    // create activation primitive
+    p_fwd = std::make_shared<mkldnn::eltwise_forward>(*forward_pd, *src_memory,
+                                                      *dst_memory);
     dev_ctx.SetBlob(key_fwd, p_fwd);
   } else {
     // primitives already exist
-    auto p_src_mem =
+    auto src_memory =
         std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
-    PADDLE_ENFORCE(p_src_mem != nullptr,
-                   "Fail to find eltwise p_src_mem in device context.");
-    auto p_dst_mem =
+    PADDLE_ENFORCE(src_memory != nullptr,
+                   "Fail to find eltwise src_memory in device context.");
+    dst_memory =
         std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
-    PADDLE_ENFORCE(p_dst_mem != nullptr,
-                   "Fail to find eltwise p_src_mem in device context.");
+    PADDLE_ENFORCE(dst_memory != nullptr,
+                   "Fail to find eltwise dst_memory in device context.");
 
-    p_src_mem->set_data_handle(platform::to_void_reinterpret_cast(src_data));
-    p_dst_mem->set_data_handle(dst_data);
+    src_memory->set_data_handle(platform::to_void_cast(x_data));
+    dst_memory->set_data_handle(y_data);
   }
 
   // push primitive to stream and wait until it's executed
-  std::vector<mkldnn::primitive> pipeline = {*(p_fwd.get())};
-  mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
+  std::vector<primitive> pipeline;
+  pipeline.push_back(*p_fwd);
+  stream(stream::kind::eager).submit(pipeline).wait();
+
+  y->set_layout(DataLayout::kMKLDNN);
+  y->set_format(GetMKLDNNFormat(*dst_memory));
 }
 
-template <typename T, typename ExecContext>
-void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
-                  const T alpha = 0, const T beta = 0) {
+template <typename T>
+void eltwise_grad(const framework::ExecutionContext &ctx,
+                  mkldnn::algorithm algorithm, const T alpha = 0,
+                  const T beta = 0) {
   auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
   const auto &mkldnn_engine = dev_ctx.GetEngine();
 
-  // get buffers
-  const auto *out = ctx.template Input<Tensor>("Out");
-
-  auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
-  const auto *diff_dst = dout->template data<T>();
+  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
+  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
 
-  auto *dx =
-      ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
-  const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());
+  const T *diff_y_data = diff_y->data<T>();
+  T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
 
-  // get memory dim
-  std::vector<int> src_tz = framework::vectorize2int(out->dims());
+  std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
 
-  const std::string key = gethash(src_tz, algorithm);
-  const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem";
-  const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem";
-  const std::string key_grad = key + "@eltwise_grad";
+  auto diff_y_format =
+      diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
 
+  const std::string key = gethash(diff_dst_tz, algorithm);
   const std::string key_src_data =
       key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
+  const std::string key_src_layout =
+      key + ctx.op().Input("Out") + "@eltwise_fwd_src_layout";
+  const auto p_src_layout =
+      std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout));
+  const std::string key_src_mem =
+      key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
+  const std::string key_fwd_pd =
+      key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
+  const std::string key_with_layouts =
+      key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
+  const std::string key_diff_src_mem =
+      key_with_layouts + "@eltwise_diff_src_mem";
+  const std::string key_diff_dst_mem =
+      key_with_layouts + "@eltwise_diff_dst_mem";
+  const std::string key_grad = key_with_layouts + "@eltwise_grad";
+
   const auto p_src_data =
       std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
 
-  const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
-  auto p_src_mem =
+  auto src_memory =
       std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
-  p_src_mem->set_data_handle(*p_src_data.get());
+  PADDLE_ENFORCE(src_memory != nullptr,
+                 "Fail to find src_memory in device context");
+  src_memory->set_data_handle(*p_src_data.get());
+
+  std::shared_ptr<memory> diff_src_memory;
 
-  auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
+  auto p_grad = std::static_pointer_cast<mkldnn::eltwise_backward>(
       dev_ctx.GetBlob(key_grad));
 
   if (p_grad == nullptr) {
-    // create memory description
-    auto data_md = src_tz.size() == 2
-                       ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
-                                                 mkldnn::memory::format::nc)
-                       : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
-                                                 mkldnn::memory::format::nchw);
-
-    // create memory primitives
-    std::shared_ptr<void> p_diff_src_mem =
-        std::make_shared<mkldnn::memory>(mkldnn::memory(
-            {data_md, mkldnn_engine}, platform::to_void_cast(diff_src)));
-    dev_ctx.SetBlob(key_diff_src_mem, p_diff_src_mem);
-    std::shared_ptr<void> p_diff_dst_mem =
-        std::make_shared<mkldnn::memory>(mkldnn::memory(
-            {data_md, mkldnn_engine}, platform::to_void_cast(diff_dst)));
-    dev_ctx.SetBlob(key_diff_dst_mem, p_diff_dst_mem);
-
-    auto bwd_desc = mkldnn::eltwise_backward::desc(algorithm, data_md, data_md,
-                                                   alpha, beta);
-
-    const std::string key_fwd_pd = key + "eltwise_fwd_pd";
-    auto *p_fwd_pd = static_cast<mkldnn::eltwise_forward::primitive_desc *>(
-        dev_ctx.GetBlob(key_fwd_pd).get());
-
-    auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
-        bwd_desc, mkldnn_engine, *p_fwd_pd);
-
+    // create mkldnn memory for input diff_y
+    auto diff_dst_md = platform::MKLDNNMemDesc(
+        diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
+    auto diff_dst_memory = std::shared_ptr<memory>(
+        new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
+    dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
+
+    // retrieve eltwise primitive desc from device context
+    auto forward_pd =
+        std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
+            dev_ctx.GetBlob(key_fwd_pd));
+    PADDLE_ENFORCE(forward_pd != nullptr,
+                   "Fail to find eltwise_fwd_pd in device context");
+
+    // ceate primitive descriptor for activation backward
+    auto backward_desc = mkldnn::eltwise_backward::desc(
+        algorithm, diff_dst_memory->get_primitive_desc().desc(),
+        src_memory->get_primitive_desc().desc(), alpha, beta);
+    auto backward_pd = mkldnn::eltwise_backward::primitive_desc(
+        backward_desc, mkldnn_engine, *forward_pd);
+
+    // create mkldnn memory for output diff_src
+    diff_src_memory = std::make_shared<memory>(
+        backward_pd.diff_src_primitive_desc(), diff_x_data);
+    dev_ctx.SetBlob(key_diff_src_mem, diff_src_memory);
+
+    // create activation backward primitive
     p_grad = std::make_shared<mkldnn::eltwise_backward>(
-        eltwise_bwd_prim_desc, *static_cast<mkldnn::memory *>(p_src_mem.get()),
-        *(static_cast<mkldnn::memory *>(p_diff_dst_mem.get())),
-        *(static_cast<mkldnn::memory *>(p_diff_src_mem.get())));
+        backward_pd, *src_memory, *diff_dst_memory, *diff_src_memory);
+    dev_ctx.SetBlob(key_grad, p_grad);
   } else {
     // primitives already exist
-    auto p_diff_src_mem = std::static_pointer_cast<mkldnn::memory>(
+    diff_src_memory = std::static_pointer_cast<mkldnn::memory>(
         dev_ctx.GetBlob(key_diff_src_mem));
-    auto p_diff_dst_mem = std::static_pointer_cast<mkldnn::memory>(
+    auto diff_dst_memory = std::static_pointer_cast<mkldnn::memory>(
         dev_ctx.GetBlob(key_diff_dst_mem));
 
-    p_diff_src_mem->set_data_handle(
-        platform::to_void_reinterpret_cast(diff_src));
-    p_diff_dst_mem->set_data_handle(
-        platform::to_void_reinterpret_cast(diff_dst));
+    diff_src_memory->set_data_handle(
+        platform::to_void_reinterpret_cast(diff_x_data));
+    diff_dst_memory->set_data_handle(
+        platform::to_void_reinterpret_cast(diff_y_data));
   }
 
   // push primitive to stream and wait until it's executed
-  std::vector<mkldnn::primitive> pipeline = {*(p_grad.get())};
-  mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
+  std::vector<primitive> pipeline;
+  pipeline.push_back(*p_grad);
+  stream(stream::kind::eager).submit(pipeline).wait();
+
+  diff_x->set_layout(DataLayout::kMKLDNN);
+  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
 }
-}  // anonymous namespace
 
 template <typename T, mkldnn::algorithm algorithm>
 struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
-  template <typename ExecContext>
-  void operator()(const ExecContext &ctx) const {
+  void operator()(const framework::ExecutionContext &ctx) const {
     eltwise_forward<T>(ctx, algorithm);
   }
 };
 
 template <typename T, mkldnn::algorithm algorithm>
 struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
-  template <typename ExecContext>
-  void operator()(const ExecContext &ctx) const {
+  void operator()(const framework::ExecutionContext &ctx) const {
     eltwise_grad<T>(ctx, algorithm);
   }
 };
diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc
index a06ca7952f..b6b498a616 100644
--- a/paddle/fluid/operators/activation_op.cc
+++ b/paddle/fluid/operators/activation_op.cc
@@ -19,18 +19,20 @@ limitations under the License. */
 namespace paddle {
 namespace operators {
 
-#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)             \
-  class OP_NAME##OpMaker                                              \
-      : public ::paddle::framework::OpProtoAndCheckerMaker {          \
-   public:                                                            \
-    void Make() override {                                            \
-      AddInput("X", "Input of " #OP_NAME " operator");                \
-      AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
-      AddAttr<bool>("use_mkldnn",                                     \
-                    "(default false) Only used in mkldnn kernel")     \
-          .SetDefault(false);                                         \
-      AddComment(OP_COMMENT);                                         \
-    }                                                                 \
+using paddle::framework::Tensor;
+
+#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)               \
+  class OP_NAME##OpMaker                                                \
+      : public ::paddle::framework::OpProtoAndCheckerMaker {            \
+   public:                                                              \
+    void Make() override {                                              \
+      AddInput("X", "Input of " #OP_NAME " operator");                  \
+      AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X");   \
+      AddAttr<bool>("use_mkldnn",                                       \
+                    "(bool, default false) Only used in mkldnn kernel") \
+          .SetDefault(false);                                           \
+      AddComment(#OP_COMMENT);                                          \
+    }                                                                   \
   }
 
 #define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)              \
@@ -58,7 +60,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                       const framework::OperatorWithKernel& oper,
                                       const std::string& name) {
   framework::LibraryType library{framework::LibraryType::kPlain};
-
   framework::DataLayout layout = framework::DataLayout::kAnyLayout;
 #ifdef PADDLE_WITH_MKLDNN
   auto it = oper.Attrs().find("use_mkldnn");
@@ -82,6 +83,7 @@ class ActivationOp : public framework::OperatorWithKernel {
     ctx->ShareLoD("X", /*->*/ "Out");
   }
 
+ protected:
   framework::OpKernelType GetExpectedKernelType(
       const framework::ExecutionContext& ctx) const override {
     return GetKernelType(ctx, *this, "X");
@@ -96,6 +98,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
     ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
   }
 
+ protected:
   framework::OpKernelType GetExpectedKernelType(
       const framework::ExecutionContext& ctx) const override {
     return GetKernelType(ctx, *this, "Out");