From 640f7194b9a10187b96cfc14fdbea6b59740dd70 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Wed, 18 Nov 2020 17:04:59 +0800 Subject: [PATCH] Fix full batch error --- .../official/recommend/wide_and_deep/src/metrics.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/model_zoo/official/recommend/wide_and_deep/src/metrics.py b/model_zoo/official/recommend/wide_and_deep/src/metrics.py index c89e948405..088ecb90b4 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/metrics.py +++ b/model_zoo/official/recommend/wide_and_deep/src/metrics.py @@ -18,9 +18,7 @@ Area under cure metric """ from sklearn.metrics import roc_auc_score -from mindspore import context from mindspore.nn.metrics import Metric -from mindspore.communication.management import get_rank, get_group_size class AUCMetric(Metric): """ @@ -30,7 +28,6 @@ class AUCMetric(Metric): def __init__(self): super(AUCMetric, self).__init__() self.clear() - self.full_batch = context.get_auto_parallel_context("full_batch") def clear(self): """Clear the internal evaluation result.""" @@ -42,13 +39,7 @@ class AUCMetric(Metric): all_predict = inputs[1].asnumpy().flatten().tolist() # predict all_label = inputs[2].asnumpy().flatten().tolist() # label self.pred_probs.extend(all_predict) - if self.full_batch: - rank_id = get_rank() - group_size = get_group_size() - gap = len(all_label) // group_size - self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap]) - else: - self.true_labels.extend(all_label) + self.true_labels.extend(all_label) def eval(self): if len(self.true_labels) != len(self.pred_probs):