diff --git a/paddle/operators/expand_op.cc b/paddle/operators/expand_op.cc
index 3990b3751d..5d83b1d9d2 100644
--- a/paddle/operators/expand_op.cc
+++ b/paddle/operators/expand_op.cc
@@ -24,26 +24,28 @@ class ExpandOp : public framework::OperatorWithKernel {
   using framework::OperatorWithKernel::OperatorWithKernel;
 
  protected:
-  void InferShape(const framework::InferShapeContext& ctx) const override {
-    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
-    std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
-    auto x_dims = ctx.Input<Tensor>("X")->dims();
-
-    PADDLE_ENFORCE_EQ(x_dims.size(), expand_times.size(),
-                      "The number of expandTimes's value must be equal "
-                      "to the rank of X.");
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
+    std::vector<int> expand_times =
+        ctx->Attrs().Get<std::vector<int>>("expandTimes");
+    auto x_dims = ctx->GetInputDim("X");
+
+    PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
+                      "The number of Attr(expandTimes)'s value must be equal "
+                      "to the rank of Input(X).");
     PADDLE_ENFORCE_LE(x_dims.size(), 6,
-                      "The rank of X must not be greater than 6.");
+                      "The rank of Input(X) must not be greater than 6.");
 
     std::vector<int64_t> out_shape(x_dims.size());
     for (size_t i = 0; i < expand_times.size(); ++i) {
       PADDLE_ENFORCE_GE(expand_times[i], 1,
-                        "Each value of expandTimes should not be "
+                        "Each value of Attr(expandTimes) should not be "
                         "less than 1.");
       out_shape[i] = x_dims[i] * expand_times[i];
     }
-    auto* out = ctx.Output<framework::LoDTensor>("Out");
-    out->Resize(framework::make_ddim(out_shape));
+
+    ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
+    ctx->ShareLoD("X", "Out");
   }
 };
 
@@ -52,20 +54,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
   ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
       : OpProtoAndCheckerMaker(proto, op_checker) {
     AddInput("X",
-             "The input tensor of expand op."
-             "The rank of X should be between in 1 and 6.");
+             "(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
+             "X is the input tensor to be expanded.");
     AddOutput("Out",
-              "Output tensor of expand op."
-              "The rank of Out is same as X except that each dimension size "
-              "of Out equals to corresponding dimension size of X multiplying "
-              "corresponding value of expandTimes.");
+              "(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
+              "The rank of Output(Out) is same as Input(X) except that each "
+              "dimension size of Output(Out) is equal to corresponding "
+              "dimension size of Input(X) multiplying corresponding value of "
+              "Attr(expandTimes).");
     AddAttr<std::vector<int>>("expandTimes",
                               "Expand times number for each dimension.");
     AddComment(R"DOC(
 Expand operator tiles the input by given times number. You should set times
 number for each dimension by providing attribute 'expandTimes'. The rank of X
-should be between in 1 and 6. Please notice that size of 'expandTimes' must be
-same with X's rank.
+should be in [1, 6]. Please notice that size of 'expandTimes' must be same with
+X's rank.
 )DOC");
   }
 };
@@ -75,25 +78,27 @@ class ExpandGradOp : public framework::OperatorWithKernel {
   using framework::OperatorWithKernel::OperatorWithKernel;
 
  protected:
-  void InferShape(const framework::InferShapeContext& ctx) const override {
-    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
-    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
-                            "Input(Out@GRAD) should not be null.");
-    auto x_dims = ctx.Input<Tensor>("X")->dims();
-    std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
-    auto out_dims =
-        ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->dims();
-    auto* x_grad =
-        ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+                   "Input(Out@GRAD) should not be null.");
+    auto x_dims = ctx->GetInputDim("X");
+    std::vector<int> expand_times =
+        ctx->Attrs().Get<std::vector<int>>("expandTimes");
+    auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
 
     for (size_t i = 0; i < expand_times.size(); ++i) {
       PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
                         "Each dimension size of Input(Out@GRAD) should be "
                         "equal to multiplication of crroresponding dimension "
-                        "size of Input(X) and expandTimes value.");
+                        "size of Input(X) and Attr(expandTimes) value.");
     }
 
