diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc
index 3bd0db8b59..495a8f6c01 100644
--- a/paddle/fluid/operators/yolov3_loss_op.cc
+++ b/paddle/fluid/operators/yolov3_loss_op.cc
@@ -204,7 +204,11 @@ namespace ops = paddle::operators;
 REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
                   ops::Yolov3LossGradMaker);
 REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
-REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>,
-                       ops::Yolov3LossKernel<double>);
-REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel<float>,
-                       ops::Yolov3LossGradKernel<double>);
+REGISTER_OP_CPU_KERNEL(
+    yolov3_loss,
+    ops::Yolov3LossKernel<paddle::platform::CPUDeviceContext, float>,
+    ops::Yolov3LossKernel<paddle::platform::CPUDeviceContext, double>);
+REGISTER_OP_CPU_KERNEL(
+    yolov3_loss_grad,
+    ops::Yolov3LossGradKernel<paddle::platform::CPUDeviceContext, float>,
+    ops::Yolov3LossGradKernel<paddle::platform::CPUDeviceContext, double>);
diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h
index 5de5b4efc7..f086e89a99 100644
--- a/paddle/fluid/operators/yolov3_loss_op.h
+++ b/paddle/fluid/operators/yolov3_loss_op.h
@@ -13,6 +13,7 @@
 #include <algorithm>
 #include <vector>
 #include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/math_function.h"
 
 namespace paddle {
 namespace operators {
@@ -32,183 +33,6 @@ static inline bool isZero(T x) {
   return fabs(x) < 1e-6;
 }
 
-template <typename T>
-static inline void CalcL1LossWithWeight(const Tensor& x, const Tensor& y,
-                                        const Tensor& weight,
-                                        const T loss_weight, T* loss) {
-  int n = x.dims()[0];
-  int stride = x.numel() / n;
-  const T* x_data = x.data<T>();
-  const T* y_data = y.data<T>();
-  const T* weight_data = weight.data<T>();
-
-  for (int i = 0; i < n; i++) {
-    for (int j = 0; j < stride; j++) {
-      loss[i] += fabs(y_data[j] - x_data[j]) * weight_data[j] * loss_weight;
-    }
-    x_data += stride;
-    y_data += stride;
-    weight_data += stride;
-  }
-}
-
-template <typename T>
-static void CalcL1LossGradWithWeight(const T* loss_grad, Tensor* grad,
-                                     const Tensor& x, const Tensor& y,
-                                     const Tensor& weight) {
-  int n = x.dims()[0];
-  int stride = x.numel() / n;
-  T* grad_data = grad->data<T>();
-  const T* x_data = x.data<T>();
-  const T* y_data = y.data<T>();
-  const T* weight_data = weight.data<T>();
-
-  for (int i = 0; i < n; i++) {
-    for (int j = 0; j < stride; j++) {
-      grad_data[j] = weight_data[j] * loss_grad[i];
-      if (x_data[j] < y_data[j]) grad_data[j] *= -1.0;
-    }
-    grad_data += stride;
-    x_data += stride;
-    y_data += stride;
-    weight_data += stride;
-  }
-}
-
-template <typename T>
-static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y,
-                                     const Tensor& weight, const T loss_weight,
-                                     T* loss) {
-  int n = x.dims()[0];
-  int stride = x.numel() / n;
-  const T* x_data = x.data<T>();
-  const T* y_data = y.data<T>();
-  const T* weight_data = weight.data<T>();
-
-  for (int i = 0; i < n; i++) {
-    for (int j = 0; j < stride; j++) {
-      loss[i] += pow(y_data[j] - x_data[j], 2) * weight_data[j] * loss_weight;
-    }
-    x_data += stride;
-    y_data += stride;
-    weight_data += stride;
-  }
-}
-
-template <typename T>
-static void CalcMSEGradWithWeight(const T* loss_grad, Tensor* grad,
-                                  const Tensor& x, const Tensor& y,
-                                  const Tensor& weight) {
-  int n = x.dims()[0];
-  int stride = x.numel() / n;
-  T* grad_data = grad->data<T>();
-  const T* x_data = x.data<T>();
-  const T* y_data = y.data<T>();
-  const T* weight_data = weight.data<T>();
-
-  for (int i = 0; i < n; i++) {
-    for (int j = 0; j < stride; j++) {
-      grad_data[j] =
-          2.0 * weight_data[j] * (x_data[j] - y_data[j]) * loss_grad[i];
-    }
-    grad_data += stride;
-    x_data += stride;
-    y_data += stride;
-    weight_data += stride;
-  }
-}
-
-template <typename T>
-static inline void CalcSCEWithWeight(const Tensor& x, const Tensor& label,
-                                     const Tensor& weight, const T loss_weight,
-                                     T* loss) {
-  int n = x.dims()[0];
-  int stride = x.numel() / n;
-  const T* x_data = x.data<T>();
-  const T* label_data = label.data<T>();
-  const T* weight_data = weight.data<T>();
-
-  for (int i = 0; i < n; i++) {
-    for (int j = 0; j < stride; j++) {
-      T term1 = (x_data[j] > 0) ? x_data[j] : 0;
-      T term2 = x_data[j] * label_data[j];
-      T term3 = std::log(1.0 + std::exp(-std::abs(x_data[j])));
-      loss[i] += (term1 - term2 + term3) * weight_data[j] * loss_weight;
-    }
-    x_data += stride;
-    label_data += stride;
-    weight_data += stride;
-  }
-}
-
-template <typename T>
-static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad,
-                                         const Tensor& x, const Tensor& label,
-                                         const Tensor& weight) {
-  int n = x.dims()[0];
-  int stride = x.numel() / n;
-  T* grad_data = grad->data<T>();
-  const T* x_data = x.data<T>();
-  const T* label_data = label.data<T>();
-  const T* weight_data = weight.data<T>();
-
-  for (int i = 0; i < n; i++) {
-    for (int j = 0; j < stride; j++) {
-      grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) *
-                     weight_data[j] * loss_grad[i];
-    }
-    grad_data += stride;
-    x_data += stride;
-    label_data += stride;
-    weight_data += stride;
-  }
-}
-
-// template <typename T>
-// static void SplitPredResult(const Tensor& input, Tensor* pred_conf,
-//                             Tensor* pred_class, Tensor* pred_x, Tensor*
-//                             pred_y,
-//                             Tensor* pred_w, Tensor* pred_h,
-//                             const int anchor_num, const int class_num) {
-//   const int n = input.dims()[0];
-//   const int h = input.dims()[2];
-//   const int w = input.dims()[3];
-//   const int box_attr_num = 5 + class_num;
-//
-//   auto input_t = EigenTensor<T, 4>::From(input);
-//   auto pred_conf_t = EigenTensor<T, 4>::From(*pred_conf);
-//   auto pred_class_t = EigenTensor<T, 5>::From(*pred_class);
-//   auto pred_x_t = EigenTensor<T, 4>::From(*pred_x);
-//   auto pred_y_t = EigenTensor<T, 4>::From(*pred_y);
-//   auto pred_w_t = EigenTensor<T, 4>::From(*pred_w);
-//   auto pred_h_t = EigenTensor<T, 4>::From(*pred_h);
-//
-//   for (int i = 0; i < n; i++) {
-//     for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
-//       for (int j = 0; j < h; j++) {
-//         for (int k = 0; k < w; k++) {
-//           pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j,
-//           k);
-//           pred_y_t(i, an_idx, j, k) =
-//               input_t(i, box_attr_num * an_idx + 1, j, k);
-//           pred_w_t(i, an_idx, j, k) =
-//               input_t(i, box_attr_num * an_idx + 2, j, k);
-//           pred_h_t(i, an_idx, j, k) =
-//               input_t(i, box_attr_num * an_idx + 3, j, k);
-//
-//           pred_conf_t(i, an_idx, j, k) =
-//               input_t(i, box_attr_num * an_idx + 4, j, k);
-//
-//           for (int c = 0; c < class_num; c++) {
-//             pred_class_t(i, an_idx, j, k, c) =
-//                 input_t(i, box_attr_num * an_idx + 5 + c, j, k);
-//           }
-//         }
-//       }
-//     }
-//   }
-// }
-
 template <typename T>
 static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
   T b1_x1 = box1[0] - box1[2] / 2;
