diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp
index 663a105098..4347ab821d 100644
--- a/paddle/gserver/layers/MKLDNNLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNLayer.cpp
@@ -171,14 +171,16 @@ void MKLDNNLayer::resetWithMatrix(MKLDNNMatrixPtr& dnn,
 }
 
 void MKLDNNLayer::resetInValue(
-    MKLDNNMatrixPtr& in, const std::shared_ptr<memory::primitive_desc>& intPD) {
+    MKLDNNMatrixPtr& in,
+    const std::shared_ptr<memory::primitive_desc>& intPD,
+    size_t inputIdx) {
   cvtInVal_ = nullptr;
   extInVal_ = nullptr;
   in = nullptr;
   CHECK_GT(bs_ * ic_ * ih_ * iw_, 0);
   auto extPD = MKLDNNMatrix::createPrimitiveDesc(
       {bs_, ic_, ih_, iw_}, format::nchw, engine_);
-  const MatrixPtr& inMat = inputLayers_[0]->getOutputValue();
+  const MatrixPtr& inMat = inputLayers_[inputIdx]->getOutputValue();
   in = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
   CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr);
   if (in == nullptr || in->getFormat() == format::nc) {
@@ -216,11 +218,12 @@ void MKLDNNLayer::resetOutValue(MKLDNNMatrixPtr& out,
 }
 
 void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
-                              memory::primitive_desc intPD) {
+                              memory::primitive_desc intPD,
+                              size_t inputIdx) {
   cvtInGrad_ = nullptr;
   extInGrad_ = nullptr;
   in = nullptr;
-  LayerPtr& input = inputLayers_[0];
+  LayerPtr& input = inputLayers_[inputIdx];
   if (input->getOutputGrad() == nullptr) {
     // no need input grad
     return;
@@ -245,7 +248,6 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
     return;
   }
   // need create reorder
-  // TODO(TJ): add macro definition to simplify it
   CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat()))
       << "should have external input value and the format must be nchw(nc)";
   extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat);
diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h
index 2c21a5b2aa..7479c34c92 100644
--- a/paddle/gserver/layers/MKLDNNLayer.h
+++ b/paddle/gserver/layers/MKLDNNLayer.h
@@ -199,7 +199,8 @@ protected:
    */
   void resetInValue(
       MKLDNNMatrixPtr& in,
-      const std::shared_ptr<mkldnn::memory::primitive_desc>& intPD = nullptr);
+      const std::shared_ptr<mkldnn::memory::primitive_desc>& intPD = nullptr,
+      size_t inputIdx = 0);
 
   /**
    * reset output value from internal primitive desc.
@@ -212,7 +213,9 @@ protected:
    * reset input grad from internal primitive desc.
    * reset both internal and external buffer and create reorder if necessary.
    */
-  void resetInGrad(MKLDNNMatrixPtr& in, mkldnn::memory::primitive_desc intPD);
+  void resetInGrad(MKLDNNMatrixPtr& in,
+                   mkldnn::memory::primitive_desc intPD,
+                   size_t inputIdx = 0);
 
   /**
    * reset output grad from internal primitive desc.