diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc
new file mode 100644
index 0000000000..e9a3847417
--- /dev/null
+++ b/paddle/operators/smooth_l1_loss_op.cc
@@ -0,0 +1,119 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/smooth_l1_loss_op.h"
+
+namespace paddle {
+namespace operators {
+
+class SmoothL1LossOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext& ctx) const override {
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
+                            "Input of SmoothL1LossOp must be initialized.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
+                            "Target of SmoothL1LossOp must be initialized.");
+
+    auto* x = ctx.Input<framework::Tensor>("X");
+    auto* y = ctx.Input<framework::Tensor>("Y");
+    PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
+                      "Dimensions of SmoothL1LossOp's input and target "
+                      "must be same.");
+    PADDLE_ENFORCE_GE(framework::arity(x->dims()), 2,
+                      "Tensor rank of SmoothL1LossOp's input must be "
+                      "at least 2.");
+    auto* inside_weight = ctx.Input<framework::Tensor>("InsideWeight");
+    if (inside_weight) {
+      auto* outside_weight = ctx.Input<framework::Tensor>("OutsideWeight");
+      PADDLE_ENFORCE_NOT_NULL(outside_weight,
+                              "If weights are provided, must specify both "
+                              "inside and outside weights.");
+      PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(),
+                        "Dimensions of inside weight must be same with input.");
+      PADDLE_ENFORCE_EQ(
+          outside_weight->dims(), x->dims(),
+          "Dimensions of outside weight must be same with input.");
+    }
+
+    auto* diff = ctx.Output<framework::Tensor>("diff");
+    auto* out = ctx.Output<framework::Tensor>("Out");
+    diff->Resize(x->dims());
+    // loss is a two-rank tensor
+    out->Resize({x->dims()[0], 1});
+  }
+};
+
+template <typename AttrType>
+class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  SmoothL1LossOpMaker(framework::OpProto* proto,
+                      framework::OpAttrChecker* op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X", "Input of SmoothL1LossOp.");
+    AddInput("Y", "Target of SmoothL1LossOp.");
+    AddInput("InsideWeight", "Optional input to scale (X-Y).");
+    AddInput("OutsideWeight", "Optinal input to scale smooth l1 loss.");
+    AddOutput("diff", "Intermediate variable to cache Win*(X-Y).")
+        .AsIntermediate();
+    AddOutput("Out", "Final smooth l1 loss of inputs.");
+    AddComment(R"DOC(
+Compute SmoothL1Loss for input and target.
+
+The equation is: Out = 0.5 * (sigma * (X - Y)) ^ 2  if abs(X - Y) < 1 / sigma^2
+                       abs(X - Y) - 0.5 / sigma^2   otherwise
+)DOC");
+    AddAttr<AttrType>("sigma", "Hyper parameter, default value is 3.0 .")
+        .SetDefault(3.0);
+  }
+};
+
+class SmoothL1LossGradOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext& ctx) const override {
+    auto in_dims = ctx.Input<framework::Tensor>("X")->dims();
+    auto out_dims =
+        ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->dims();
+    auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
+    auto* y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
+
+    PADDLE_ENFORCE_GE(framework::arity(out_dims), 2,
+                      "Tensor rank of output gradient should be 2.");
+    PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
+                      "First dimension of ouptut gradient must be "
+                      "same with input.");
+    PADDLE_ENFORCE_EQ(out_dims[1], 1,
+                      "Second dimension of output gradient must be 1.");
+
+    if (x_grad) x_grad->Resize(in_dims);
+    if (y_grad) y_grad->Resize(in_dims);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp,
+            ops::SmoothL1LossOpMaker<float>, ops::SmoothL1LossGradOp);
+REGISTER_OP_CPU_KERNEL(
+    smooth_l1_loss, ops::SmoothL1LossKernel<paddle::platform::CPUPlace, float>);
+REGISTER_OP_CPU_KERNEL(
+    smooth_l1_loss_grad,
+    ops::SmoothL1LossGradKernel<paddle::platform::CPUPlace, float>);
diff --git a/paddle/operators/smooth_l1_loss_op.cu b/paddle/operators/smooth_l1_loss_op.cu
new file mode 100644
index 0000000000..1c3172f438
--- /dev/null
+++ b/paddle/operators/smooth_l1_loss_op.cu
@@ -0,0 +1,24 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#define EIGEN_USE_GPU
+
+#include "paddle/operators/smooth_l1_loss_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(
+    smooth_l1_loss, ops::SmoothL1LossKernel<paddle::platform::GPUPlace, float>);
+REGISTER_OP_GPU_KERNEL(
+    smooth_l1_loss_grad,
+    ops::SmoothL1LossGradKernel<paddle::platform::GPUPlace, float>);
diff --git a/paddle/operators/smooth_l1_loss_op.h b/paddle/operators/smooth_l1_loss_op.h
new file mode 100644
index 0000000000..ae91b9c893
--- /dev/null
+++ b/paddle/operators/smooth_l1_loss_op.h
@@ -0,0 +1,184 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
+
+template <typename T>
+struct SmoothL1LossFoward {
+  __host__ __device__ SmoothL1LossFoward(const T& sigma2) : sigma2(sigma2) {}
+
+  __host__ __device__ T operator()(const T& val) const {
+    T abs_val = std::abs(val);
+    if (abs_val < 1.0 / sigma2) {
+      return 0.5 * val * val * sigma2;
+    } else {
+      return abs_val - 0.5 / sigma2;
+    }
+  }
+
+  T sigma2;
+};
+
+template <typename Place, typename T, typename AttrType = T>
+class SmoothL1LossKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* in0 = context.Input<Tensor>("X");
+    auto* in1 = context.Input<Tensor>("Y");
+    auto* in2 = context.Input<Tensor>("InsideWeight");
+    auto* in3 = context.Input<Tensor>("OutsideWeight");
+    auto* out0 = context.Output<Tensor>("diff");
+    auto* out1 = context.Output<Tensor>("Out");
+
+    out0->mutable_data<T>(context.GetPlace());
+    out1->mutable_data<T>(context.GetPlace());
+    auto place = context.GetEigenDevice<Place>();
+
+    auto sigma = static_cast<T>(context.op_.GetAttr<AttrType>("sigma"));
+    T sigma2 = sigma * sigma;
+    bool has_weight = (in2 != nullptr) && (in3 != nullptr);
+
+    auto x = EigenVector<T>::Flatten(*in0);
+    auto y = EigenVector<T>::Flatten(*in1);
+    auto diff = EigenVector<T>::Flatten(*out0);
+
+    diff.device(place) = x - y;
+    // multiply inside weight
+    if (has_weight) {
+      auto inside_weight = EigenVector<T>::Flatten(*in2);
+      // cache diff, reused in bp
+      diff.device(place) = diff * inside_weight;
+    }
+
+    auto in_counts = framework::product(in0->dims());
+    Tensor paddle_errors;
+    paddle_errors.mutable_data<T>({static_cast<int>(in_counts)},
+                                  context.GetPlace());
+    auto errors = EigenVector<T>::Flatten(paddle_errors);
+    // apply smooth l1 forward
+    errors.device(place) = diff.unaryExpr(SmoothL1LossFoward<T>(sigma2));
+
+    // multiply outside weight
+    if (has_weight) {
+      auto outside_weight = EigenVector<T>::Flatten(*in3);
+      errors.device(place) = errors * outside_weight;
+    }
+    auto loss = EigenMatrix<T>::From(*out1, {in0->dims()[0], 1});
+    // first dimension of 'X' is the number of samples
+    auto errors_mat_view = EigenMatrix<T>::From(paddle_errors, in0->dims());
+    loss.device(place) = errors_mat_view.sum(Eigen::array<int, 1>({1}));
+  }
+};
+
+template <typename T>
+struct SmoothL1LossBackward {
+  __host__ __device__ SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {}
+
+  __host__ __device__ T operator()(const T& val) const {
+    T abs_val = std::abs(val);
+    if (abs_val < 1.0 / sigma2) {
+      return sigma2 * val;
+    } else {
+      return (0 < val) - (val < 0);
+    }
+  }
+
+  T sigma2;
+};
+
+template <typename Place, typename T, typename AttrType = T>
+class SmoothL1LossGradKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* in0 = context.Input<Tensor>("InsideWeight");
+    auto* in1 = context.Input<Tensor>("OutsideWeight");
+    auto* in2 = context.Input<Tensor>("diff");
+    auto* og = context.Input<Tensor>(framework::GradVarName("Out"));
+    auto sigma = static_cast<T>(context.op_.GetAttr<AttrType>("sigma"));
+    T sigma2 = sigma * sigma;
+    bool has_weight = (in0 != nullptr) && (in1 != nullptr);
+
+    auto place = context.GetEigenDevice<Place>();
+
+    auto in_dims = in2->dims();
+    auto counts = framework::product(in_dims);
+    auto cols = counts / in_dims[0];
+    auto mat_dims = framework::make_ddim(
+        {static_cast<int>(in_dims[0]), static_cast<int>(cols)});
+
+    Tensor paddle_diff;
+    paddle_diff.mutable_data<T>({static_cast<int>(counts)}, context.GetPlace());
+    auto diff = EigenVector<T>::Flatten(paddle_diff);
+    // apply smooth l1 backwoard
+    diff.device(place) = EigenVector<T>::Flatten(*in2).unaryExpr(
+        SmoothL1LossBackward<T>(sigma2));
+
+    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
+    auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
+
+    // compute weights
+    Tensor paddle_weights;
+    paddle_weights.mutable_data<T>(mat_dims, context.GetPlace());
+    auto weights = EigenMatrix<T>::From(paddle_weights);
+    // initialize to 1.0
+    if (platform::is_cpu_place(context.GetPlace())) {
+      weights.setConstant(static_cast<T>(1.0));
+    } else {
+      Tensor paddle_cpu_weights;
+      paddle_cpu_weights.mutable_data<T>(mat_dims, platform::CPUPlace());
+      EigenMatrix<T>::From(paddle_cpu_weights).setConstant(static_cast<T>(1.0));
+      paddle_weights.CopyFrom<T>(paddle_cpu_weights, context.GetPlace());
+    }
+    if (has_weight) {
+      auto inside_weight = EigenMatrix<T>::From(*in0, mat_dims);
+      auto outside_weight = EigenMatrix<T>::From(*in1, mat_dims);
+      weights.device(place) = inside_weight * outside_weight;
+    }
+
+    // compute gradients
+    auto out_grad = EigenMatrix<T>::From(*og);
+    auto diff_mat_view = EigenMatrix<T>::From(paddle_diff, mat_dims);
+    auto gradients =
+        out_grad.broadcast(Eigen::array<int, 2>({1, static_cast<int>(cols)})) *
+        weights * diff_mat_view;
+
+    if (out0) {
+      out0->mutable_data<T>(context.GetPlace());
+      auto x_grad = EigenMatrix<T>::From(*out0, mat_dims);
+      x_grad.device(place) = gradients;
+    }
+
+    if (out1) {
+      out1->mutable_data<T>(context.GetPlace());
+      auto y_grad = EigenMatrix<T>::From(*out1, mat_dims);
+      y_grad.device(place) = -1 * gradients;
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index 3bc150ccb7..5aaa372664 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -48,6 +48,7 @@ USE_OP_ITSELF(identity);
 USE_OP(minus);
 USE_CPU_ONLY_OP(gather);
 USE_CPU_ONLY_OP(scatter);
+USE_OP(smooth_l1_loss);
 
 namespace paddle {
 namespace framework {
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index 661ebd8964..763f3a9f95 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -32,3 +32,4 @@ py_test(test_gradient_checker SRCS test_gradient_checker.py)
 py_test(test_lookup_table SRCS test_lookup_table.py)
 py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
 py_test(mnist SRCS mnist.py)
+py_test(test_smooth_l1_loss_op SRCS test_smooth_l1_loss_op.py)
diff --git a/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py b/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py
new file mode 100644
index 0000000000..b3432e703e
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py
@@ -0,0 +1,106 @@
+import unittest
+from op_test_util import OpTestMeta
+from gradient_checker import GradientChecker, create_op
+import functools
+import numpy as np
+from paddle.v2.framework.op import Operator
+
+
+def smooth_l1_loss_forward(val, sigma2):
+    abs_val = abs(val)
+    if abs_val < 1.0 / sigma2:
+        return 0.5 * val * val * sigma2
+    else:
+        return abs_val - 0.5 / sigma2
+
+
+class TestSmoothL1LossOp_f0(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "smooth_l1_loss"
+        dims = (32, 64)
+        self.inputs = {
+            'X': np.random.random(dims).astype("float32"),
+            'Y': np.random.random(dims).astype("float32")
+        }
+        sigma = 3.0
+        self.attrs = {'sigma': sigma}
+        sigma2 = sigma * sigma
+        diff = self.inputs['X'] - self.inputs['Y']
+        loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2).sum(1)
+        loss = loss.reshape((dims[0], 1))
+        self.outputs = {'diff': diff, 'Out': loss}
+
+
+class TestSmoothL1LossOp_f1(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "smooth_l1_loss"
+        dims = (32, 64)
+        self.inputs = {
+            'X': np.random.random(dims).astype("float32"),
+            'Y': np.random.random(dims).astype("float32"),
+            'InsideWeight': np.random.random(dims).astype("float32"),
+            'OutsideWeight': np.random.random(dims).astype("float32")
+        }
+        sigma = 3.0
+        self.attrs = {'sigma': sigma}
+        sigma2 = sigma * sigma
+        diff = self.inputs['X'] - self.inputs['Y']
+        diff = diff * self.inputs['InsideWeight']
+        loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2)
+        loss = loss * self.inputs['OutsideWeight']
+        loss = loss.sum(1).reshape((dims[0], 1))
+        self.outputs = {'diff': diff, 'Out': loss}
+
+
+class SmoothL1LossGradOpTest(GradientChecker):
+    def test_smooth_l1_loss_b0(self):
+        dims = (5, 7)
+        X = np.random.random(dims).astype("float32")
+        Y = np.random.random(dims).astype("float32")
+        InsideWeight = np.random.random(dims).astype("float32")
+        OutsideWeight = np.random.random(dims).astype("float32")
+        inputs = {
+            'X': X,
+            'Y': Y,
+            'InsideWeight': InsideWeight,
+            'OutsideWeight': OutsideWeight
+        }
+        op = Operator(
+            "smooth_l1_loss",
+            X='X',
+            Y='Y',
+            InsideWeight='InsideWeight',
+            OutsideWeight='OutsideWeight',
+            diff="diff",
+            Out="Out",
+            sigma=3.0)
+        self.compare_grad(
+            op, inputs, no_grad_set=set(['InsideWeight', 'OutsideWeight']))
+        self.check_grad(
+            op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.08)
+
+    def test_smooth_l1_loss_b1(self):
+        dims = (5, 7)
+        X = np.random.random(dims).astype("float32")
+        Y = np.random.random(dims).astype("float32")
+        inputs = {'X': X, 'Y': Y}
+        op = Operator(
+            "smooth_l1_loss",
+            X='X',
+            Y='Y',
+            InsideWeight='InsideWeight',
+            OutsideWeight='OutsideWeight',
+            diff="diff",
+            Out="Out",
+            sigma=3.0)
+        self.compare_grad(
+            op, inputs, no_grad_set=set(['InsideWeight', 'OutsideWeight']))
+        self.check_grad(op, inputs, set(["X", "Y"]), "Out")
+
+
+if __name__ == '__main__':
+    unittest.main()