!9664 add skip to ImageClassificationRunner when returning NaN and fix bug for small-size ablation.

From: @yuhanshi
Reviewed-by: @wuxuejian,@wenkai_dist
Signed-off-by: @wenkai_dist
pull/9664/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e00570f0a8

@ -418,29 +418,24 @@ class ImageClassificationRunner:
inputs, labels, _ = self._unpack_next_element(next_element)
for idx, inp in enumerate(inputs):
inp = _EXPAND_DIMS(inp, 0)
saliency_dict = saliency_dict_lst[idx]
for label, saliency in saliency_dict.items():
if isinstance(benchmarker, Localization):
_, _, bboxes = self._unpack_next_element(next_element, True)
if label in labels[idx]:
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
saliency=saliency)
if np.any(res == np.nan):
res = np.zeros_like(res)
if isinstance(benchmarker, LabelAgnosticMetric):
res = benchmarker.evaluate(explainer, inp)
benchmarker.aggregate(res)
else:
saliency_dict = saliency_dict_lst[idx]
for label, saliency in saliency_dict.items():
if isinstance(benchmarker, Localization):
_, _, bboxes = self._unpack_next_element(next_element, True)
if label in labels[idx]:
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
saliency=saliency)
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelSensitiveMetric):
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelSensitiveMetric):
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
if np.any(res == np.nan):
res = np.zeros_like(res)
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelAgnosticMetric):
res = benchmarker.evaluate(explainer, inp)
if np.any(res == np.nan):
res = np.zeros_like(res)
benchmarker.aggregate(res)
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
def _verify_data(self):
"""Verify dataset and labels."""

@ -382,8 +382,6 @@ class Faithfulness(LabelSensitiveMetric):
perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument
perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant
num_perturb_pixel_per_step = None # number of pixels for each perturbation step
num_perturb_steps = 100 # separate the perturbation progress in to 100 steps.
base_value = 0.0 # the pixel value set for the perturbed pixels
check_value_type("activation_fn", activation_fn, nn.Cell)
@ -395,8 +393,6 @@ class Faithfulness(LabelSensitiveMetric):
self._faithfulness_helper = method(
perturb_percent=perturb_percent,
perturb_method=perturb_method,
perturb_pixel_per_step=num_perturb_pixel_per_step,
num_perturbations=num_perturb_steps,
base_value=base_value
)

@ -15,6 +15,7 @@
"""Base class for XAI metrics."""
import copy
import math
from typing import Callable
import numpy as np
@ -88,11 +89,12 @@ class LabelAgnosticMetric(AttributionMetric):
Return:
float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned.
"""
if not self._global_results:
return 0.0
results_sum = sum(self._global_results)
count = len(self._global_results)
return results_sum / count
result_sum, count = 0, 0
for res in self._global_results:
if math.isfinite(res):
result_sum += res
count += 1
return 0. if count == 0 else result_sum / count
def aggregate(self, result):
"""Aggregate single evaluation result to global results."""
@ -100,7 +102,7 @@ class LabelAgnosticMetric(AttributionMetric):
self._global_results.append(result)
elif isinstance(result, (ms.Tensor, np.ndarray)):
result = format_tensor_to_ndarray(result)
self._global_results.append(float(result))
self._global_results.extend([float(res) for res in result.reshape(-1)])
else:
raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
@ -130,10 +132,12 @@ class LabelSensitiveMetric(AttributionMetric):
@property
def num_labels(self):
"""Number of labels used in evaluation."""
return self._num_labels
@staticmethod
def _verify_params(num_labels):
"""Checks whether num_labels is valid."""
check_value_type("num_labels", num_labels, int)
if num_labels < 1:
raise ValueError("Argument num_labels must be parsed with a integer > 0.")
@ -147,17 +151,19 @@ class LabelSensitiveMetric(AttributionMetric):
target_np = format_tensor_to_ndarray(targets)
if len(target_np) > 1:
raise ValueError("One result can not be aggreated to multiple targets.")
else:
result_np = format_tensor_to_ndarray(result)
elif isinstance(result, (ms.Tensor, np.ndarray)):
result_np = format_tensor_to_ndarray(result).reshape(-1)
if isinstance(targets, int):
for res in result_np:
self._global_results[targets].append(float(res))
else:
target_np = format_tensor_to_ndarray(targets)
target_np = format_tensor_to_ndarray(targets).reshape(-1)
if len(target_np) != len(result_np):
raise ValueError("Length of result does not match with length of targets.")
for tar, res in zip(target_np, result_np):
self._global_results[int(tar)].append(float(res))
else:
raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
def reset(self):
"""Resets global_result."""
@ -168,16 +174,18 @@ class LabelSensitiveMetric(AttributionMetric):
"""
Get the class performances by global result.
Returns:
(:class:`np.ndarray`): :attr:`num_labels`-dimensional vector
containing per-class performance.
(:class:`list`): a list of performances where each value is the average score of specific class.
"""
count = np.array(
[len(self._global_results[i]) for i in range(self._num_labels)])
result_sum = np.array(
[sum(self._global_results[i]) for i in range(self._num_labels)])
return result_sum / count.clip(min=1)
results_on_labels = []
for label_id in range(self._num_labels):
sum_of_label, count_of_label = 0, 0
for res in self._global_results[label_id]:
if math.isfinite(res):
sum_of_label += res
count_of_label += 1
results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label)
return results_on_labels
@property
def performance(self):
@ -187,13 +195,13 @@ class LabelSensitiveMetric(AttributionMetric):
Returns:
(:class:`float`): mean performance.
"""
count = sum(
[len(self._global_results[i]) for i in range(self._num_labels)])
result_sum = sum(
[sum(self._global_results[i]) for i in range(self._num_labels)])
if count == 0:
return 0
return result_sum / count
result_sum, count = 0, 0
for label_id in range(self._num_labels):
for res in self._global_results[label_id]:
if math.isfinite(res):
result_sum += res
count += 1
return 0. if count == 0 else result_sum / count
def get_results(self):
"""Global result of the metric can be return"""

