|
|
|
@ -42,8 +42,8 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int* labels_data = in1->data<int>();
|
|
|
|
|
const T* weights_data = in2 ? in2->data<T>() : nullptr;
|
|
|
|
|
const T* states_data = in3 ? in3->data<T>() : nullptr;
|
|
|
|
|
T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* accum_metrics_data = out1->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
double* batch_metrics_data = out0->mutable_data<double>(ctx.GetPlace());
|
|
|
|
|
double* accum_metrics_data = out1->mutable_data<double>(ctx.GetPlace());
|
|
|
|
|
out2->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto accum_states = EigenMatrix<T>::From(*out2);
|
|
|
|
|
accum_states.setZero();
|
|
|
|
@ -121,7 +121,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void ComputeMetrics(const T* states_data, T* metrics_data,
|
|
|
|
|
void ComputeMetrics(const T* states_data, double* metrics_data,
|
|
|
|
|
size_t state_var_num, size_t class_dim) const {
|
|
|
|
|
T total_tp_count = 0;
|
|
|
|
|
T total_fp_count = 0;
|
|
|
|
|