@@ -242,30 +66,36 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
                             Tensor* tconf, Tensor* tclass) {
   const int n = gt_box.dims()[0];
   const int b = gt_box.dims()[1];
-  const int anchor_num = anchors.size() / 2;
-  auto gt_box_t = EigenTensor<T, 3>::From(gt_box);
-  auto gt_label_t = EigenTensor<int, 2>::From(gt_label);
-  auto conf_mask_t = EigenTensor<T, 4>::From(*conf_mask).setConstant(1.0);
-  auto obj_mask_t = EigenTensor<T, 4>::From(*obj_mask).setConstant(0.0);
-  auto tx_t = EigenTensor<T, 4>::From(*tx).setConstant(0.0);
-  auto ty_t = EigenTensor<T, 4>::From(*ty).setConstant(0.0);
-  auto tw_t = EigenTensor<T, 4>::From(*tw).setConstant(0.0);
-  auto th_t = EigenTensor<T, 4>::From(*th).setConstant(0.0);
-  auto tweight_t = EigenTensor<T, 4>::From(*tweight).setConstant(0.0);
-  auto tconf_t = EigenTensor<T, 4>::From(*tconf).setConstant(0.0);
-  auto tclass_t = EigenTensor<T, 5>::From(*tclass).setConstant(0.0);
+  const int an_num = anchors.size() / 2;
+  const int h = tclass->dims()[2];
+  const int w = tclass->dims()[3];
+  const int class_num = tclass->dims()[4];
+
+  const T* gt_box_data = gt_box.data<T>();
+  const int* gt_label_data = gt_label.data<int>();
+  T* conf_mask_data = conf_mask->data<T>();
+  T* obj_mask_data = obj_mask->data<T>();
+  T* tx_data = tx->data<T>();
+  T* ty_data = ty->data<T>();
+  T* tw_data = tw->data<T>();
+  T* th_data = th->data<T>();
+  T* tweight_data = tweight->data<T>();
+  T* tconf_data = tconf->data<T>();
+  T* tclass_data = tclass->data<T>();
 
   for (int i = 0; i < n; i++) {
     for (int j = 0; j < b; j++) {
-      if (isZero<T>(gt_box_t(i, j, 2)) && isZero<T>(gt_box_t(i, j, 3))) {
+      int box_idx = (i * b + j) * 4;
+      if (isZero<T>(gt_box_data[box_idx + 2]) &&
+          isZero<T>(gt_box_data[box_idx + 3])) {
         continue;
       }
 
-      int cur_label = gt_label_t(i, j);
-      T gx = gt_box_t(i, j, 0) * grid_size;
-      T gy = gt_box_t(i, j, 1) * grid_size;
-      T gw = gt_box_t(i, j, 2) * input_size;
-      T gh = gt_box_t(i, j, 3) * input_size;
+      int cur_label = gt_label_data[i * b + j];
+      T gx = gt_box_data[box_idx] * grid_size;
+      T gy = gt_box_data[box_idx + 1] * grid_size;
+      T gw = gt_box_data[box_idx + 2] * input_size;
+      T gh = gt_box_data[box_idx + 3] * input_size;
       int gi = static_cast<int>(gx);
       int gj = static_cast<int>(gy);
 
@@ -273,7 +103,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
       T iou;
       int best_an_index = -1;
       std::vector<T> gt_box_shape({0, 0, gw, gh});
-      for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
+      for (int an_idx = 0; an_idx < an_num; an_idx++) {
         std::vector<T> anchor_shape({0, 0, static_cast<T>(anchors[2 * an_idx]),
                                      static_cast<T>(anchors[2 * an_idx + 1])});
         iou = CalcBoxIoU<T>(gt_box_shape, anchor_shape);
@@ -282,19 +112,22 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
           best_an_index = an_idx;
         }
         if (iou > ignore_thresh) {
-          conf_mask_t(i, an_idx, gj, gi) = static_cast<T>(0.0);
+          int conf_idx = ((i * an_num + an_idx) * h + gj) * w + gi;
+          conf_mask_data[conf_idx] = static_cast<T>(0.0);
         }
       }
-      conf_mask_t(i, best_an_index, gj, gi) = static_cast<T>(1.0);
-      obj_mask_t(i, best_an_index, gj, gi) = static_cast<T>(1.0);
-      tx_t(i, best_an_index, gj, gi) = gx - gi;
-      ty_t(i, best_an_index, gj, gi) = gy - gj;
-      tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]);
-      th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]);
-      tweight_t(i, best_an_index, gj, gi) =
-          2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3);
-      tclass_t(i, best_an_index, gj, gi, cur_label) = 1;
-      tconf_t(i, best_an_index, gj, gi) = 1;
+
+      int obj_idx = ((i * an_num + best_an_index) * h + gj) * w + gi;
+      conf_mask_data[obj_idx] = static_cast<T>(1.0);
+      obj_mask_data[obj_idx] = static_cast<T>(1.0);
+      tx_data[obj_idx] = gx - gi;
+      ty_data[obj_idx] = gy - gj;
+      tw_data[obj_idx] = log(gw / anchors[2 * best_an_index]);
+      th_data[obj_idx] = log(gh / anchors[2 * best_an_index + 1]);
+      tweight_data[obj_idx] =
+          2.0 - gt_box_data[box_idx + 2] * gt_box_data[box_idx + 3];
+      tconf_data[obj_idx] = static_cast<T>(1.0);
+      tclass_data[obj_idx * class_num + cur_label] = static_cast<T>(1.0);
     }
   }
 }
