From 06c7c8c80e2c843afb7c5b156766533a5a389be9 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 26 Oct 2017 11:59:54 +0800 Subject: [PATCH 1/5] Add CPU kernel. --- paddle/operators/precision_recall_op.cc | 118 ++++++++++++++++++ paddle/operators/precision_recall_op.h | 159 ++++++++++++++++++++++++ 2 files changed, 277 insertions(+) create mode 100644 paddle/operators/precision_recall_op.cc create mode 100644 paddle/operators/precision_recall_op.h diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc new file mode 100644 index 0000000000..22eaa3f36e --- /dev/null +++ b/paddle/operators/precision_recall_op.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +namespace paddle { +namespace operators { + +class PrecisionRecallOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // may contains weights and StatesInfo + PADDLE_ENFORCE(ctx->HasInput("Predictions"), + "Input(Predictions) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"), + "Output(BatchMetrics) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AccumMetrics"), + "Output(AccumMetrics) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"), + "Output(AccumStatesInfo) should not be null."); + + auto predictions_dims = ctx->GetInputDim("Predictions"); + auto labels_dims = ctx->GetInputDim("Labels"); + + if (ctx->HasInput("Weights")) { + auto weights_dims = ctx->GetInputDim("Weights"); + PADDLE_ENFORCE_EQ(weights_dims, {predictions_dims[0], 1}, + "The shape of Input(Weights) should be " + "[batch_size, 1]."); + } + if (ctx->HasInput("StatesInfo")) { + auto states_dims = ctx->GetInputDim("StatesInfo"); + PADDLE_ENFORCE_EQ(states_dims, {predictions_dims[1], 4}, + "The shape of Input(StatesInfo) should be " + "[class_number, 4]."); + } + PADDLE_ENFORCE_EQ(predictions_dims[0], labels_dims[0], + "The 1st dimension of Input(Predictions) and " + "Input(Labels) both are batch_size and the shape should " + "be the same."); + PADDLE_ENFORCE_EQ(labels_dims[1], 1, + "The 2nd dimension of Input(Labels) " + "contains instance label and the shape should be equal " + "to 1"); + PADDLE_ENFORCE_GE(predictions_dims[1], 1, + "The shape of Input(Predictions)'s 2nd dimension is " + "equal to class number and should be at least 1."); + + // Layouts of BatchMetrics and AccumMetrics both are: + // [ + // macro average precision, macro average recall, macro average F1 score, + // micro average precision, micro average recall, micro average F1 score + // ] + ctx->SetOutputDim("BatchMetrics", {6}); + ctx->SetOutputDim("AccumMetrics", {6}); + // Shape of AccumStatesInfo is [class_number, 4] + // The layout of each row is: + // [ TP, FP, TN, FN ] + ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); + } +}; + +class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PrecisionRecallOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Predictions", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "Each row contains probabilities for an instance which computed " + "by the previous operator."); + AddInput("Labels", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. Each element is a label and the " + "value should be in [0, class_number - 1]."); + AddInput("Weights", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. This input is optional. If provided, " + "weight of instance would be considered when computing metrics.") + .AsDispensable(); + AddInput("StatesInfo", + "(Tensor, default Tensor), a 2-D tensor with shape D x 4, " + "where D is the number of classes. This input is optional. If " + "provided, current state will be accumulated to this state and " + "the accumulation state will be as the output state.") + .AsDispensable(); + + AddComment(R"DOC( +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp, + ops::PrecisionRecallOpMaker); +REGISTER_OP_CPU_KERNEL( + precision_recall, + ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel, diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h new file mode 100644 index 0000000000..7ed5f2387e --- /dev/null +++ b/paddle/operators/precision_recall_op.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +enum StateVariable { TP = 0, FP, TN, FN }; + +template +class PrecisionRecallKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in0 = ctx.Input("Predictions"); + auto* in1 = ctx.Input("Labels"); + auto* in2 = ctx.Input("Weights"); + auto* in3 = ctx.Input("StatesInfo"); + auto* out0 = ctx.Output("BatchMetrics"); + auto* out1 = ctx.Output("AccumMetrics"); + auto* out2 = ctx.Output("AccumStatesInfo"); + + const T* predictions_data = in0->data(); + const T* labels_data = in1->data(); + const T* weights_data = in2 ? in2->data() : nullptr; + const T* states_data = in3 ? in3->data() : nullptr; + T* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); + T* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); + out2->mutable_data(ctx.GetPlace()); + auto accum_states = EigenMatrix::From(*out2); + accum_states.setZero(); + T* accum_states_data = out2->data(ctx.GetPlace()); + + size_t sample_num = in0->dims()[0]; + size_t class_dim = in0->dims()[1]; + size_t state_var_num = 4; // TP FP TN FN + + // get states info for current batch + for (size_t i = 0; i < sample_num; ++i) { + size_t max_idx = 0; + T max_val = predictions_data[i * class_dim]; + for (size_t j = 1; j < class_dim; ++j) { + if (max_val < predictions_data[i * class_dim + j]) { + max_idx = j; + max_val = predictions_data[i * class_dim + j]; + } + } + + T w = weights_data ? weights_data[i] : 1.0; + if (max_idx == labels_data[i]) { + accum_states_data[max_idx * state_var_num + TP] += w; + for (size_t j = 0; j < class_dim; ++j) { + accum_states_data[j * state_var_num + TN] += w; + } + accum_states_data[max_idx * state_var_num + TN] -= w; + } else { + accum_states_data[labels_data[i] * state_var_num + FN] += w; + accum_states_data[max_idx * state_var_num + FP] += w; + for (size_t j = 0; j < class_dim; ++j) { + accum_states_data[j * state_var_num + TN] += w; + } + accum_states_data[max_idx * state_var_num + TN] -= w; + accum_states_data[labels_data[j] * state_var_num + TN] -= w; + } + } + + ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num, + class_dim); + + if (states_data) { + for (size_t i = 0; i < class_dim; ++i) { + for (size_t j = 0; j < state_var_num; ++j) { + size_t idx = i * state_var_num + j; + accum_states_data[idx] += states_data[idx]; + } + } + } + + ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num, + class_dim); + } + + // expose to be reused + static inline T CalcPrecision(T tp_count, T fp_count) { + if (tp_count > 0.0 || fp_count > 0.0) { + return tp_count / (tp_count + fp_count); + } + return 1.0; + } + + static inline T CalcRecall(T tp_count, T fn_count) { + if (tp_count > 0.0 || fn_count > 0.0) { + return tp_count / (tp_count + fn_count); + } + return 1.0 + } + + static inline T CalcF1Score(T precision, T recall) { + if (precision > 0.0 || recall > 0.0) { + return 2 * precision * recall / (precision + recall); + } + return 0.0; + } + + protected: + void ComputeMetrics(const T* states_data, T* metrics_data, + size_t state_var_num, size_t class_dim) { + T total_tp_count = 0; + T total_fp_count = 0; + T total_fn_count = 0; + T macro_avg_precision = 0.0; + T macro_avg_recall = 0.0; + + for (size_t i = 0; i < class_dim; ++i) { + T tp_count = states_data[i * state_var_num + TP]; + T fp_count = states_data[i * state_var_num + FP]; + T fn_count = states_data[i * state_var_num + FN]; + total_tp_count += tp_count; + total_fp_count += fp_count; + total_fn_count += fn_count; + macro_avg_precision += CalcPrecision(tp_count, fp_count); + macro_avg_recall += CalcRecall(tp_count, fn_count); + } + macro_avg_precision /= class_dim; + macro_avg_recall /= class_dim; + T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall); + + T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); + T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count); + T micro_f1_score = CalcRecall(micro_avg_precision, micro_avg_recall); + + // fill metrics data + metrics_data[0] = macro_avg_precision; + metrics_data[1] = macro_avg_recall; + metrics_data[2] = macro_f1_score; + metrics_data[3] = micro_avg_precision; + metrics_data[4] = micro_avg_recall; + metrics_data[5] = micro_f1_score; + } +}; + +} // namespace operators +} // namespace paddle From 65dbbd57af4016953338b27e80aa05cfed62c220 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 26 Oct 2017 22:42:44 +0800 Subject: [PATCH 2/5] Add and pass unittests. --- paddle/operators/precision_recall_op.cc | 21 ++- paddle/operators/precision_recall_op.h | 14 +- .../tests/test_precision_recall_op.py | 164 ++++++++++++++++++ 3 files changed, 188 insertions(+), 11 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_precision_recall_op.py diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index 22eaa3f36e..47a16b9461 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/operators/precision_recall_op.h" + namespace paddle { namespace operators { @@ -37,13 +39,15 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { if (ctx->HasInput("Weights")) { auto weights_dims = ctx->GetInputDim("Weights"); - PADDLE_ENFORCE_EQ(weights_dims, {predictions_dims[0], 1}, + PADDLE_ENFORCE_EQ(weights_dims, + framework::make_ddim({predictions_dims[0], 1}), "The shape of Input(Weights) should be " "[batch_size, 1]."); } if (ctx->HasInput("StatesInfo")) { auto states_dims = ctx->GetInputDim("StatesInfo"); - PADDLE_ENFORCE_EQ(states_dims, {predictions_dims[1], 4}, + PADDLE_ENFORCE_EQ(states_dims, + framework::make_ddim({predictions_dims[1], 4}), "The shape of Input(StatesInfo) should be " "[class_number, 4]."); } @@ -71,6 +75,12 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { // [ TP, FP, TN, FN ] ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return framework::ToDataType(ctx.Input("Predictions")->type()); + } }; class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { @@ -98,6 +108,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { "provided, current state will be accumulated to this state and " "the accumulation state will be as the output state.") .AsDispensable(); + AddOutput("BatchMetrics", ""); + AddOutput("AccumMetrics", ""); + AddOutput("AccumStatesInfo", ""); AddComment(R"DOC( )DOC"); @@ -113,6 +126,4 @@ REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp, REGISTER_OP_CPU_KERNEL( precision_recall, ops::PrecisionRecallKernel, - ops::PrecisionRecallKernel, - ops::PrecisionRecallKernel, - ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel); diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h index 7ed5f2387e..3bc638ea44 100644 --- a/paddle/operators/precision_recall_op.h +++ b/paddle/operators/precision_recall_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { @@ -37,7 +39,7 @@ class PrecisionRecallKernel : public framework::OpKernel { auto* out2 = ctx.Output("AccumStatesInfo"); const T* predictions_data = in0->data(); - const T* labels_data = in1->data(); + const int* labels_data = in1->data(); const T* weights_data = in2 ? in2->data() : nullptr; const T* states_data = in3 ? in3->data() : nullptr; T* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); @@ -45,7 +47,7 @@ class PrecisionRecallKernel : public framework::OpKernel { out2->mutable_data(ctx.GetPlace()); auto accum_states = EigenMatrix::From(*out2); accum_states.setZero(); - T* accum_states_data = out2->data(ctx.GetPlace()); + T* accum_states_data = out2->data(); size_t sample_num = in0->dims()[0]; size_t class_dim = in0->dims()[1]; @@ -76,7 +78,7 @@ class PrecisionRecallKernel : public framework::OpKernel { accum_states_data[j * state_var_num + TN] += w; } accum_states_data[max_idx * state_var_num + TN] -= w; - accum_states_data[labels_data[j] * state_var_num + TN] -= w; + accum_states_data[labels_data[i] * state_var_num + TN] -= w; } } @@ -108,7 +110,7 @@ class PrecisionRecallKernel : public framework::OpKernel { if (tp_count > 0.0 || fn_count > 0.0) { return tp_count / (tp_count + fn_count); } - return 1.0 + return 1.0; } static inline T CalcF1Score(T precision, T recall) { @@ -120,7 +122,7 @@ class PrecisionRecallKernel : public framework::OpKernel { protected: void ComputeMetrics(const T* states_data, T* metrics_data, - size_t state_var_num, size_t class_dim) { + size_t state_var_num, size_t class_dim) const { T total_tp_count = 0; T total_fp_count = 0; T total_fn_count = 0; @@ -143,7 +145,7 @@ class PrecisionRecallKernel : public framework::OpKernel { T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count); - T micro_f1_score = CalcRecall(micro_avg_precision, micro_avg_recall); + T micro_f1_score = CalcF1Score(micro_avg_precision, micro_avg_recall); // fill metrics data metrics_data[0] = macro_avg_precision; diff --git a/python/paddle/v2/framework/tests/test_precision_recall_op.py b/python/paddle/v2/framework/tests/test_precision_recall_op.py new file mode 100644 index 0000000000..33efd717d1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_precision_recall_op.py @@ -0,0 +1,164 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def calc_precision(tp_count, fp_count): + if tp_count > 0.0 or fp_count > 0.0: + return tp_count / (tp_count + fp_count) + return 1.0 + + +def calc_recall(tp_count, fn_count): + if tp_count > 0.0 or fn_count > 0.0: + return tp_count / (tp_count + fn_count) + return 1.0 + + +def calc_f1_score(precision, recall): + if precision > 0.0 or recall > 0.0: + return 2 * precision * recall / (precision + recall) + return 0.0 + + +def get_states(predictions, labels, weights=None): + ins_num = predictions.shape[0] + class_num = predictions.shape[1] + # TP FP TN FN + states = np.zeros((class_num, 4)).astype('float32') + for i in xrange(ins_num): + w = weights[i] if weights is not None else 1.0 + max_idx = np.argmax(predictions[i]) + if max_idx == labels[i][0]: + states[max_idx][0] += w + for j in xrange(class_num): + states[j][2] += w + states[max_idx][2] -= w + else: + states[labels[i][0]][3] += w + states[max_idx][1] += w + for j in xrange(class_num): + states[j][2] += w + states[labels[i][0]][2] -= w + states[max_idx][2] -= w + return states + + +def compute_metrics(states): + class_num = states.shape[0] + total_tp_count = 0.0 + total_fp_count = 0.0 + total_fn_count = 0.0 + macro_avg_precision = 0.0 + macro_avg_recall = 0.0 + for i in xrange(class_num): + total_tp_count += states[i][0] + total_fp_count += states[i][1] + total_fn_count += states[i][3] + macro_avg_precision += calc_precision(states[i][0], states[i][1]) + macro_avg_recall += calc_recall(states[i][0], states[i][3]) + metrics = [] + macro_avg_precision /= class_num + macro_avg_recall /= class_num + metrics.append(macro_avg_precision) + metrics.append(macro_avg_recall) + metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall)) + micro_avg_precision = calc_precision(total_tp_count, total_fp_count) + metrics.append(micro_avg_precision) + micro_avg_recall = calc_recall(total_tp_count, total_fn_count) + metrics.append(micro_avg_recall) + metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall)) + return np.array(metrics).astype('float32') + + +class TestPrecisionRecallOp_0(OpTest): + def setUp(self): + self.op_type = "precision_recall" + ins_num = 64 + class_num = 10 + predictions = np.random.uniform(0, 1.0, + (ins_num, class_num)).astype('float32') + labels = np.random.choice(xrange(class_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + states = get_states(predictions, labels) + metrics = compute_metrics(states) + + self.inputs = {'Predictions': predictions, 'Labels': labels} + + self.outputs = { + 'BatchMetrics': metrics, + 'AccumMetrics': metrics, + 'AccumStatesInfo': states + } + + def test_check_output(self): + self.check_output() + + +class TestPrecisionRecallOp_1(OpTest): + def setUp(self): + self.op_type = "precision_recall" + ins_num = 64 + class_num = 10 + predictions = np.random.uniform(0, 1.0, + (ins_num, class_num)).astype('float32') + weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + predictions = np.random.random((ins_num, class_num)).astype('float32') + labels = np.random.choice(xrange(class_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + + states = get_states(predictions, labels, weights) + metrics = compute_metrics(states) + self.inputs = { + 'Predictions': predictions, + 'Labels': labels, + 'Weights': weights + } + + self.outputs = { + 'BatchMetrics': metrics, + 'AccumMetrics': metrics, + 'AccumStatesInfo': states + } + + def test_check_output(self): + self.check_output() + + +class TestPrecisionRecallOp_2(OpTest): + def setUp(self): + self.op_type = "precision_recall" + ins_num = 64 + class_num = 10 + predictions = np.random.uniform(0, 1.0, + (ins_num, class_num)).astype('float32') + weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + predictions = np.random.random((ins_num, class_num)).astype('float32') + labels = np.random.choice(xrange(class_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + states = np.random.randint(0, 30, (class_num, 4)).astype('float32') + + accum_states = get_states(predictions, labels, weights) + batch_metrics = compute_metrics(accum_states) + accum_states += states + accum_metrics = compute_metrics(accum_states) + + self.inputs = { + 'Predictions': predictions, + 'Labels': labels, + 'Weights': weights, + 'StatesInfo': states + } + + self.outputs = { + 'BatchMetrics': batch_metrics, + 'AccumMetrics': accum_metrics, + 'AccumStatesInfo': accum_states + } + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() From 97bfc0dfae147f5514251b077eb26a4ed831b890 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 27 Oct 2017 11:05:57 +0800 Subject: [PATCH 3/5] Add comments. --- paddle/operators/precision_recall_op.cc | 50 +++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index 47a16b9461..24246907b1 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -22,7 +22,6 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - // may contains weights and StatesInfo PADDLE_ENFORCE(ctx->HasInput("Predictions"), "Input(Predictions) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), @@ -108,11 +107,54 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { "provided, current state will be accumulated to this state and " "the accumulation state will be as the output state.") .AsDispensable(); - AddOutput("BatchMetrics", ""); - AddOutput("AccumMetrics", ""); - AddOutput("AccumStatesInfo", ""); + AddOutput("BatchMetrics", + "(Tensor, default Tensor), a 1-D tensor with shape {6}." + "This output tensor contains metrics for current batch data." + "The layout is [macro average precision, macro average recall, " + "macro f1 score, micro average precision, micro average recall, " + "micro f1 score]"); + AddOutput("AccumMetrics", + "(Tensor, default Tensor), a 1-D tensor with shape {6}." + "This output tensor contains metrics for accumulated data." + "The layout is [macro average precision, macro average recall, " + "macro f1 score, micro average precision, micro average recall, " + "micro f1 score]"); + AddOutput("AccumStatesInfo", + "(Tensor, default Tensor), a 2-D tensor with shape D x 4, " + "where D is equal to class number. This output tensor contains " + "accumulated state variables used to compute metrics. The layout " + "for each class is [true positives, false positives, " + "true negatives, false negatives]."); AddComment(R"DOC( +When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used +to compute various metrics including: + - macro average precision + - macro average recall + - macro f1 score + - micro average precision + - micro average recall + - micro f1 score + +To compute the above metrics, we need to statistic counts for true positives, +false positives and false negatives. Here count of true negatives is not +necessary, but statisticing it may provide potential usage and the cost is +trivial, so the operator also provides count of true negatives. + +We define state as a 2-D tensor with shape [class number, 4]. Each row of a +state contains statistic variables for corresponding class. Layout of each row +is: TP(true positives), FP(false positives), TN(true negatives), +FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be +calculated by given weight instead of instance count. + +This operator also supports metrics computing for cross-batch situation. To +achieve this, 'Input(StatesInfo)' should be provided. State of current batch +data will be accumulated to 'Input(StatesInfo)' and 'Output(AccumStatesInfo)' +is the accumulation state. + +'Output(BatchMetrics)' is metrics of current batch data while +'Output(AccumStatesInfo)' is metrics of accumulation data. + )DOC"); } }; From d2b10cc0b1b6a3267698f0d63d721ca99dc6ecf6 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 27 Oct 2017 15:18:28 +0800 Subject: [PATCH 4/5] Refine doc and fix data type of metrics. --- paddle/operators/precision_recall_op.cc | 4 ++-- paddle/operators/precision_recall_op.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index 24246907b1..a3f4c07493 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -136,9 +136,9 @@ to compute various metrics including: - micro average recall - micro f1 score -To compute the above metrics, we need to statistic counts for true positives, +To compute the above metrics, we need to do statistics for true positives, false positives and false negatives. Here count of true negatives is not -necessary, but statisticing it may provide potential usage and the cost is +necessary, but counting it may provide potential usage and the cost is trivial, so the operator also provides count of true negatives. We define state as a 2-D tensor with shape [class number, 4]. Each row of a diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h index 3bc638ea44..2e49bc3bb5 100644 --- a/paddle/operators/precision_recall_op.h +++ b/paddle/operators/precision_recall_op.h @@ -42,8 +42,8 @@ class PrecisionRecallKernel : public framework::OpKernel { const int* labels_data = in1->data(); const T* weights_data = in2 ? in2->data() : nullptr; const T* states_data = in3 ? in3->data() : nullptr; - T* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); - T* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); + double* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); + double* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); out2->mutable_data(ctx.GetPlace()); auto accum_states = EigenMatrix::From(*out2); accum_states.setZero(); @@ -121,7 +121,7 @@ class PrecisionRecallKernel : public framework::OpKernel { } protected: - void ComputeMetrics(const T* states_data, T* metrics_data, + void ComputeMetrics(const T* states_data, double* metrics_data, size_t state_var_num, size_t class_dim) const { T total_tp_count = 0; T total_fp_count = 0; From 970613fc152b77a4fa76876c1fb21fc8473affaa Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 1 Nov 2017 23:23:42 +0800 Subject: [PATCH 5/5] Refine and follow comments. --- paddle/operators/precision_recall_op.cc | 62 ++++++------ paddle/operators/precision_recall_op.h | 54 +++++------ .../tests/test_precision_recall_op.py | 97 ++++++++++--------- 3 files changed, 115 insertions(+), 98 deletions(-) diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index a3f4c07493..39da1e0bf8 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -22,8 +22,10 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Predictions"), - "Input(Predictions) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("MaxProbs"), + "Input(MaxProbs) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input(Indices) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), "Input(Labels) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"), @@ -33,34 +35,36 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"), "Output(AccumStatesInfo) should not be null."); - auto predictions_dims = ctx->GetInputDim("Predictions"); + int64_t cls_num = + static_cast(ctx->Attrs().Get("class_number")); + auto max_probs_dims = ctx->GetInputDim("MaxProbs"); auto labels_dims = ctx->GetInputDim("Labels"); + PADDLE_ENFORCE_EQ(max_probs_dims[1], 1, + "Each instance contains one max probability, so the " + "shape of Input(MaxProbs) should be [batch_size, 1]."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims, + "The shape of Input(Indices) should be [batch_size, 1]."); + PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0], + "The 1st dimension of Input(MaxProbs) and " + "Input(Labels) both are batch_size and the shape should " + "be the same."); + PADDLE_ENFORCE_EQ(labels_dims[1], 1, + "The 2nd dimension of Input(Labels) contains instance " + "label and the shape should be equal to 1."); if (ctx->HasInput("Weights")) { auto weights_dims = ctx->GetInputDim("Weights"); PADDLE_ENFORCE_EQ(weights_dims, - framework::make_ddim({predictions_dims[0], 1}), + framework::make_ddim({max_probs_dims[0], 1}), "The shape of Input(Weights) should be " "[batch_size, 1]."); } if (ctx->HasInput("StatesInfo")) { auto states_dims = ctx->GetInputDim("StatesInfo"); - PADDLE_ENFORCE_EQ(states_dims, - framework::make_ddim({predictions_dims[1], 4}), + PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}), "The shape of Input(StatesInfo) should be " "[class_number, 4]."); } - PADDLE_ENFORCE_EQ(predictions_dims[0], labels_dims[0], - "The 1st dimension of Input(Predictions) and " - "Input(Labels) both are batch_size and the shape should " - "be the same."); - PADDLE_ENFORCE_EQ(labels_dims[1], 1, - "The 2nd dimension of Input(Labels) " - "contains instance label and the shape should be equal " - "to 1"); - PADDLE_ENFORCE_GE(predictions_dims[1], 1, - "The shape of Input(Predictions)'s 2nd dimension is " - "equal to class number and should be at least 1."); // Layouts of BatchMetrics and AccumMetrics both are: // [ @@ -72,13 +76,13 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { // Shape of AccumStatesInfo is [class_number, 4] // The layout of each row is: // [ TP, FP, TN, FN ] - ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); + ctx->SetOutputDim("AccumStatesInfo", {cls_num, 4}); } protected: framework::DataType IndicateDataType( const framework::ExecutionContext &ctx) const override { - return framework::ToDataType(ctx.Input("Predictions")->type()); + return framework::ToDataType(ctx.Input("MaxProbs")->type()); } }; @@ -87,11 +91,15 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { PrecisionRecallOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Predictions", - "(Tensor, default Tensor), a 2-D tensor with shape N x D, " - "where N is the batch size and D is the number of classes. " - "Each row contains probabilities for an instance which computed " - "by the previous operator."); + AddInput("MaxProbs", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. Each row contains the max probability " + "of an instance which computed by the previous top_k (k=1) " + "operator."); + AddInput("Indices", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. Each row contains the corresponding " + "index which computed by the previous top_k (k=1) operator."); AddInput("Labels", "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " "where N is the batch size. Each element is a label and the " @@ -125,9 +133,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { "accumulated state variables used to compute metrics. The layout " "for each class is [true positives, false positives, " "true negatives, false negatives]."); - + AddAttr("class_number", "Number of classes to be evaluated."); AddComment(R"DOC( -When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used +When given 'Input(Indices)' and 'Input(Labels)', this operator can be used to compute various metrics including: - macro average precision - macro average recall @@ -141,7 +149,7 @@ false positives and false negatives. Here count of true negatives is not necessary, but counting it may provide potential usage and the cost is trivial, so the operator also provides count of true negatives. -We define state as a 2-D tensor with shape [class number, 4]. Each row of a +We define state as a 2-D tensor with shape [class_number, 4]. Each row of a state contains statistic variables for corresponding class. Layout of each row is: TP(true positives), FP(false positives), TN(true negatives), FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h index 2e49bc3bb5..4a871ce674 100644 --- a/paddle/operators/precision_recall_op.h +++ b/paddle/operators/precision_recall_op.h @@ -30,7 +30,7 @@ template class PrecisionRecallKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in0 = ctx.Input("Predictions"); + auto* in0 = ctx.Input("Indices"); auto* in1 = ctx.Input("Labels"); auto* in2 = ctx.Input("Weights"); auto* in3 = ctx.Input("StatesInfo"); @@ -38,8 +38,9 @@ class PrecisionRecallKernel : public framework::OpKernel { auto* out1 = ctx.Output("AccumMetrics"); auto* out2 = ctx.Output("AccumStatesInfo"); - const T* predictions_data = in0->data(); + const int* ids_data = in0->data(); const int* labels_data = in1->data(); + size_t cls_num = static_cast(ctx.Attr("class_number")); const T* weights_data = in2 ? in2->data() : nullptr; const T* states_data = in3 ? in3->data() : nullptr; double* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); @@ -50,43 +51,42 @@ class PrecisionRecallKernel : public framework::OpKernel { T* accum_states_data = out2->data(); size_t sample_num = in0->dims()[0]; - size_t class_dim = in0->dims()[1]; size_t state_var_num = 4; // TP FP TN FN // get states info for current batch for (size_t i = 0; i < sample_num; ++i) { - size_t max_idx = 0; - T max_val = predictions_data[i * class_dim]; - for (size_t j = 1; j < class_dim; ++j) { - if (max_val < predictions_data[i * class_dim + j]) { - max_idx = j; - max_val = predictions_data[i * class_dim + j]; - } - } + size_t idx = ids_data[i]; + size_t label = labels_data[i]; + + PADDLE_ENFORCE(idx >= 0 && idx < cls_num, + "Class index of each instance should be in " + "[0, class_number)."); + PADDLE_ENFORCE(label >= 0 && label < cls_num, + "Label of each instance should be in [0, class_number)."); T w = weights_data ? weights_data[i] : 1.0; - if (max_idx == labels_data[i]) { - accum_states_data[max_idx * state_var_num + TP] += w; - for (size_t j = 0; j < class_dim; ++j) { + if (idx == label) { + accum_states_data[idx * state_var_num + TP] += w; + for (size_t j = 0; j < cls_num; ++j) { accum_states_data[j * state_var_num + TN] += w; } - accum_states_data[max_idx * state_var_num + TN] -= w; + accum_states_data[idx * state_var_num + TN] -= w; } else { - accum_states_data[labels_data[i] * state_var_num + FN] += w; - accum_states_data[max_idx * state_var_num + FP] += w; - for (size_t j = 0; j < class_dim; ++j) { + accum_states_data[label * state_var_num + FN] += w; + accum_states_data[idx * state_var_num + FP] += w; + for (size_t j = 0; j < cls_num; ++j) { accum_states_data[j * state_var_num + TN] += w; } - accum_states_data[max_idx * state_var_num + TN] -= w; - accum_states_data[labels_data[i] * state_var_num + TN] -= w; + accum_states_data[idx * state_var_num + TN] -= w; + accum_states_data[label * state_var_num + TN] -= w; } } ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num, - class_dim); + cls_num); if (states_data) { - for (size_t i = 0; i < class_dim; ++i) { + for (size_t i = 0; i < cls_num; ++i) { for (size_t j = 0; j < state_var_num; ++j) { size_t idx = i * state_var_num + j; accum_states_data[idx] += states_data[idx]; @@ -95,7 +95,7 @@ class PrecisionRecallKernel : public framework::OpKernel { } ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num, - class_dim); + cls_num); } // expose to be reused @@ -122,14 +122,14 @@ class PrecisionRecallKernel : public framework::OpKernel { protected: void ComputeMetrics(const T* states_data, double* metrics_data, - size_t state_var_num, size_t class_dim) const { + size_t state_var_num, size_t cls_num) const { T total_tp_count = 0; T total_fp_count = 0; T total_fn_count = 0; T macro_avg_precision = 0.0; T macro_avg_recall = 0.0; - for (size_t i = 0; i < class_dim; ++i) { + for (size_t i = 0; i < cls_num; ++i) { T tp_count = states_data[i * state_var_num + TP]; T fp_count = states_data[i * state_var_num + FP]; T fn_count = states_data[i * state_var_num + FN]; @@ -139,8 +139,8 @@ class PrecisionRecallKernel : public framework::OpKernel { macro_avg_precision += CalcPrecision(tp_count, fp_count); macro_avg_recall += CalcRecall(tp_count, fn_count); } - macro_avg_precision /= class_dim; - macro_avg_recall /= class_dim; + macro_avg_precision /= cls_num; + macro_avg_recall /= cls_num; T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall); T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); diff --git a/python/paddle/v2/framework/tests/test_precision_recall_op.py b/python/paddle/v2/framework/tests/test_precision_recall_op.py index 33efd717d1..d3dbdb6e2a 100644 --- a/python/paddle/v2/framework/tests/test_precision_recall_op.py +++ b/python/paddle/v2/framework/tests/test_precision_recall_op.py @@ -21,45 +21,44 @@ def calc_f1_score(precision, recall): return 0.0 -def get_states(predictions, labels, weights=None): - ins_num = predictions.shape[0] - class_num = predictions.shape[1] +def get_states(idxs, labels, cls_num, weights=None): + ins_num = idxs.shape[0] # TP FP TN FN - states = np.zeros((class_num, 4)).astype('float32') + states = np.zeros((cls_num, 4)).astype('float32') for i in xrange(ins_num): w = weights[i] if weights is not None else 1.0 - max_idx = np.argmax(predictions[i]) - if max_idx == labels[i][0]: - states[max_idx][0] += w - for j in xrange(class_num): + idx = idxs[i][0] + label = labels[i][0] + if idx == label: + states[idx][0] += w + for j in xrange(cls_num): states[j][2] += w - states[max_idx][2] -= w + states[idx][2] -= w else: - states[labels[i][0]][3] += w - states[max_idx][1] += w - for j in xrange(class_num): + states[label][3] += w + states[idx][1] += w + for j in xrange(cls_num): states[j][2] += w - states[labels[i][0]][2] -= w - states[max_idx][2] -= w + states[label][2] -= w + states[idx][2] -= w return states -def compute_metrics(states): - class_num = states.shape[0] +def compute_metrics(states, cls_num): total_tp_count = 0.0 total_fp_count = 0.0 total_fn_count = 0.0 macro_avg_precision = 0.0 macro_avg_recall = 0.0 - for i in xrange(class_num): + for i in xrange(cls_num): total_tp_count += states[i][0] total_fp_count += states[i][1] total_fn_count += states[i][3] macro_avg_precision += calc_precision(states[i][0], states[i][1]) macro_avg_recall += calc_recall(states[i][0], states[i][3]) metrics = [] - macro_avg_precision /= class_num - macro_avg_recall /= class_num + macro_avg_precision /= cls_num + macro_avg_recall /= cls_num metrics.append(macro_avg_precision) metrics.append(macro_avg_recall) metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall)) @@ -75,15 +74,18 @@ class TestPrecisionRecallOp_0(OpTest): def setUp(self): self.op_type = "precision_recall" ins_num = 64 - class_num = 10 - predictions = np.random.uniform(0, 1.0, - (ins_num, class_num)).astype('float32') - labels = np.random.choice(xrange(class_num), ins_num).reshape( + cls_num = 10 + max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + idxs = np.random.choice(xrange(cls_num), ins_num).reshape( (ins_num, 1)).astype('int32') - states = get_states(predictions, labels) - metrics = compute_metrics(states) + labels = np.random.choice(xrange(cls_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + states = get_states(idxs, labels, cls_num) + metrics = compute_metrics(states, cls_num) + + self.attrs = {'class_number': cls_num} - self.inputs = {'Predictions': predictions, 'Labels': labels} + self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels} self.outputs = { 'BatchMetrics': metrics, @@ -99,18 +101,22 @@ class TestPrecisionRecallOp_1(OpTest): def setUp(self): self.op_type = "precision_recall" ins_num = 64 - class_num = 10 - predictions = np.random.uniform(0, 1.0, - (ins_num, class_num)).astype('float32') + cls_num = 10 + max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + idxs = np.random.choice(xrange(cls_num), ins_num).reshape( + (ins_num, 1)).astype('int32') weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - predictions = np.random.random((ins_num, class_num)).astype('float32') - labels = np.random.choice(xrange(class_num), ins_num).reshape( + labels = np.random.choice(xrange(cls_num), ins_num).reshape( (ins_num, 1)).astype('int32') - states = get_states(predictions, labels, weights) - metrics = compute_metrics(states) + states = get_states(idxs, labels, cls_num, weights) + metrics = compute_metrics(states, cls_num) + + self.attrs = {'class_number': cls_num} + self.inputs = { - 'Predictions': predictions, + 'MaxProbs': max_probs, + 'Indices': idxs, 'Labels': labels, 'Weights': weights } @@ -129,22 +135,25 @@ class TestPrecisionRecallOp_2(OpTest): def setUp(self): self.op_type = "precision_recall" ins_num = 64 - class_num = 10 - predictions = np.random.uniform(0, 1.0, - (ins_num, class_num)).astype('float32') + cls_num = 10 + max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + idxs = np.random.choice(xrange(cls_num), ins_num).reshape( + (ins_num, 1)).astype('int32') weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - predictions = np.random.random((ins_num, class_num)).astype('float32') - labels = np.random.choice(xrange(class_num), ins_num).reshape( + labels = np.random.choice(xrange(cls_num), ins_num).reshape( (ins_num, 1)).astype('int32') - states = np.random.randint(0, 30, (class_num, 4)).astype('float32') + states = np.random.randint(0, 30, (cls_num, 4)).astype('float32') - accum_states = get_states(predictions, labels, weights) - batch_metrics = compute_metrics(accum_states) + accum_states = get_states(idxs, labels, cls_num, weights) + batch_metrics = compute_metrics(accum_states, cls_num) accum_states += states - accum_metrics = compute_metrics(accum_states) + accum_metrics = compute_metrics(accum_states, cls_num) + + self.attrs = {'class_number': cls_num} self.inputs = { - 'Predictions': predictions, + 'MaxProbs': max_probs, + 'Indices': idxs, 'Labels': labels, 'Weights': weights, 'StatesInfo': states