From 878dd88f6107fb81a9c9db99abad0f770b8c9d1b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= <typhoonzero1986@gmail.com>
Date: Tue, 31 Oct 2017 15:37:23 +0800
Subject: [PATCH] Refine evaluator op types (#5208)

* refine evaluator op types

* update

* follow comments

* update

* fix v2 mnist case

* fix v2 mnist case

* update

* update
---
 paddle/operators/accuracy_op.cc               | 39 +++++++++++++------
 paddle/operators/accuracy_op.cu               | 24 +++++++-----
 paddle/operators/accuracy_op.h                |  9 +++--
 paddle/operators/auc_op.cc                    | 38 ++++++++++++------
 paddle/operators/auc_op.h                     | 37 ++++++++----------
 python/paddle/v2/framework/layers.py          |  7 +++-
 .../v2/framework/tests/test_accuracy_op.py    | 11 +++---
 .../paddle/v2/framework/tests/test_auc_op.py  | 16 ++++----
 8 files changed, 108 insertions(+), 73 deletions(-)

diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc
index 88958e1634..2a2a1e9cfd 100644
--- a/paddle/operators/accuracy_op.cc
+++ b/paddle/operators/accuracy_op.cc
@@ -22,23 +22,35 @@ class AccuracyOp : public framework::OperatorWithKernel {
   using framework::OperatorWithKernel::OperatorWithKernel;
 
   void InferShape(framework::InferShapeContext *ctx) const override {
-    PADDLE_ENFORCE(ctx->HasInput("Inference"),
-                   "Input(Inference) of AccuracyOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Out"),
+                   "Input (Out) of accuracy op should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Indices"),
+                   "Input (Indices) of accuracy op should not be null.");
     PADDLE_ENFORCE(ctx->HasInput("Label"),
-                   "Input(Label) of AccuracyOp should not be null.");
+                   "Input (Label) of accuracy op should not be null.");
     PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
-                   "Output(Accuracy) of AccuracyOp should not be null.");
+                   "Output (Accuracy) of AccuracyOp should not be null.");
 
-    auto inference_dim = ctx->GetInputDim("Inference");
+    auto inference_dim = ctx->GetInputDim("Out");
     auto label_dim = ctx->GetInputDim("Label");
+    // Assume indices has same shape with infernece, because
+    // it's the output of topk.
 
     PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
     PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1");
     PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
-                      "inference size must be the same as label size");
+                      "the inference tensor's num_rows must be"
+                      " the same as label.");
 
     ctx->SetOutputDim("Accuracy", {1});
-    ctx->ShareLoD("Inference", /*->*/ "Accuracy");
+    ctx->ShareLoD("Out", /*->*/ "Accuracy");
+  }
+
+ protected:
+  // IndicateDataType
+  framework::DataType IndicateDataType(
+      const framework::ExecutionContext &ctx) const override {
+    return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
   }
 };
 
@@ -48,7 +60,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
                   framework::OpAttrChecker *op_checker)
       : OpProtoAndCheckerMaker(proto, op_checker) {
     // TODO(typhoonzero): support both inference value and indices.
-    AddInput("Inference", "topk(indices) the network output");
+    AddInput("Out", "topk (inferences) the network output");
+    AddInput("Indices", "topk (indices) the network output");
     AddInput("Label", "Label of the training data");
     // TODO(typhoonzero): AddInput("Weight", ...
     AddOutput("Accuracy", "The accuracy of current batch");
@@ -59,7 +72,7 @@ The accuracy is:
 ..  math::
 accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples})
 
-Both the input `Inference` and `Label` can carry the LoD (Level of Details)
+Both the input `Out` and `Label` can carry the LoD (Level of Details)
 information, or not. But the output only shares the LoD with input `Inference`.
 )DOC");
   }
@@ -71,6 +84,8 @@ information, or not. But the output only shares the LoD with input `Inference`.
 namespace ops = paddle::operators;
 REGISTER_OPERATOR(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker,
                   paddle::framework::EmptyGradOpMaker);
