|
|
@ -50,7 +50,11 @@ class AccuracyKernel : public framework::OpKernel<T> {
|
|
|
|
int num_correct = 0;
|
|
|
|
int num_correct = 0;
|
|
|
|
// assume inference is already the topk of the output
|
|
|
|
// assume inference is already the topk of the output
|
|
|
|
for (size_t i = 0; i < num_samples; ++i) {
|
|
|
|
for (size_t i = 0; i < num_samples; ++i) {
|
|
|
|
PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0");
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
|
|
|
label_data[i], 0,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"label of AccuracyOp must >= 0, But received label[%d] is %d", i,
|
|
|
|
|
|
|
|
label_data[i]));
|
|
|
|
for (size_t j = 0; j < class_dim; ++j) {
|
|
|
|
for (size_t j = 0; j < class_dim; ++j) {
|
|
|
|
if (indices_data[i * class_dim + j] == label_data[i]) {
|
|
|
|
if (indices_data[i * class_dim + j] == label_data[i]) {
|
|
|
|
++num_correct;
|
|
|
|
++num_correct;
|
|
|
|