-    if (x_grad) x_grad->Resize(x_dims);
+    auto x_grad_name = framework::GradVarName("X");
+
+    if (ctx->HasOutput(x_grad_name)) {
+      ctx->SetOutputDim(x_grad_name, x_dims);
+    }
   }
 };
 
diff --git a/paddle/operators/expand_op.h b/paddle/operators/expand_op.h
index f9cd519c70..bd17567c88 100644
--- a/paddle/operators/expand_op.h
+++ b/paddle/operators/expand_op.h
@@ -45,6 +45,8 @@
 namespace paddle {
 namespace operators {
 
+using Tensor = framework::Tensor;
+
 template <typename T, int MajorType = Eigen::RowMajor,
           typename IndexType = Eigen::DenseIndex>
 using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
@@ -53,24 +55,24 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
 using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
 
 template <typename Place, typename T>
-class ExpandKernel : public framework::OpKernel {
+class ExpandKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    auto rank = context.Input<framework::Tensor>("X")->dims().size();
+    auto rank = context.Input<Tensor>("X")->dims().size();
     switch (rank) {
       REP_EXPAND_TEMPLATE(6)
       default:
         PADDLE_ENFORCE(false,
                        "Only support tensor with rank being between 1 and 6.");
-    };
+    }
   }
 
  protected:
   template <int Rank>
   void Expand(const framework::ExecutionContext& context) const {
-    auto* in0 = context.Input<framework::Tensor>("X");
+    auto* in0 = context.Input<Tensor>("X");
     auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
-    auto* out0 = context.Output<framework::LoDTensor>("Out");
+    auto* out0 = context.Output<Tensor>("Out");
     Eigen::DSizes<int, Rank> bcast_dims;
     auto x_dims = in0->dims();
     for (size_t i = 0; i < expand_times.size(); ++i) {
@@ -85,10 +87,10 @@ class ExpandKernel : public framework::OpKernel {
 };
 
 template <typename Place, typename T>
-class ExpandGradKernel : public framework::OpKernel {
+class ExpandGradKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    auto* in0 = context.Input<framework::Tensor>("X");
+    auto* in0 = context.Input<Tensor>("X");
     auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
     auto x_dims = in0->dims();
     std::vector<int> reshape_dims_vec;
@@ -111,23 +113,17 @@ class ExpandGradKernel : public framework::OpKernel {
     int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7;
     // no need reduce, just copy
     if (reduce_dims_vec.size() == 0) {
-      auto* in0 =
-          context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
-      auto* out0 =
-          context.Output<framework::LoDTensor>(framework::GradVarName("X"));
+      auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
+      auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
       out0->mutable_data<T>(context.GetPlace());
-      if (platform::is_cpu_place(context.GetPlace())) {
-        out0->CopyFrom<T>(*in0, platform::CPUPlace());
-      } else {
-        out0->CopyFrom<T>(*in0, platform::GPUPlace());
-      }
+      out0->CopyFrom(*in0, context.GetPlace(), context.device_context());
     } else {
       switch (dims) {
         REP_EXPAND_GRAD_TEMPLATE(72)
         default:
           PADDLE_ENFORCE(
               false, "Only support tensor with rank being between 1 and 6.");
-      };
+      }
     }
   }
 
@@ -144,11 +140,9 @@ class ExpandGradKernel : public framework::OpKernel {
     PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
                       "Inconsistent size between template Dims and "
                       "reduce dimensions.");
-    auto* in0 =
-        context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
-    auto* out0 =
-        context.Output<framework::LoDTensor>(framework::GradVarName("X"));
-    auto x = EigenVector<T>::Flatten(*(context.Input<framework::Tensor>("X")));
+    auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
+    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
+    auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
     out0->mutable_data<T>(context.GetPlace());
     auto x_grad = EigenVector<T>::Flatten(*out0);
     Eigen::DSizes<int, Dims / 6 + 1> reshape_dims;
@@ -165,5 +159,5 @@ class ExpandGradKernel : public framework::OpKernel {
   }
 };
 
-}  // operators
-}  // paddle
+}  // namespace operators
+}  // namespace paddle