-REGISTER_OP_CPU_KERNEL(
-    accuracy, ops::AccuracyKernel<paddle::platform::CPUPlace, int>,
-    ops::AccuracyKernel<paddle::platform::CPUPlace, int64_t>);
+// FIXME(typhoonzero): types of T is for infernece data.
+// label data is always int.
+REGISTER_OP_CPU_KERNEL(accuracy,
+                       ops::AccuracyKernel<paddle::platform::CPUPlace, float>,
+                       ops::AccuracyKernel<paddle::platform::CPUPlace, double>);
diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu
index be58dfbd03..a0483f367e 100644
--- a/paddle/operators/accuracy_op.cu
+++ b/paddle/operators/accuracy_op.cu
@@ -21,9 +21,10 @@ namespace paddle {
 namespace operators {
 using platform::PADDLE_CUDA_NUM_THREADS;
 
-template <typename T, int BlockSize>
-__global__ void AccuracyCudaKernel(const int N, const int D, const T* Xdata,
-                                   const T* labeldata, float* accuracy) {
+template <int BlockSize>
+__global__ void AccuracyCudaKernel(const int N, const int D,
+                                   const int64_t* Xdata,
+                                   const int64_t* labeldata, float* accuracy) {
   int count = 0;
   __shared__ int total[BlockSize];
 
@@ -52,13 +53,14 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& ctx) const override {
     PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                    "It must use GPUPlace.");
-    auto* inference = ctx.Input<Tensor>("Inference");
+    auto* inference = ctx.Input<Tensor>("Out");
+    auto* indices = ctx.Input<Tensor>("Indices");
     auto* label = ctx.Input<Tensor>("Label");
     auto* accuracy = ctx.Output<Tensor>("Accuracy");
     // FIXME(typhoonzero): only support indices currently
     // if add support for output values, how to detect the data type?
-    const T* inference_data = inference->data<T>();
-    const T* label_data = label->data<T>();
+    const int64_t* indices_data = indices->data<int64_t>();
+    const int64_t* label_data = label->data<int64_t>();
     float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
 
     size_t num_samples = inference->dims()[0];
@@ -69,11 +71,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
       return;
     }
 
-    AccuracyCudaKernel<T, PADDLE_CUDA_NUM_THREADS><<<
+    AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
         1, PADDLE_CUDA_NUM_THREADS, 0,
         reinterpret_cast<const platform::CUDADeviceContext&>(
             ctx.device_context())
-            .stream()>>>(num_samples, infer_width, inference_data, label_data,
+            .stream()>>>(num_samples, infer_width, indices_data, label_data,
                          accuracy_data);
   }
 };
@@ -81,5 +83,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
 }  // namespace operators
 }  // namespace paddle
 
-REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<int>,
-                       paddle::operators::AccuracyOpCUDAKernel<int64_t>);
+// FIXME(typhoonzero): types of T is for infernece data.
+// label data is always int
+REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
+                       paddle::operators::AccuracyOpCUDAKernel<double>);
diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h
index 12c6b9aac8..1968b53d19 100644
--- a/paddle/operators/accuracy_op.h
+++ b/paddle/operators/accuracy_op.h
@@ -38,14 +38,15 @@ template <typename Place, typename T>
 class AccuracyKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    auto* inference = ctx.Input<Tensor>("Inference");
+    auto* inference = ctx.Input<Tensor>("Out");
+    auto* indices = ctx.Input<Tensor>("Indices");
     auto* label = ctx.Input<Tensor>("Label");
     auto* accuracy = ctx.Output<Tensor>("Accuracy");
 
     float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
 
-    const T* inference_data = inference->data<T>();
-    const T* label_data = label->data<T>();
+    const int64_t* indices_data = indices->data<int64_t>();
+    const int64_t* label_data = label->data<int64_t>();
 
     size_t num_samples = inference->dims()[0];
     size_t class_dim = inference->dims()[1];
