parent
efc2464f6c
commit
06c7c8c80e
@ -0,0 +1,118 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class PrecisionRecallOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
// may contains weights and StatesInfo
|
||||
PADDLE_ENFORCE(ctx->HasInput("Predictions"),
|
||||
"Input(Predictions) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
||||
"Input(Labels) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"),
|
||||
"Output(BatchMetrics) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("AccumMetrics"),
|
||||
"Output(AccumMetrics) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"),
|
||||
"Output(AccumStatesInfo) should not be null.");
|
||||
|
||||
auto predictions_dims = ctx->GetInputDim("Predictions");
|
||||
auto labels_dims = ctx->GetInputDim("Labels");
|
||||
|
||||
if (ctx->HasInput("Weights")) {
|
||||
auto weights_dims = ctx->GetInputDim("Weights");
|
||||
PADDLE_ENFORCE_EQ(weights_dims, {predictions_dims[0], 1},
|
||||
"The shape of Input(Weights) should be "
|
||||
"[batch_size, 1].");
|
||||
}
|
||||
if (ctx->HasInput("StatesInfo")) {
|
||||
auto states_dims = ctx->GetInputDim("StatesInfo");
|
||||
PADDLE_ENFORCE_EQ(states_dims, {predictions_dims[1], 4},
|
||||
"The shape of Input(StatesInfo) should be "
|
||||
"[class_number, 4].");
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(predictions_dims[0], labels_dims[0],
|
||||
"The 1st dimension of Input(Predictions) and "
|
||||
"Input(Labels) both are batch_size and the shape should "
|
||||
"be the same.");
|
||||
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
|
||||
"The 2nd dimension of Input(Labels) "
|
||||
"contains instance label and the shape should be equal "
|
||||
"to 1");
|
||||
PADDLE_ENFORCE_GE(predictions_dims[1], 1,
|
||||
"The shape of Input(Predictions)'s 2nd dimension is "
|
||||
"equal to class number and should be at least 1.");
|
||||
|
||||
// Layouts of BatchMetrics and AccumMetrics both are:
|
||||
// [
|
||||
// macro average precision, macro average recall, macro average F1 score,
|
||||
// micro average precision, micro average recall, micro average F1 score
|
||||
// ]
|
||||
ctx->SetOutputDim("BatchMetrics", {6});
|
||||
ctx->SetOutputDim("AccumMetrics", {6});
|
||||
// Shape of AccumStatesInfo is [class_number, 4]
|
||||
// The layout of each row is:
|
||||
// [ TP, FP, TN, FN ]
|
||||
ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4});
|
||||
}
|
||||
};
|
||||
|
||||
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
PrecisionRecallOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Predictions",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
|
||||
"where N is the batch size and D is the number of classes. "
|
||||
"Each row contains probabilities for an instance which computed "
|
||||
"by the previous operator.");
|
||||
AddInput("Labels",
|
||||
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
|
||||
"where N is the batch size. Each element is a label and the "
|
||||
"value should be in [0, class_number - 1].");
|
||||
AddInput("Weights",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
|
||||
"where N is the batch size. This input is optional. If provided, "
|
||||
"weight of instance would be considered when computing metrics.")
|
||||
.AsDispensable();
|
||||
AddInput("StatesInfo",
|
||||
"(Tensor, default Tensor<int>), a 2-D tensor with shape D x 4, "
|
||||
"where D is the number of classes. This input is optional. If "
|
||||
"provided, current state will be accumulated to this state and "
|
||||
"the accumulation state will be as the output state.")
|
||||
.AsDispensable();
|
||||
|
||||
AddComment(R"DOC(
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp,
|
||||
ops::PrecisionRecallOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
precision_recall,
|
||||
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, int>,
|
||||
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, double>,
|
||||
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, int64_t>,
|
@ -0,0 +1,159 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
enum StateVariable { TP = 0, FP, TN, FN };
|
||||
|
||||
template <typename Place, typename T>
|
||||
class PrecisionRecallKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in0 = ctx.Input<Tensor>("Predictions");
|
||||
auto* in1 = ctx.Input<Tensor>("Labels");
|
||||
auto* in2 = ctx.Input<Tensor>("Weights");
|
||||
auto* in3 = ctx.Input<Tensor>("StatesInfo");
|
||||
auto* out0 = ctx.Output<Tensor>("BatchMetrics");
|
||||
auto* out1 = ctx.Output<Tensor>("AccumMetrics");
|
||||
auto* out2 = ctx.Output<Tensor>("AccumStatesInfo");
|
||||
|
||||
const T* predictions_data = in0->data<T>();
|
||||
const T* labels_data = in1->data<T>();
|
||||
const T* weights_data = in2 ? in2->data<T>() : nullptr;
|
||||
const T* states_data = in3 ? in3->data<T>() : nullptr;
|
||||
T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace());
|
||||
T* accum_metrics_data = out1->mutable_data<T>(ctx.GetPlace());
|
||||
out2->mutable_data<T>(ctx.GetPlace());
|
||||
auto accum_states = EigenMatrix<T>::From(*out2);
|
||||
accum_states.setZero();
|
||||
T* accum_states_data = out2->data<T>(ctx.GetPlace());
|
||||
|
||||
size_t sample_num = in0->dims()[0];
|
||||
size_t class_dim = in0->dims()[1];
|
||||
size_t state_var_num = 4; // TP FP TN FN
|
||||
|
||||
// get states info for current batch
|
||||
for (size_t i = 0; i < sample_num; ++i) {
|
||||
size_t max_idx = 0;
|
||||
T max_val = predictions_data[i * class_dim];
|
||||
for (size_t j = 1; j < class_dim; ++j) {
|
||||
if (max_val < predictions_data[i * class_dim + j]) {
|
||||
max_idx = j;
|
||||
max_val = predictions_data[i * class_dim + j];
|
||||
}
|
||||
}
|
||||
|
||||
T w = weights_data ? weights_data[i] : 1.0;
|
||||
if (max_idx == labels_data[i]) {
|
||||
accum_states_data[max_idx * state_var_num + TP] += w;
|
||||
for (size_t j = 0; j < class_dim; ++j) {
|
||||
accum_states_data[j * state_var_num + TN] += w;
|
||||
}
|
||||
accum_states_data[max_idx * state_var_num + TN] -= w;
|
||||
} else {
|
||||
accum_states_data[labels_data[i] * state_var_num + FN] += w;
|
||||
accum_states_data[max_idx * state_var_num + FP] += w;
|
||||
for (size_t j = 0; j < class_dim; ++j) {
|
||||
accum_states_data[j * state_var_num + TN] += w;
|
||||
}
|
||||
accum_states_data[max_idx * state_var_num + TN] -= w;
|
||||
accum_states_data[labels_data[j] * state_var_num + TN] -= w;
|
||||
}
|
||||
}
|
||||
|
||||
ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num,
|
||||
class_dim);
|
||||
|
||||
if (states_data) {
|
||||
for (size_t i = 0; i < class_dim; ++i) {
|
||||
for (size_t j = 0; j < state_var_num; ++j) {
|
||||
size_t idx = i * state_var_num + j;
|
||||
accum_states_data[idx] += states_data[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num,
|
||||
class_dim);
|
||||
}
|
||||
|
||||
// expose to be reused
|
||||
static inline T CalcPrecision(T tp_count, T fp_count) {
|
||||
if (tp_count > 0.0 || fp_count > 0.0) {
|
||||
return tp_count / (tp_count + fp_count);
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
static inline T CalcRecall(T tp_count, T fn_count) {
|
||||
if (tp_count > 0.0 || fn_count > 0.0) {
|
||||
return tp_count / (tp_count + fn_count);
|
||||
}
|
||||
return 1.0
|
||||
}
|
||||
|
||||
static inline T CalcF1Score(T precision, T recall) {
|
||||
if (precision > 0.0 || recall > 0.0) {
|
||||
return 2 * precision * recall / (precision + recall);
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeMetrics(const T* states_data, T* metrics_data,
|
||||
size_t state_var_num, size_t class_dim) {
|
||||
T total_tp_count = 0;
|
||||
T total_fp_count = 0;
|
||||
T total_fn_count = 0;
|
||||
T macro_avg_precision = 0.0;
|
||||
T macro_avg_recall = 0.0;
|
||||
|
||||
for (size_t i = 0; i < class_dim; ++i) {
|
||||
T tp_count = states_data[i * state_var_num + TP];
|
||||
T fp_count = states_data[i * state_var_num + FP];
|
||||
T fn_count = states_data[i * state_var_num + FN];
|
||||
total_tp_count += tp_count;
|
||||
total_fp_count += fp_count;
|
||||
total_fn_count += fn_count;
|
||||
macro_avg_precision += CalcPrecision(tp_count, fp_count);
|
||||
macro_avg_recall += CalcRecall(tp_count, fn_count);
|
||||
}
|
||||
macro_avg_precision /= class_dim;
|
||||
macro_avg_recall /= class_dim;
|
||||
T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall);
|
||||
|
||||
T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count);
|
||||
T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count);
|
||||
T micro_f1_score = CalcRecall(micro_avg_precision, micro_avg_recall);
|
||||
|
||||
// fill metrics data
|
||||
metrics_data[0] = macro_avg_precision;
|
||||
metrics_data[1] = macro_avg_recall;
|
||||
metrics_data[2] = macro_f1_score;
|
||||
metrics_data[3] = micro_avg_precision;
|
||||
metrics_data[4] = micro_avg_recall;
|
||||
metrics_data[5] = micro_f1_score;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue