diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index ab1d214333..50f5f34021 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor)
 op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
 op_library(cos_sim_op DEPS cos_sim_functor)
 op_library(parallel_do_op DEPS executor)
+op_library(unsqueeze_op DEPS reshape_op)
 
 if (WITH_GPU)
     op_library(conv_op DEPS vol2col depthwise_conv im2col)
diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc
index 373dac8bab..c503988676 100644
--- a/paddle/fluid/operators/unsqueeze_op.cc
+++ b/paddle/fluid/operators/unsqueeze_op.cc
@@ -12,41 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#include "paddle/fluid/operators/unsqueeze_op.h"
 #include <string>
 #include <vector>
+#include "paddle/fluid/framework/op_registry.h"
 
 namespace paddle {
 namespace operators {
 
-using framework::OpKernelType;
-using framework::Tensor;
-
-class UnsqueezeOp : public framework::OperatorWithKernel {
+class UnsqueezeOpInferShape : public framework::InferShapeBase {
  public:
-  using framework::OperatorWithKernel::OperatorWithKernel;
-
-  void InferShape(framework::InferShapeContext* ctx) const override {
+  void operator()(framework::InferShapeContext *ctx) const override {
     PADDLE_ENFORCE(ctx->HasInput("X"),
                    "Input(X) of UnsqueezeOp should not be null.");
     PADDLE_ENFORCE(ctx->HasOutput("Out"),
                    "Output(Out) of UnsqueezeOp should not be null.");
 
-    const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes");
+    const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
     PADDLE_ENFORCE(!axes.empty(),
                    "The unsqueeze axes information must be set by Attr(axes).");
 
-    const auto& x_dims = ctx->GetInputDim("X");
+    const auto &x_dims = ctx->GetInputDim("X");
     // Validity Check: input tensor dims (<6).
-    PADDLE_ENFORCE(x_dims.size() < 6,
+    PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6,
                    "Invalid dimensions, dynamic dimensions should within "
-                   "[0, 5] dimensions (Eigen limit).");
+                   "[1, 6] dimensions (Eigen limit).");
     // Validity Check: the range of unsqueeze aixs.
-    // TODO(chenweihang): Don't consider negative axis?.
-    for (unsigned int idx = 0; idx < axes.size(); ++idx) {
-      PADDLE_ENFORCE(axes[idx] < 6,
+    for (int axis : axes) {
+      PADDLE_ENFORCE(axis < 6,
                      "Invalid dimensions, input axis should within "
-                     "[0, 5] dimensions (Eigen limit).");
+                     "[1, 6] dimensions (Eigen limit).");
     }
 
     auto out_dims = GetOutputShape(axes, x_dims);
@@ -54,33 +48,7 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
   }
 
   static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
-                                        const framework::DDim& in_dims) {
-    /*
-     * STL version
-     * Test Error! don't know why?.
-    std::vector<int64_t> output_shape;
-
-    // Contruct base output shape
-    for(int idx = 0; idx < in_dims.size(); ++idx) {
-      output_shape.emplace_back(in_dims[idx]);
-    }
-    // Validity Check: output dimensions limit.
-    PADDLE_ENFORCE(unsqz_dims.size() + output_shape.size() < 6,
-                   "The Attr(axes) size is too large. The output shape should "
-                   "be less than 6 (Eigne limit).");
-    // Insert the unsqueeze axis in turn.
-    auto it = output_shape.begin();
-    for (int axis : unsqz_dims) {
-      int cur = axis < 0 ? (axis + output_shape.size() + 1)
-                         : axis;
-      // Vaildity Check: the axis bound
-      PADDLE_ENFORCE(cur >= 0 && cur <= static_cast<int>(output_shape.size()),
-                     "The unsqueeze dims must be within range of current
-    rank.");
-      output_shape.emplace(it + axis, 1);
-    }
-    */
-
+                                        const framework::DDim &in_dims) {
     unsigned int unsqz_mask = 0;
     unsigned int front = 0, back = 0;
     int output_dims_size = in_dims.size();
@@ -93,17 +61,17 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
           cur >= 0 && cur <= output_dims_size,
           "The unsqueeze dims must be within range of current rank.");
       // Save the front part.
-      front = unsqz_mask & ((1 << axis) - 1);
+      front = unsqz_mask & ((1 << cur) - 1);
       // Move the back part.
-      back = unsqz_mask & ~((1 << axis) - 1);
+      back = unsqz_mask & ~((1 << cur) - 1);
       back <<= 1;
       // Merge two part.
-      back |= (1 << axis);
+      back |= (1 << cur);
       unsqz_mask = front | back;
       // Add the output size.
       output_dims_size++;
       // Validity Check: rank range.
-      PADDLE_ENFORCE(output_dims_size < 6,
+      PADDLE_ENFORCE(output_dims_size <= 6,
                      "The output tensor's rank should be less than 6.");
     }
 
@@ -121,6 +89,31 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
   }
 };
 
+class UnsqueezeOp : public framework::OperatorBase {
+ public:
+  UnsqueezeOp(const std::string &type, const framework::VariableNameMap &inputs,
+              const framework::VariableNameMap &outputs,
+              const framework::AttributeMap &attrs)
+      : OperatorBase(type, inputs, outputs, attrs) {}
+
+ private:
+  void RunImpl(const framework::Scope &scope,
+               const platform::Place &place) const override {
+    auto &axes = Attr<std::vector<int>>("axes");
+    auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
+    auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
+
+    framework::AttributeMap attrs;
+    attrs["shape"] = framework::vectorize2int(out_dims);
+    attrs["inplace"] = Attr<bool>("inplace");
+    // Invoke Reshape op.
+    auto reshape_op = framework::OpRegistry::CreateOp(
+        "reshape", {{"X", {Input("X")}}, {"Shape", {}}},
+        {{"Out", {Output("Out")}}}, attrs);
+    reshape_op->Run(scope, place);
+  }
+};
+
 class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
  public:
   void Make() override {
@@ -150,42 +143,49 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
   }
 };
 
-class UnsqueezeGradOp : public framework::OperatorWithKernel {
+class UnsqueezeGradInferShape : public framework::InferShapeBase {
  public:
-  using framework::OperatorWithKernel::OperatorWithKernel;
-
-  void InferShape(framework::InferShapeContext* ctx) const override {
-    PADDLE_ENFORCE(ctx->HasInput("X"),
-                   "Input(X) of UnsqueezeGradOp should not be null.");
-    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
-                   "Output(Out@GRAD) of UnsqueezeGradOp should not be null.");
+  void operator()(framework::InferShapeContext *ctx) const override {
     ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
+    ctx->ShareLoD("X", framework::GradVarName("X"));
   }
+};
 
- protected:
-  framework::OpKernelType GetExpectedKernelType(
-      const framework::ExecutionContext& ctx) const override {
-    return framework::OpKernelType(
-        framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
-        ctx.device_context());
+class UnsqueezeGradOp : public framework::OperatorBase {
+ public:
+  UnsqueezeGradOp(const std::string &type,
+                  const framework::VariableNameMap &inputs,
+                  const framework::VariableNameMap &outputs,
+                  const framework::AttributeMap &attrs)
+      : OperatorBase(type, inputs, outputs, attrs) {}
+
+ private:
+  void RunImpl(const framework::Scope &scope,
+               const platform::Place &place) const override {
+    auto dx_name = Output(framework::GradVarName("X"));
+    auto dout_name = Input(framework::GradVarName("Out"));
+    auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
+
+    framework::AttributeMap attrs;
+    attrs["shape"] = framework::vectorize2int(x_dims);
+    attrs["inplace"] = Attr<bool>("inplace");
+
+    auto reshape_op = framework::OpRegistry::CreateOp(
+        "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
+        attrs);
+    reshape_op->Run(scope, place);
   }
 };
 
 }  // namespace operators
 }  // namespace paddle
 
+// Tell linker to use reshape op.
+USE_OP(reshape);
+
 namespace ops = paddle::operators;
 REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
+                  ops::UnsqueezeOpInferShape,
                   paddle::framework::DefaultGradOpDescMaker<true>);
-REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp);
-REGISTER_OP_CPU_KERNEL(
-    unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
-    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
-    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
-    ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
-REGISTER_OP_CPU_KERNEL(
-    unsqueeze_grad,
-    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
-    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
-    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
-    ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
+REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
+                  ops::UnsqueezeGradInferShape);
diff --git a/paddle/fluid/operators/unsqueeze_op.cu b/paddle/fluid/operators/unsqueeze_op.cu
deleted file mode 100644
index 4d111190cd..0000000000
--- a/paddle/fluid/operators/unsqueeze_op.cu
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#define EIGEN_USE_GPU
-
-#include "paddle/fluid/operators/unsqueeze_op.h"
-
-namespace ops = paddle::operators;
-REGISTER_OP_CUDA_KERNEL(
-    unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
-    ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
-    ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
-    ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
-REGISTER_OP_CUDA_KERNEL(
-    unsqueeze_grad,
-    ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
-    ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
-    ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
-    ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
diff --git a/paddle/fluid/operators/unsqueeze_op.h b/paddle/fluid/operators/unsqueeze_op.h
deleted file mode 100644
index aa45fb3113..0000000000
--- a/paddle/fluid/operators/unsqueeze_op.h
+++ /dev/null
@@ -1,72 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-
-#include <vector>
-
-#include "paddle/fluid/framework/op_registry.h"
-#include "paddle/fluid/framework/operator.h"
-
-namespace paddle {
-namespace operators {
-
-using Tensor = framework::Tensor;
-
-template <typename DeviceContext, typename T>
-class UnsqueezeKernel : public framework::OpKernel<T> {
- public:
-  void Compute(const framework::ExecutionContext &ctx) const override {
-    auto *out = ctx.Output<framework::LoDTensor>("Out");
-    auto *in = ctx.Input<framework::LoDTensor>("X");
-
-    framework::DDim out_dims = out->dims();
-
-    bool inplace = ctx.Attr<bool>("inplace");
-    out->Resize(out_dims);
-    if (!inplace) {
-      out->mutable_data<T>(ctx.GetPlace());
-      framework::TensorCopySync(*in, ctx.GetPlace(), out);
-      out->Resize(out_dims);
-    } else {
-      out->ShareDataWith(*in);
-      out->Resize(out_dims);
-    }
-  }
-};
-
-template <typename DeviceContext, typename T>
-class UnsqueezeGradKernel : public framework::OpKernel<T> {
- public:
-  void Compute(const framework::ExecutionContext &ctx) const override {
-    auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
-    auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
-
-    d_x->mutable_data<T>(ctx.GetPlace());
-    bool inplace = ctx.Attr<bool>("inplace");
-
-    auto in_dims = d_x->dims();
-    if (!inplace) {
-      framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
-      ctx.device_context().Wait();
-      d_x->Resize(in_dims);
-    } else {
-      d_x->ShareDataWith(*d_out);
-      d_x->Resize(in_dims);
-    }
-  }
-};
-
-}  // namespace operators
-}  // namespace paddle
diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
index af273ca5a1..eff90f4618 100644
--- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
+++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
@@ -27,7 +27,7 @@ class TestUnsqueezeOp(OpTest):
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inpalce": False}
+        self.attrs = {"axes": axes, "inplace": False}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
     def test_check_output(self):
@@ -37,23 +37,42 @@ class TestUnsqueezeOp(OpTest):
         self.check_grad(["X"], "Out")
 
 
-# Correct: There is mins axis.
+# Correct: Single input index.
+class TestUnsqueezeOp1(OpTest):
+    def setUp(self):
+        ori_shape = (3, 5)
+        axes = (-1, )
+        new_shape = (3, 5, 1)
+
+        self.op_type = "unsqueeze"
+        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
+        self.attrs = {"axes": axes, "inplace": False}
+        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
+
+
+# Correct: Mixed input axis.
 class TestUnsqueezeOp2(OpTest):
     def setUp(self):
         ori_shape = (3, 5)
-        axes = (0, -2)
-        new_shape = (1, 3, 1, 5)
+        axes = (0, -1)
+        new_shape = (1, 3, 5, 1)
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inpalce": False}
+        self.attrs = {"axes": axes, "inplace": False}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
-        def test_check_output(self):
-            self.check_output()
+    def test_check_output(self):
+        self.check_output()
 
-        def test_check_grad(self):
-            self.check_grad(["X"], "Out")
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
 
 
 # Correct: There is duplicated axis.
@@ -65,83 +84,84 @@ class TestUnsqueezeOp3(OpTest):
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inpalce": False}
+        self.attrs = {"axes": axes, "inplace": False}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
-        def test_check_output(self):
-            self.check_output()
+    def test_check_output(self):
+        self.check_output()
 
-        def test_check_grad(self):
-            self.check_grad(["X"], "Out")
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
 
 
-# Error: Output dimension is error.
-class TestUnsqueezeOp4(OpTest):
+# Correct: Inplace.
+class TestUnsqueezeOpInplace1(OpTest):
     def setUp(self):
-        ori_shape = (3, 2, 5)
-        axes = (0, 3)
-        new_shape = (1, 3, 2, 2, 5)
+        ori_shape = (3, 5)
+        axes = (0, 2)
+        new_shape = (1, 3, 1, 5)
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inpalce": False}
+        self.attrs = {"axes": axes, "inplace": True}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
-        def test_check_output(self):
-            self.check_output()
+    def test_check_output(self):
+        self.check_output()
 
-        def test_check_grad(self):
-            self.check_grad(["X"], "Out")
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
 
 
-# Error: Input axes is invalid case 1.
-class TestUnsqueezeOp5(OpTest):
+# Correct: Inplace. There is mins index.
+class TestUnsqueezeOpInplace2(OpTest):
     def setUp(self):
-        ori_shape = (3, 2, 5)
-        axes = (0, 5)
+        ori_shape = (3, 5)
+        axes = (0, -2)
         new_shape = (1, 3, 1, 5)
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inpalce": False}
+        self.attrs = {"axes": axes, "inplace": True}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
-        def test_check_output(self):
-            self.check_output()
+    def test_check_output(self):
+        self.check_output()
 
-        def test_check_grad(self):
-            self.check_grad(["X"], "Out")
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
 
 
-# Error: Input axes is invalid case 2.
-class TestUnsqueezeOp5(OpTest):
+# Correct: Inplace. There is duplicated axis.
+class TestUnsqueezeOpInplace3(OpTest):
     def setUp(self):
         ori_shape = (3, 2, 5)
-        axes = (0, 2, 10)
-        new_shape = (1, 3, 1, 5)
+        axes = (0, 3, 3)
+        new_shape = (1, 3, 2, 1, 1, 5)
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inpalce": False}
+        self.attrs = {"axes": axes, "inplace": True}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
-        def test_check_output(self):
-            self.check_output()
+    def test_check_output(self):
+        self.check_output()
 
-        def test_check_grad(self):
-            self.check_grad(["X"], "Out")
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
 
 
-# Correct: Inplace.
-class TestUnsqueezeOpInplace1(OpTest):
+'''
+# Error: Output dimension is error.
+class TestUnsqueezeOp4(OpTest):
     def setUp(self):
         ori_shape = (3, 5)
-        axes = (0, 2)
-        new_shape = (1, 3, 1, 5)
+        axes = (0, 3)
+        new_shape = (1, 3, 1, 1, 5)
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inplace": True}
+        self.attrs = {"axes": axes, "inplace": False}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
     def test_check_output(self):
@@ -150,25 +170,60 @@ class TestUnsqueezeOpInplace1(OpTest):
     def test_check_grad(self):
         self.check_grad(["X"], "Out")
 
-
-# Correct: Inplace. There is duplicated axis.
-class TestUnsqueezeOpInplace2(OpTest):
+# Error: Input axis is large than output range.
+class TestUnsqueezeOp5(OpTest):
     def setUp(self):
-        ori_shape = (3, 2, 5)
-        axes = (0, 3, 3)
-        new_shape = (1, 3, 2, 1, 1, 5)
+        ori_shape = (3, 5)
+        axes = (0, 4)
+        new_shape = (1, 3, 5, 1)
 
         self.op_type = "unsqueeze"
         self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
-        self.attrs = {"axes": axes, "inpalce": True}
+        self.attrs = {"axes": axes, "inplace": False}
         self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
 
-        def test_check_output(self):
-            self.check_output()
+    def test_check_output(self):
+        self.check_output()
 
         def test_check_grad(self):
             self.check_grad(["X"], "Out")
 
+# Error: Input axes is large than Eigen limit.
+class TestUnsqueezeOp6(OpTest):
+    def setUp(self):
+        ori_shape = (3, 5)
+        axes = (0, 2, 10)
+        new_shape = (1, 3, 1, 5, 1)
+
+        self.op_type = "unsqueeze"
+        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
+        self.attrs = {"axes": axes, "inplace": False}
+        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
+
+# Error: Input axes size is large than Eigen limit.
+class TestUnsqueezeOp7(OpTest):
+    def setUp(self):
+        ori_shape = (3, 5)
+        axes = (0, 2, 2, 2, 2, 2)
+        new_shape = (1, 3, 1, 1, 5, 1)
+
+        self.op_type = "unsqueeze"
+        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
+        self.attrs = {"axes": axes, "inplace": False}
+        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out")
+'''
 
 if __name__ == "__main__":
     unittest.main()