@@ -60,7 +61,7 @@ class AccuracyKernel : public framework::OpKernel<T> {
     for (size_t i = 0; i < num_samples; ++i) {
       PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0");
       for (size_t j = 0; j < class_dim; ++j) {
-        if (inference_data[i * class_dim + j] == label_data[i]) {
+        if (indices_data[i * class_dim + j] == label_data[i]) {
           ++num_correct;
           break;
         }
diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc
index cf3dbc5d10..f5784922af 100644
--- a/paddle/operators/auc_op.cc
+++ b/paddle/operators/auc_op.cc
@@ -23,18 +23,26 @@ class AucOp : public framework::OperatorWithKernel {
 
  protected:
   void InferShape(framework::InferShapeContext *ctx) const override {
-    PADDLE_ENFORCE(ctx->HasInput("Inference"),
-                   "Input of Inference must be initialized.");
+    PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out must be initialized.");
+    PADDLE_ENFORCE(ctx->HasInput("Indices"),
+                   "Input of Indices must be initialized.");
     PADDLE_ENFORCE(ctx->HasInput("Label"),
                    "Input of Label must be initialized.");
-    auto inference_dim = ctx->GetInputDim("Inference");
-    auto label_dim = ctx->GetInputDim("Label");
+    auto inference_height = ctx->GetInputDim("Out")[0];
+    auto label_height = ctx->GetInputDim("Label")[0];
 
-    PADDLE_ENFORCE_EQ(inference_dim, label_dim,
-                      "inference and label should have same shape");
+    PADDLE_ENFORCE_EQ(inference_height, label_height,
+                      "Out and Label should have same height.");
 
     ctx->SetOutputDim("AUC", {1});
-    ctx->ShareLoD("Inference", /*->*/ "AUC");
+    ctx->ShareLoD("Out", /*->*/ "AUC");
+  }
+
+ protected:
+  // IndicateDataType
+  framework::DataType IndicateDataType(
+      const framework::ExecutionContext &ctx) const override {
+    return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
   }
 };
 
@@ -42,12 +50,18 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
  public:
   AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
       : OpProtoAndCheckerMaker(proto, op_checker) {
-    AddInput("Inference",
-             "A floating point tensor of arbitrary shape and whose values"
-             "are in the range [0, 1].");
+    AddInput("Out",
+             "A floating point 2D tensor, values are in the range [0, 1]."
+             "Each row is descend sorted. This input should be the"
+             "output of topk."
+             "Typically, this tensor indicates the probability of each label");
+    AddInput("Indices",
+             "An int 2D tensor, indicating the indices of original"
+             "tensor before sort. Typically, this tensor indicates which label"
+             "the probability stands for.");
     AddInput("Label",
-             "A tensor whose shape matches "
-             "Inference. Will be cast to bool.");
+             "A 2D int tensor indicating the label of the training data."
+             "The height is batch size and width is always 1.");
     // TODO(typhoonzero): support weight input
     AddOutput("AUC",
               "A scalar representing the "
diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h
index be6ef29d5f..e5ac57b038 100644
--- a/paddle/operators/auc_op.h
+++ b/paddle/operators/auc_op.h
@@ -29,7 +29,7 @@ template <typename Place, typename T>
 class AucKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    auto* inference = ctx.Input<Tensor>("Inference");
+    auto* inference = ctx.Input<Tensor>("Out");
     auto* label = ctx.Input<Tensor>("Label");
     auto* auc = ctx.Output<Tensor>("AUC");
 
@@ -46,18 +46,11 @@ class AucKernel : public framework::OpKernel<T> {
     thresholds_list[0] = 0.0f - kEpsilon;
     thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
 
-    size_t num_samples = inference->numel();
+    size_t batch_size = inference->dims()[0];
+    size_t inference_width = inference->dims()[1];
 
     const T* inference_data = inference->data<T>();
-    Tensor label_casted;
-    label_casted.Resize(label->dims());
-    bool* label_casted_data = label_casted.mutable_data<bool>(ctx.GetPlace());
-
-    const int* label_data = label->data<int>();
-    // cast label_data to bool
-    for (size_t i = 0; i < num_samples; i++) {
-      label_casted_data[i] = static_cast<bool>(label_data[i]);
-    }
+    const int64_t* label_data = label->data<int64_t>();
 
     // Create local tensor for storing the curve: TP, FN, TN, FP
     // TODO(typhoonzero): use eigen op to caculate these values.
@@ -68,23 +61,27 @@ class AucKernel : public framework::OpKernel<T> {
     true_negative.Resize({num_thresholds});
     false_positive.Resize({num_thresholds});
 
-    int* tp_data = true_positive.mutable_data<int>(ctx.GetPlace());
-    int* fn_data = false_negative.mutable_data<int>(ctx.GetPlace());
-    int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace());
-    int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace());
+    int64_t* tp_data = true_positive.mutable_data<int64_t>(ctx.GetPlace());
+    int64_t* fn_data = false_negative.mutable_data<int64_t>(ctx.GetPlace());
+    int64_t* tn_data = true_negative.mutable_data<int64_t>(ctx.GetPlace());
+    int64_t* fp_data = false_positive.mutable_data<int64_t>(ctx.GetPlace());
 
     for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
       // caculate TP, FN, TN, FP for current thresh
-      int tp = 0, fn = 0, tn = 0, fp = 0;
-      for (size_t i = 0; i < num_samples; i++) {
-        if (label_casted_data[i]) {
-          if (inference_data[i] >= (thresholds_list[idx_thresh])) {
+      int64_t tp = 0, fn = 0, tn = 0, fp = 0;
+      for (size_t i = 0; i < batch_size; i++) {
+        // NOTE: label_data used as bool, labels >0 will be treated as true.
+        if (label_data[i]) {
+          // use first(max) data in each row
+          if (inference_data[i * inference_width] >=
+              (thresholds_list[idx_thresh])) {
             tp++;
           } else {
             fn++;
           }
         } else {
-          if (inference_data[i] >= (thresholds_list[idx_thresh])) {
+          if (inference_data[i * inference_width] >=
+              (thresholds_list[idx_thresh])) {
             fp++;
           } else {
             tn++;
diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py
index 4727d139a2..6451d11e2b 100644
--- a/python/paddle/v2/framework/layers.py
+++ b/python/paddle/v2/framework/layers.py
@@ -243,8 +243,11 @@ def accuracy(input, label, k=1, **kwargs):
     acc_out = helper.create_tmp_variable(dtype=acc_out_dtype)
     helper.append_op(
         type="accuracy",
-        inputs={"Inference": [topk_indices],
-                "Label": [label]},
+        inputs={
+            "Out": [topk_out],
+            "Indices": [topk_indices],
+            "Label": [label]
+        },
         outputs={"Accuracy": [acc_out]})
     return acc_out
 
diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py
index f17edd44ae..6536c297e8 100644
--- a/python/paddle/v2/framework/tests/test_accuracy_op.py
+++ b/python/paddle/v2/framework/tests/test_accuracy_op.py
@@ -7,13 +7,14 @@ class TestAccuracyOp(OpTest):
     def setUp(self):
         self.op_type = "accuracy"
         n = 8192
-        infer = np.random.randint(0, 2, (n, 1)).astype("int")
-        label = np.random.randint(0, 2, (n, 1)).astype("int")
-        self.inputs = {'Inference': infer, "Label": label}
+        infer = np.random.random((n, 1)).astype("float32")
+        indices = np.random.randint(0, 2, (n, 1))
+        label = np.random.randint(0, 2, (n, 1))
+        self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
         num_correct = 0
         for rowid in xrange(n):
-            for ele in infer[rowid]:
-                if ele == label[rowid][0]:
+            for ele in indices[rowid]:
+                if ele == label[rowid]:
                     num_correct += 1
                     break
         self.outputs = {
diff --git a/python/paddle/v2/framework/tests/test_auc_op.py b/python/paddle/v2/framework/tests/test_auc_op.py
index 65f679cfcc..26ea905d88 100644
--- a/python/paddle/v2/framework/tests/test_auc_op.py
+++ b/python/paddle/v2/framework/tests/test_auc_op.py
@@ -6,10 +6,11 @@ from op_test import OpTest
 class TestAucOp(OpTest):
     def setUp(self):
         self.op_type = "auc"
-        pred = np.random.random((128)).astype("float32")
-        labels = np.random.randint(0, 2, (128, ))
+        pred = np.random.random((128, 2)).astype("float32")
+        indices = np.random.randint(0, 2, (128, 2))
+        labels = np.random.randint(0, 2, (128, 1))
         num_thresholds = 200
-        self.inputs = {'Inference': pred, 'Label': labels}
+        self.inputs = {'Out': pred, 'Indices': indices, 'Label': labels}
         self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
         # NOTE: sklearn use a different way to generate thresholds
         #       which will cause the result differs slightly:
@@ -31,12 +32,12 @@ class TestAucOp(OpTest):
             tp, fn, tn, fp = 0, 0, 0, 0
             for i, lbl in enumerate(labels):
                 if lbl:
-                    if pred[i] >= thresh:
+                    if pred[i, 0] >= thresh:
                         tp += 1
                     else:
                         fn += 1
                 else:
-                    if pred[i] >= thresh:
+                    if pred[i, 0] >= thresh:
                         fp += 1
                     else:
                         tn += 1
@@ -62,6 +63,5 @@ class TestAucOp(OpTest):
         self.check_output()
 
 
-# TODO(typhoonzero): add this back till we fix it
-#if __name__ == "__main__":
-#    unittest.main()
+if __name__ == "__main__":
+    unittest.main()