diff --git a/mindspore/explainer/benchmark/_attribution/class_sensitivity.py b/mindspore/explainer/benchmark/_attribution/class_sensitivity.py index bfd983eafc..80f013ebe7 100644 --- a/mindspore/explainer/benchmark/_attribution/class_sensitivity.py +++ b/mindspore/explainer/benchmark/_attribution/class_sensitivity.py @@ -22,15 +22,14 @@ from ..._utils import calc_correlation class ClassSensitivity(LabelAgnosticMetric): - r""" + """ Class sensitivity metric used to evaluate attribution-based explanations. Reasonable atrribution-based explainers are expected to generate distinct saliency maps for different labels, - especially for labels of highest confidence and low confidence. Class sensitivity evaluates the explainer through + especially for labels of highest confidence and low confidence. ClassSensitivity evaluates the explainer through computing the correlation between saliency maps of highest-confidence and lowest-confidence labels. Explainer with better class sensitivity will receive lower correlation score. To make the evaluation results intuitive, the returned score will take negative on correlation and normalize. - """ def evaluate(self, explainer, inputs): @@ -46,12 +45,18 @@ class ClassSensitivity(LabelAgnosticMetric): Examples: >>> import mindspore as ms + >>> from mindspore.explainer.benchmark import ClassSensitivity >>> from mindspore.explainer.explanation import Gradient - >>> model = resnet(10) - >>> gradient = Gradient(model) - >>> x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) + >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net + >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. + >>> network = resnet50(10) + >>> param_dict = load_checkpoint("resnet50.ckpt") + >>> load_param_into_net(network, param_dict) + >>> # prepare your explainer to be evaluated, e.g., Gradient. + >>> gradient = Gradient(network) + >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) >>> class_sensitivity = ClassSensitivity() - >>> res = class_sensitivity.evaluate(gradient, x) + >>> res = class_sensitivity.evaluate(gradient, input_x) """ self._check_evaluate_param(explainer, inputs) diff --git a/mindspore/explainer/benchmark/_attribution/robustness.py b/mindspore/explainer/benchmark/_attribution/robustness.py index 280043856c..927cc6523b 100644 --- a/mindspore/explainer/benchmark/_attribution/robustness.py +++ b/mindspore/explainer/benchmark/_attribution/robustness.py @@ -32,6 +32,7 @@ class Robustness(LabelSensitiveMetric): num_labels (int): Number of classes in the dataset. Examples: + >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. >>> from mindspore.explainer.benchmark import Robustness >>> num_labels = 100 >>> robustness = Robustness(num_labels) @@ -41,7 +42,7 @@ class Robustness(LabelSensitiveMetric): super().__init__(num_labels) self._perturb = RandomPerturb() - self._num_perturbations = 100 # number of perturbations used in evaluation + self._num_perturbations = 10 # number of perturbations used in evaluation self._threshold = 0.1 # threshold to generate perturbation self._activation_fn = activation_fn @@ -68,12 +69,17 @@ class Robustness(LabelSensitiveMetric): ValueError: If batch_size is larger than 1. Examples: - >>> # init an explainer, the network should contain the output activation function. >>> from mindspore.explainer.explanation import Gradient >>> from mindspore.explainer.benchmark import Robustness + >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net + >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. + >>> network = resnet50(10) + >>> param_dict = load_checkpoint("resnet50.ckpt") + >>> load_param_into_net(network, param_dict) + >>> # prepare your explainer to be evaluated, e.g., Gradient. >>> gradient = Gradient(network) >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) - >>> target_label = 5 + >>> target_label = ms.Tensor([0], ms.int32) >>> robustness = Robustness(num_labels=10) >>> res = robustness.evaluate(gradient, input_x, target_label) """ @@ -84,39 +90,48 @@ class Robustness(LabelSensitiveMetric): inputs_np = inputs.asnumpy() if isinstance(targets, int): - targets = ms.Tensor(targets, ms.int32) + targets = ms.Tensor([targets], ms.int32) if saliency is None: saliency = explainer(inputs, targets) saliency_np = saliency.asnumpy() + norm = np.sqrt(np.sum(np.square(saliency_np), axis=tuple(range(1, len(saliency_np.shape))))) - if norm == 0: + if (norm == 0).any(): log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.') - return np.array([np.nan]) - - perturbations = [] - for sample in inputs_np: - sample = np.expand_dims(sample, axis=0) - perturbations_per_input = [] - for _ in range(self._num_perturbations): - perturbation = self._perturb(sample) - perturbations_per_input.append(perturbation) - perturbations_per_input = np.vstack(perturbations_per_input) - perturbations.append(perturbations_per_input) - perturbations = np.stack(perturbations, axis=0) - - perturbations = np.reshape(perturbations, (-1,) + inputs_np.shape[1:]) - perturbations = ms.Tensor(perturbations, ms.float32) - - repeated_targets = np.repeat(targets.asnumpy(), repeats=self._num_perturbations, axis=0) - repeated_targets = ms.Tensor(repeated_targets, ms.int32) - saliency_of_perturbations = explainer(perturbations, repeated_targets) - perturbations_saliency = saliency_of_perturbations.asnumpy() - - repeated_saliency = np.repeat(saliency_np, repeats=self._num_perturbations, axis=0) - - sensitivities = np.sum((repeated_saliency - perturbations_saliency) ** 2, - axis=tuple(range(1, len(repeated_saliency.shape)))) - - max_sensitivity = np.max(sensitivities.reshape((norm.shape[0], -1)), axis=1) / norm + norm[norm == 0] = np.nan + + model = nn.SequentialCell([explainer.model, self._activation_fn]) + original_outputs = model(inputs).asnumpy() + sensitivities = [] + for _ in range(self._num_perturbations): + perturbations = [] + for j, sample in enumerate(inputs_np): + perturbation_on_single_sample = self._perturb_with_threshold(model, + np.expand_dims(sample, axis=0), + original_outputs[j]) + perturbations.append(perturbation_on_single_sample) + perturbations = np.vstack(perturbations) + perturbations_saliency = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy() + sensitivity = np.sum((perturbations_saliency - saliency_np) ** 2, + axis=tuple(range(1, len(saliency_np.shape)))) + sensitivities.append(sensitivity) + sensitivities = np.stack(sensitivities, axis=-1) + max_sensitivity = np.max(sensitivities, axis=1) / norm robustness_res = 1 / np.exp(max_sensitivity) return robustness_res + + def _perturb_with_threshold(self, model: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: + """ + Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than + the given self._threshold or until the attempt reaches the max_attempt_time. + """ + # the maximum time attempt to get a perturbation with perturb_error low than self._threshold + max_attempt_time = 3 + perturbation = None + for _ in range(max_attempt_time): + perturbation = self._perturb(sample) + perturbation_output = self._activation_fn(model(ms.Tensor(sample, ms.float32))).asnumpy() + perturb_error = np.linalg.norm(original_output - perturbation_output) + if perturb_error <= self._threshold: + return perturbation + return perturbation diff --git a/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py b/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py index 08e3f5d15c..a90d5b06f5 100644 --- a/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py +++ b/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py @@ -14,14 +14,11 @@ # ============================================================================ """Occlusion explainer.""" -import math - import numpy as np from numpy.lib.stride_tricks import as_strided import mindspore as ms import mindspore.nn as nn -from mindspore import Tensor from .ablation import Ablation from .perturbation import PerturbationAttribution from .replacement import Constant @@ -62,8 +59,8 @@ class Occlusion(PerturbationAttribution): network (Cell): Specify the black-box model to be explained. Inputs: - inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. - targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. + - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. + - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. If it is a 1D tensor, its length should be the same as `inputs`. Outputs: @@ -72,13 +69,15 @@ class Occlusion(PerturbationAttribution): Example: >>> from mindspore.explainer.explanation import Occlusion >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net + >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. >>> network = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt") >>> load_param_into_net(network, param_dict) + >>> # initialize Occlusion explainer and pass the pretrained model >>> occlusion = Occlusion(network) - >>> x = Tensor(np.random.rand(1, 3, 224, 224), ms.float32) - >>> label = 1 - >>> saliency = occlusion(x, label) + >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) + >>> label = ms.Tensor([1], ms.int32) + >>> saliency = occlusion(input_x, label) """ def __init__(self, network, activation_fn=nn.Softmax()): @@ -88,62 +87,63 @@ class Occlusion(PerturbationAttribution): self._aggregation_fn = abs_max self._get_replacement = Constant(base_value=0.0) self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. - self._num_per_eval = 32 # number of perturbations each evaluation step. + self._num_per_eval = 2 # number of perturbations generate for each sample per evaluation step. def __call__(self, inputs, targets): """Call function for 'Occlusion'.""" self._verify_data(inputs, targets) - inputs = inputs.asnumpy() - targets = targets.asnumpy() if isinstance(targets, Tensor) else np.array([targets] * inputs.shape[0], np.int) + inputs_np = inputs.asnumpy() + targets_np = targets.asnumpy() if isinstance(targets, ms.Tensor) else np.array([targets], np.int) - # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to - # `(C, 3, 3)` and `(C, 1, 1)` separately. - window_size = tuple( - [inputs.shape[1]] - + [x % self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]]) - strides = tuple( - [inputs.shape[1]] - + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]]) + batch_size = inputs_np.shape[0] + window_size, strides = self._get_window_size_and_strides(inputs_np) model = nn.SequentialCell([self._model, self._activation_fn]) - original_outputs = model(Tensor(inputs, ms.float32)).asnumpy()[np.arange(len(targets)), targets] + original_outputs = model(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np] - total_attribution = np.zeros_like(inputs) - weights = np.ones_like(inputs) - masks = Occlusion._generate_masks(inputs, window_size, strides) + total_attribution = np.zeros_like(inputs_np) + weights = np.ones_like(inputs_np) + masks = Occlusion._generate_masks(inputs_np, window_size, strides) num_perturbations = masks.shape[1] - original_outputs_repeat = np.repeat(original_outputs, repeats=num_perturbations, axis=0) - - reference = self._get_replacement(inputs) - occluded_inputs = self._ablation(inputs, reference, masks) - targets_repeat = np.repeat(targets, repeats=num_perturbations, axis=0) - - occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:])) - if occluded_inputs.shape[0] > self._num_per_eval: - cal_time = math.ceil(occluded_inputs.shape[0] / self._num_per_eval) - occluded_outputs = [] - for i in range(cal_time): - occluded_input = occluded_inputs[i*self._num_per_eval - :min((i+1) * self._num_per_eval, occluded_inputs.shape[0])] - target = targets_repeat[i*self._num_per_eval - :min((i+1) * self._num_per_eval, occluded_inputs.shape[0])] - occluded_output = model(Tensor(occluded_input)).asnumpy()[np.arange(target.shape[0]), target] - occluded_outputs.append(occluded_output) - occluded_outputs = np.concatenate(occluded_outputs) - else: - occluded_outputs = model(Tensor(occluded_inputs)).asnumpy()[np.arange(len(targets_repeat)), targets_repeat] - outputs_diff = original_outputs_repeat - occluded_outputs - outputs_diff = outputs_diff.reshape(inputs.shape[0], -1) - - total_attribution += ( - outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6) - weights += masks.sum(axis=1) - - attribution = self._aggregation_fn(Tensor(total_attribution / weights)) + reference = self._get_replacement(inputs_np) + + count = 0 + while count < num_perturbations: + ith_masks = masks[:, count:min(count+self._num_per_eval, num_perturbations)] + actual_num_eval = ith_masks.shape[1] + num_samples = batch_size * actual_num_eval + occluded_inputs = self._ablation(inputs_np, reference, ith_masks) + occluded_inputs = occluded_inputs.reshape((-1, *inputs_np.shape[1:])) + targets_repeat = np.repeat(targets_np, repeats=actual_num_eval, axis=0) + occluded_outputs = model( + ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat] + original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0) + outputs_diff = original_outputs_repeat - occluded_outputs + total_attribution += ( + outputs_diff.reshape(ith_masks.shape[:2] + (1,) * (len(masks.shape) - 2)) * ith_masks).sum(axis=1) + weights += ith_masks.sum(axis=1) + count += actual_num_eval + attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights, ms.float32)) return attribution + def _get_window_size_and_strides(self, inputs): + """ + Return window_size and strides. + + # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to + # `(C, 3, 3)` and `(C, 1, 1)` separately. Otherwise, the window_size and strides will generated adaptively to + match self._num_sample_per_dim. + """ + window_size = tuple( + [inputs.shape[1]] + + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]]) + strides = tuple( + [inputs.shape[1]] + + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]]) + return window_size, strides + @staticmethod def _generate_masks(inputs, window_size, strides): """Generate masks to perturb contiguous regions.""" diff --git a/mindspore/explainer/explanation/_attribution/attribution.py b/mindspore/explainer/explanation/_attribution/attribution.py index 2b89db582b..b72b840675 100644 --- a/mindspore/explainer/explanation/_attribution/attribution.py +++ b/mindspore/explainer/explanation/_attribution/attribution.py @@ -72,3 +72,6 @@ class Attribution: if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)): raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' 'it should have the same length as inputs.') + elif inputs.shape[0] != 1: + raise ValueError('If targets have type of int, batch_size of inputs should equals 1. Receive batch_size {}' + .format(inputs.shape[0]))