@@ -427,18 +260,26 @@ static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx,
   const int class_num = tclass.dims()[4];
   const int grid_num = h * w;
 
+  // T l = 0.0;
   CalcSCE<T>(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n,
              an_num, grid_num, class_num, 1);
   CalcSCE<T>(loss_data, input_data + grid_num, ty_data, tweight_data,
              obj_mask_data, n, an_num, grid_num, class_num, 1);
+  // LOG(ERROR) << "C++ xy: " << loss_data[0] - l;
+  // l = loss_data[0];
   CalcL1Loss<T>(loss_data, input_data + 2 * grid_num, tw_data, tweight_data,
                 obj_mask_data, n, an_num, grid_num, class_num);
   CalcL1Loss<T>(loss_data, input_data + 3 * grid_num, th_data, tweight_data,
                 obj_mask_data, n, an_num, grid_num, class_num);
+  // LOG(ERROR) << "C++ wh: " << loss_data[0] - l;
+  // l = loss_data[0];
   CalcSCE<T>(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data,
              conf_mask_data, n, an_num, grid_num, class_num, 1);
+  // LOG(ERROR) << "C++ conf: " << loss_data[0] - l;
+  // l = loss_data[0];
   CalcSCE<T>(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data,
              obj_mask_data, n, an_num, grid_num, class_num, class_num);
