Refine doc and fix data type of metrics.

fix-typo
yangyaming 7 years ago
parent 97bfc0dfae
commit d2b10cc0b1

@ -136,9 +136,9 @@ to compute various metrics including:
- micro average recall
- micro f1 score
To compute the above metrics, we need to statistic counts for true positives,
To compute the above metrics, we need to do statistics for true positives,
false positives and false negatives. Here count of true negatives is not
necessary, but statisticing it may provide potential usage and the cost is
necessary, but counting it may provide potential usage and the cost is
trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class number, 4]. Each row of a

@ -42,8 +42,8 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
const int* labels_data = in1->data<int>();
const T* weights_data = in2 ? in2->data<T>() : nullptr;
const T* states_data = in3 ? in3->data<T>() : nullptr;
T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace());
T* accum_metrics_data = out1->mutable_data<T>(ctx.GetPlace());
double* batch_metrics_data = out0->mutable_data<double>(ctx.GetPlace());
double* accum_metrics_data = out1->mutable_data<double>(ctx.GetPlace());
out2->mutable_data<T>(ctx.GetPlace());
auto accum_states = EigenMatrix<T>::From(*out2);
accum_states.setZero();
@ -121,7 +121,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
}
protected:
void ComputeMetrics(const T* states_data, T* metrics_data,
void ComputeMetrics(const T* states_data, double* metrics_data,
size_t state_var_num, size_t class_dim) const {
T total_tp_count = 0;
T total_fp_count = 0;

Loading…
Cancel
Save