diff --git a/paddle/operators/adadelta_op.cc b/paddle/operators/adadelta_op.cc
new file mode 100644
index 0000000000..bd8c93b4a1
--- /dev/null
+++ b/paddle/operators/adadelta_op.cc
@@ -0,0 +1,115 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/adadelta_op.h"
+
+namespace paddle {
+namespace operators {
+
+class AdadeltaOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(framework::InferShapeContextBase *ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("Param"),
+                   "Input(Param) of AdadeltaOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Grad"),
+                   "Input(Grad) of AdadeltaOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("AvgSquaredGrad"),
+                   "Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
+                   "Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
+
+    PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
+                   "Output(ParamOut) of AdadeltaOp should not be null.");
+    PADDLE_ENFORCE(
+        ctx->HasOutput("AvgSquaredGradOut"),
+        "Output(AvgSquaredGradOut) of AdadeltaOp should not be null.");
+    PADDLE_ENFORCE(
+        ctx->HasOutput("AvgSquaredUpdateOut"),
+        "Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null.");
+
+    auto param_dim = ctx->GetInputDim("Param");
+    PADDLE_ENFORCE_EQ(
+        param_dim, ctx->GetInputDim("Grad"),
+        "param and grad input of AdadeltaOp should have same dimension");
+    PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"),
+                      "Param and AvgSquaredGrad input of AdadeltaOp "
+                      "should have same dimension");
+    PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"),
+                      "Param and AvgSquaredUpdate input of AdadeltaOp "
+                      "should have same dimension");
+
+    ctx->SetOutputDim("ParamOut", param_dim);
+    ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
+    ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
+  }
+};
+
+class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  AdadeltaOpMaker(framework::OpProto *proto,
+                  framework::OpAttrChecker *op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("Param", "(Tensor) Input parameter");
+    AddInput("Grad", "(Tensor) Input gradient");
+    AddInput("AvgSquaredGrad",
+             "(Tensor) Input expectation of squared gradient");
+    AddInput("AvgSquaredUpdate",
+             "(Tensor) Input expectation of squared parameter updates");
+
+    AddOutput("ParamOut", "(Tensor) Output parameter");
+    AddOutput("AvgSquaredGradOut",
+              "(Tensor) Output expectation of squared gradient");
+    AddOutput("AvgSquaredUpdateOut",
+              "(Tensor) Output expectation of squared parameter updates");
+
+    AddAttr<float>("rho",
+                   "(float, default 0.95) Exponential decay rate "
+                   "for squared gradients.")
+        .SetDefault(0.95f);
+    AddAttr<float>("epsilon",
+                   "(float, default 1.0e-6) Constant for "
+                   "numerical stability")
+        .SetDefault(1.0e-6f);
+    AddComment(R"DOC(
+Adadelta Updates Operator.
+
+This implements the Adadelta optimizer[1]. Adadelta is a per-dimension
+adaptive learning rate method for gradient descent.
+
+Adadelta updates:
+
+avg_squared_grad_out = rho * avg_squared_grad + (1 - rho) * grad * grad
+param_update =  - sqrt((avg_squared_update + epsilon) /
+                       (avg_squared_grad_out + epsilon)) * grad
+avg_squared_update_out = rho * avg_squared_update + (1 - rho) * param_update**2
+param_out = param + param_update
+
+References:
+  [1] ADADELTA: An Adaptive Learning Rate Method
+      https://arxiv.org/abs/1212.5701
+
+)DOC");
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
+REGISTER_OP_CPU_KERNEL(
+    adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>);
diff --git a/paddle/operators/adadelta_op.cu b/paddle/operators/adadelta_op.cu
new file mode 100644
index 0000000000..3af1c8c8e9
--- /dev/null
+++ b/paddle/operators/adadelta_op.cu
@@ -0,0 +1,20 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/adadelta_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(
+    adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>);
diff --git a/paddle/operators/adadelta_op.h b/paddle/operators/adadelta_op.h
new file mode 100644
index 0000000000..d29e15c435
--- /dev/null
+++ b/paddle/operators/adadelta_op.h
@@ -0,0 +1,69 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+template <typename Place, typename T>
+class AdadeltaOpKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
+    auto avg_squared_grad_out_tensor =
+        ctx.Output<framework::Tensor>("AvgSquaredGradOut");
+    auto avg_squared_update_out_tensor =
+        ctx.Output<framework::Tensor>("AvgSquaredUpdateOut");
+
+    param_out_tensor->mutable_data<T>(ctx.GetPlace());
+    avg_squared_grad_out_tensor->mutable_data<T>(ctx.GetPlace());
+    avg_squared_update_out_tensor->mutable_data<T>(ctx.GetPlace());
+
+    float rho = ctx.Attr<float>("rho");
+    float epsilon = ctx.Attr<float>("epsilon");
+
+    auto param = framework::EigenVector<T>::Flatten(
+        *ctx.Input<framework::Tensor>("Param"));
+    auto grad = framework::EigenVector<T>::Flatten(
+        *ctx.Input<framework::Tensor>("Grad"));
+    // Squared gradient accumulator
+    auto avg_squared_grad = framework::EigenVector<T>::Flatten(
+        *ctx.Input<framework::Tensor>("AvgSquaredGrad"));
+    // Squared updates accumulator
+    auto avg_squared_update = framework::EigenVector<T>::Flatten(
+        *ctx.Input<framework::Tensor>("AvgSquaredUpdate"));
+    auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
+    auto avg_squared_grad_out =
+        framework::EigenVector<T>::Flatten(*avg_squared_grad_out_tensor);
+    auto avg_squared_update_out =
+        framework::EigenVector<T>::Flatten(*avg_squared_update_out_tensor);
+    auto place = ctx.GetEigenDevice<Place>();
+
+    avg_squared_grad_out.device(place) =
+        rho * avg_squared_grad + (1 - rho) * grad.square();
+    auto update =
+        -((avg_squared_update + epsilon) / (avg_squared_grad_out + epsilon))
+             .sqrt() *
+        grad;
+    avg_squared_update_out.device(place) =
+        rho * avg_squared_update + (1 - rho) * update.square();
+    param_out.device(place) = param + update;
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/python/paddle/v2/framework/tests/test_adadelta_op.py b/python/paddle/v2/framework/tests/test_adadelta_op.py
new file mode 100644
index 0000000000..7105593a98
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_adadelta_op.py
@@ -0,0 +1,96 @@
+import unittest
+import numpy as np
+from op_test import OpTest
+
+
+class TestAdadeltaOp1(OpTest):
+    def setUp(self):
+        self.op_type = "adadelta"
+        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
+        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
+        # The squared gradient is positive
+        avg_squared_grad = np.random.random((102, 105)).astype("float32")
+        # The squared update is positive
+        avg_squared_update = np.random.random((102, 105)).astype("float32")
+
+        rho = 0.95
+        epsilon = 1e-6
+
+        self.inputs = {
+            'Param': param,
+            'Grad': grad,
+            'AvgSquaredGrad': avg_squared_grad,
+            'AvgSquaredUpdate': avg_squared_update
+        }
+
+        self.attrs = {'rho': rho, 'epsilon': epsilon}
+
+        avg_squared_grad_out = rho * avg_squared_grad + \
+            (1 - rho) * np.square(grad)
+        update = -np.multiply(
+            np.sqrt(
+                np.divide(avg_squared_update + epsilon, avg_squared_grad_out +
+                          epsilon)), grad)
+
+        avg_squared_update_out = rho * avg_squared_update + \
+            (1 - rho) * np.square(update)
+
+        param_out = param + update
+
+        self.outputs = {
+            'ParamOut': param_out,
+            'AvgSquaredGradOut': avg_squared_grad_out,
+            'AvgSquaredUpdateOut': avg_squared_update_out
+        }
+
+    def test_check_output(self):
+        self.check_output()
+
+
+class TestAdadeltaOp2(OpTest):
+    '''Test Adadelta op with default attribute values
+    '''
+
+    def setUp(self):
+        self.op_type = "adadelta"
+        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
+        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
+        # The squared gradient is positive
+        avg_squared_grad = np.random.random((102, 105)).astype("float32")
+        # The squared update is positive
+        avg_squared_update = np.random.random((102, 105)).astype("float32")
+
+        rho = 0.95
+        epsilon = 1e-6
+
+        self.inputs = {
+            'Param': param,
+            'Grad': grad,
+            'AvgSquaredGrad': avg_squared_grad,
+            'AvgSquaredUpdate': avg_squared_update
+        }
+
+        avg_squared_grad_out = rho * avg_squared_grad + \
+            (1 - rho) * np.square(grad)
+        update = -np.multiply(
+            np.sqrt(
+                np.divide(avg_squared_update + epsilon, avg_squared_grad_out +
+                          epsilon)), grad)
+
+        avg_squared_update_out = rho * avg_squared_update + \
+            (1 - rho) * np.square(update)
+
+        param_out = param + update
+
+        self.outputs = {
+            'ParamOut': param_out,
+            'AvgSquaredGradOut': avg_squared_grad_out,
+            'AvgSquaredUpdateOut': avg_squared_update_out
+        }
+
+    def test_check_output(self):
+        self.check_output()
+
+
+if __name__ == "__main__":
+    unittest.main()