+  // LOG(ERROR) << "C++ class: " << loss_data[0] - l;
 }
 
 template <typename T>
@@ -488,7 +329,7 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad,
                  obj_mask_data, n, an_num, grid_num, class_num, class_num);
 }
 
-template <typename T>
+template <typename DeviceContext, typename T>
 class Yolov3LossKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
@@ -517,6 +358,27 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
     tweight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+
+    math::SetConstant<DeviceContext, T> constant;
+    constant(ctx.template device_context<DeviceContext>(), &conf_mask,
+             static_cast<T>(1.0));
+    constant(ctx.template device_context<DeviceContext>(), &obj_mask,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tx,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &ty,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tw,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &th,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tweight,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tconf,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tclass,
+             static_cast<T>(0.0));
+
     PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, input_size,
                        h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight,
                        &tconf, &tclass);
@@ -528,7 +390,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
   }
 };
 
-template <typename T>
+template <typename DeviceContext, typename T>
 class Yolov3LossGradKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
@@ -559,6 +421,27 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
     tweight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+
+    math::SetConstant<DeviceContext, T> constant;
+    constant(ctx.template device_context<DeviceContext>(), &conf_mask,
+             static_cast<T>(1.0));
+    constant(ctx.template device_context<DeviceContext>(), &obj_mask,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tx,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &ty,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tw,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &th,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tweight,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tconf,
+             static_cast<T>(0.0));
+    constant(ctx.template device_context<DeviceContext>(), &tclass,
+             static_cast<T>(0.0));
+
     PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, input_size,
                        h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight,
                        &tconf, &tclass);
diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
index cf7e2c5289..862e77e663 100644
--- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
+++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
@@ -197,12 +197,12 @@ class TestYolov3LossOp(OpTest):
             max_relative_error=0.31)
 
     def initTestCase(self):
-        self.anchors = [12, 12, 11, 13]
+        self.anchors = [12, 12]
         self.class_num = 5
         self.ignore_thresh = 0.5
         self.input_size = 416
-        self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5)
-        self.gtbox_shape = (3, 5, 4)
+        self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3)
+        self.gtbox_shape = (1, 5, 4)
 
 
 if __name__ == "__main__":