@ -122,8 +122,8 @@ class Robustness(LabelSensitiveMetric):
perturbations.append(perturbation_on_single_sample)
perturbations = np.vstack(perturbations)
perturbations_saliency = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy()
sensitivity = np.sum((perturbations_saliency - saliency_np) ** 2,
axis=tuple(range(1, len(saliency_np.shape))))
sensitivity = np.sqrt(np.sum((perturbations_saliency - saliency_np) ** 2,
axis=tuple(range(1, len(saliency_np.shape)))))
sensitivities.append(sensitivity)
sensitivities = np.stack(sensitivities, axis=-1)
max_sensitivity = np.max(sensitivities, axis=1) / norm

@ -89,7 +89,7 @@ class Ablation:
class AblationWithSaliency(Ablation):
"""
Perturbation generator to generate perturbations for a given array.
Perturbation generator to generate perturbations w.r.t a given saliency map.
Args:
perturb_percent (float): percentage of pixels to perturb
@ -143,28 +143,20 @@ class AblationWithSaliency(Ablation):
"""
batch_size = saliency.shape[0]
expected_num_dim = len(saliency.shape) + 1
has_channel = num_channels is not None
num_channels = 1 if num_channels is None else num_channels
if has_channel:
saliency = saliency.mean(axis=1)
saliency_rank = rank_pixels(saliency, descending=True)
num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:])
if self._pixel_per_step:
pixel_per_step = self._pixel_per_step
num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step)
elif self._num_perturbations:
pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations)
num_perturbations = self._num_perturbations
else:
raise ValueError("Must provide either pixel_per_step or num_perturbations.")
pixel_per_step, num_perturbations = self._check_and_format_perturb_param(num_pixels)
masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]),
dtype=np.bool)
# If the perturbation is added accumulately, the factor should be 0 to preserve the low bound of indexing.
factor = 0 if self._is_accumulate else 1
for i in range(batch_size):
@ -176,7 +168,23 @@ class AblationWithSaliency(Ablation):
up_bound += pixel_per_step
masks = masks if has_channel else np.squeeze(masks, axis=2)
return masks
def _check_and_format_perturb_param(self, num_pixels):
"""
Check whether the self._pixel_per_step and self._num_perturbation is valid. If the parameters are unreasonable,
this function will try to reassign the parameters and raise ValueError when reassignment is failed.
"""
if self._pixel_per_step:
pixel_per_step = self._pixel_per_step
num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step)
elif self._num_perturbations:
pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations)
num_perturbations = self._num_perturbations
else:
# If neither pixel_per_step or num_perturbations is provided, num_perturbations is determined by the square
# root of product from the spatial size of saliency map.
num_perturbations = math.floor(np.sqrt(num_pixels))
pixel_per_step = math.floor(num_pixels * self._perturb_percent / num_perturbations)
if len(masks.shape) == expected_num_dim:
return masks
raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect {expected_num_dim}-dim.')
return pixel_per_step, num_perturbations

@ -14,8 +14,9 @@
# ============================================================================
"""Occlusion explainer."""
from typing import Tuple
import numpy as np
from numpy.lib.stride_tricks import as_strided
import mindspore as ms
import mindspore.nn as nn
@ -25,24 +26,17 @@ from .replacement import Constant
from ...._utils import abs_max
def _generate_patches(array, window_size, stride):
"""View as windows."""
if not isinstance(array, np.ndarray):
raise TypeError("`array` must be a numpy ndarray")
arr_shape = np.array(array.shape)
window_size = np.array(window_size, dtype=arr_shape.dtype)
slices = tuple(slice(None, None, st) for st in stride)
window_strides = np.array(array.strides)
def _generate_patches(array, window_size: Tuple, strides: Tuple):
"""Generate patches from image w.r.t given window_size and strides."""
window_strides = array.strides
slices = tuple(slice(None, None, stride) for stride in strides)
indexing_strides = array[slices].strides
win_indices_shape = (((np.array(array.shape) - np.array(window_size)) // np.array(stride)) + 1)
new_shape = tuple(list(win_indices_shape) + list(window_size))
strides = tuple(list(indexing_strides) + list(window_strides))
win_indices_shape = (np.array(array.shape) - np.array(window_size)) // np.array(strides) + 1
patches = as_strided(array, shape=new_shape, strides=strides)
patches_shape = tuple(win_indices_shape) + window_size
strides_in_memory = indexing_strides + window_strides
patches = np.lib.stride_tricks.as_strided(array, shape=patches_shape, strides=strides_in_memory, writeable=False)
patches = patches.reshape((-1,) + window_size)
return patches
@ -159,7 +153,7 @@ class Occlusion(PerturbationAttribution):
total_dim = np.prod(inputs.shape[1:]).item()
template = np.arange(total_dim).reshape(inputs.shape[1:])
indices = _generate_patches(template, window_size, strides)
num_perturbations = indices.reshape((-1,) + window_size).shape[0]
num_perturbations = indices.shape[0]
indices = indices.reshape(num_perturbations, -1)
mask = np.zeros((num_perturbations, total_dim), dtype=np.bool)

Loading…
Cancel
Save