|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Faithfulness."""
|
|
|
|
|
from decimal import Decimal
|
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
@ -147,8 +148,8 @@ class NaiveFaithfulness(_FaithfulnessHelper):
|
|
|
|
|
- faithfulness (np.ndarray): faithfulness score
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if not np.count_nonzero(saliency):
|
|
|
|
|
log.warning("The saliency map is zero everywhere. The correlation will be set to zero.")
|
|
|
|
|
if Decimal(str(saliency.max())) == Decimal(str(saliency.min())):
|
|
|
|
|
log.warning("The saliency map is uniform everywhere. The correlation will be set to zero.")
|
|
|
|
|
correlation = 0
|
|
|
|
|
return np.array([correlation], np.float)
|
|
|
|
|
|
|
|
|
@ -163,6 +164,11 @@ class NaiveFaithfulness(_FaithfulnessHelper):
|
|
|
|
|
predictions = model(perturbations)[:, targets].asnumpy()
|
|
|
|
|
predictions = predictions.reshape(*feature_importance.shape)
|
|
|
|
|
|
|
|
|
|
if Decimal(str(predictions.max())) == Decimal(str(predictions.min())):
|
|
|
|
|
log.warning("The perturbations do not affect the predictions. The correlation will be set to zero.")
|
|
|
|
|
correlation = 0
|
|
|
|
|
return np.array([correlation], np.float)
|
|
|
|
|
|
|
|
|
|
faithfulness = -np.corrcoef(feature_importance, predictions)
|
|
|
|
|
faithfulness = np.diag(faithfulness[:batch_size, batch_size:])
|
|
|
|
|
return faithfulness
|
|
|
|
|