!8876 add feature/explainer/ClassSensitivity, Occlusion, Robustness
From: @yuhanshi Reviewed-by: Signed-off-by:pull/8876/MERGE
commit
452cb0dd4e
@ -0,0 +1,73 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Class Sensitivity."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from .metric import LabelAgnosticMetric
|
||||
from ... import _operators as ops
|
||||
from ...explanation._attribution.attribution import Attribution
|
||||
from ..._utils import calc_correlation
|
||||
|
||||
|
||||
class ClassSensitivity(LabelAgnosticMetric):
|
||||
r"""
|
||||
Class sensitivity metric used to evaluate attribution-based explanations.
|
||||
|
||||
Reasonable atrribution-based explainers are expected to generate distinct saliency maps for different labels,
|
||||
especially for labels of highest confidence and low confidence. Class sensitivity evaluates the explainer through
|
||||
computing the correlation between saliency maps of highest-confidence and lowest-confidence labels. Explainer with
|
||||
better class sensitivity will receive lower correlation score. To make the evaluation results intuitive, the
|
||||
returned score will take negative on correlation and normalize.
|
||||
|
||||
"""
|
||||
|
||||
def evaluate(self, explainer: Attribution, inputs: Tensor) -> np.ndarray:
|
||||
"""
|
||||
Evaluate class sensitivity on a single data sample.
|
||||
|
||||
Args:
|
||||
explainer (Attribution): The explainer to be evaluated, see `mindspore.explainer.explanation`.
|
||||
inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, 1D array of shape :math:`(N,)`, result of class sensitivity evaluated on `explainer`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> gradient = Gradient()
|
||||
>>> x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> class_sensitivity = ClassSensitivity()
|
||||
>>> res = class_sensitivity.evaluate(gradient, x)
|
||||
"""
|
||||
self._check_evaluate_param(explainer, inputs)
|
||||
|
||||
outputs = explainer.model(inputs)
|
||||
|
||||
max_confidence_label = ops.argmax(outputs)
|
||||
min_confidence_label = ops.argmin(outputs)
|
||||
|
||||
max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy()
|
||||
min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy()
|
||||
|
||||
correlations = []
|
||||
for i in range(inputs.shape[0]):
|
||||
correlation = calc_correlation(max_confidence_saliency[i].reshape(-1),
|
||||
min_confidence_saliency[i].reshape(-1))
|
||||
normalized_correlation = (-correlation + 1) / 2
|
||||
correlations.append(normalized_correlation)
|
||||
return np.array(correlations, np.float)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,134 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Robustness."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import log
|
||||
from .metric import LabelSensitiveMetric
|
||||
from ...explanation._attribution import Attribution
|
||||
from ...explanation._attribution._perturbation.replacement import RandomPerturb
|
||||
|
||||
_Array = np.ndarray
|
||||
_Label = Union[ms.Tensor, int]
|
||||
|
||||
|
||||
class Robustness(LabelSensitiveMetric):
|
||||
"""
|
||||
Robustness perturbs the inputs by adding random noise and choose the maximum sensitivity as evaluation score from
|
||||
the perturbations.
|
||||
|
||||
Args:
|
||||
num_labels (int): Number of classes in the dataset.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.benchmark import Robustness
|
||||
>>> num_labels = 100
|
||||
>>> robustness = Robustness(num_labels)
|
||||
"""
|
||||
|
||||
def __init__(self, num_labels: int, activation_fn=nn.Softmax()):
|
||||
super().__init__(num_labels)
|
||||
|
||||
self._perturb = RandomPerturb()
|
||||
self._num_perturbations = 100 # number of perturbations used in evaluation
|
||||
self._threshold = 0.1 # threshold to generate perturbation
|
||||
self._activation_fn = activation_fn
|
||||
|
||||
def evaluate(self,
|
||||
explainer: Attribution,
|
||||
inputs: Tensor,
|
||||
targets: _Label,
|
||||
saliency: Optional[Tensor] = None
|
||||
) -> _Array:
|
||||
"""
|
||||
Evaluate robustness on single sample.
|
||||
|
||||
Note:
|
||||
Currently only single sample (:math:`N=1`) at each call is supported.
|
||||
|
||||
Args:
|
||||
explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
|
||||
inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If `targets` is a 1D tensor, its length should be the same as `inputs`.
|
||||
saliency (Tensor, optional): The saliency map to be evaluated, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
If it is None, the parsed `explainer` will generate the saliency map with `inputs` and `targets` and
|
||||
continue the evaluation. Default: None.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, 1D array of shape :math:`(N,)`, result of localization evaluated on `explainer`.
|
||||
|
||||
Raises:
|
||||
ValueError: If batch_size is larger than 1.
|
||||
|
||||
Examples:
|
||||
>>> # init an explainer, the network should contain the output activation function.
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> from mindspore.explainer.benchmark import Robustness
|
||||
>>> gradient = Gradient(network)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> target_label = 5
|
||||
>>> robustness = Robustness(num_labels=10)
|
||||
>>> res = robustness.evaluate(gradient, input_x, target_label)
|
||||
"""
|
||||
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
if inputs.shape[0] > 1:
|
||||
raise ValueError('Robustness only support a sample each time, but receive {}'.format(inputs.shape[0]))
|
||||
|
||||
inputs_np = inputs.asnumpy()
|
||||
if isinstance(targets, int):
|
||||
targets = ms.Tensor(targets, ms.int32)
|
||||
if saliency is None:
|
||||
saliency = explainer(inputs, targets)
|
||||
saliency_np = saliency.asnumpy()
|
||||
norm = np.sqrt(np.sum(np.square(saliency_np), axis=tuple(range(1, len(saliency_np.shape)))))
|
||||
if norm == 0:
|
||||
log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.')
|
||||
return np.array([np.nan])
|
||||
|
||||
perturbations = []
|
||||
for sample in inputs_np:
|
||||
sample = np.expand_dims(sample, axis=0)
|
||||
perturbations_per_input = []
|
||||
for _ in range(self._num_perturbations):
|
||||
perturbation = self._perturb(sample)
|
||||
perturbations_per_input.append(perturbation)
|
||||
perturbations_per_input = np.vstack(perturbations_per_input)
|
||||
perturbations.append(perturbations_per_input)
|
||||
perturbations = np.stack(perturbations, axis=0)
|
||||
|
||||
perturbations = np.reshape(perturbations, (-1,) + inputs_np.shape[1:])
|
||||
perturbations = ms.Tensor(perturbations, ms.float32)
|
||||
|
||||
repeated_targets = np.repeat(targets.asnumpy(), repeats=self._num_perturbations, axis=0)
|
||||
repeated_targets = ms.Tensor(repeated_targets, ms.int32)
|
||||
saliency_of_perturbations = explainer(perturbations, repeated_targets)
|
||||
perturbations_saliency = saliency_of_perturbations.asnumpy()
|
||||
|
||||
repeated_saliency = np.repeat(saliency_np, repeats=self._num_perturbations, axis=0)
|
||||
|
||||
sensitivities = np.sum((repeated_saliency - perturbations_saliency) ** 2,
|
||||
axis=tuple(range(1, len(repeated_saliency.shape))))
|
||||
|
||||
max_sensitivity = np.max(sensitivities.reshape((norm.shape[0], -1)), axis=1) / norm
|
||||
robustness_res = 1 / np.exp(max_sensitivity)
|
||||
return robustness_res
|
@ -0,0 +1,182 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Modules to ablate images."""
|
||||
|
||||
__all__ = [
|
||||
'Ablation',
|
||||
'AblationWithSaliency',
|
||||
]
|
||||
|
||||
import math
|
||||
from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .replacement import Constant
|
||||
from ...._utils import rank_pixels
|
||||
|
||||
|
||||
class Ablation:
|
||||
"""Base class to ablate image based on given replacement."""
|
||||
|
||||
def __init__(self, perturb_mode: str):
|
||||
self._perturb_mode = perturb_mode
|
||||
|
||||
def __call__(self,
|
||||
inputs: np.array,
|
||||
reference: Union[np.array, float],
|
||||
masks: np.array
|
||||
) -> np.array:
|
||||
|
||||
"""
|
||||
Generate perturbations of given array.
|
||||
|
||||
Args:
|
||||
inputs (np.ndarray): Input array to perturb. The first dim of inputs is assumed to be the batch size, i.e.,
|
||||
number of samples.
|
||||
reference (np.ndarray or float): Array of values to replace the elements in the original inputs. The shape
|
||||
of reference must math the inputs. If scalar is provided, the perturbed elements will be assigned the
|
||||
given value..
|
||||
masks (np.ndarray): Several boolean array to mark the perturbed positions. True marks the pixels to be
|
||||
perturbed, otherwise the pixels will be kept. The shape of masks is assumed to be
|
||||
[batch_size, num_perturbations, inputs_shape[1:]].
|
||||
|
||||
Return:
|
||||
perturbations (np.ndarray)
|
||||
"""
|
||||
if isinstance(reference, float):
|
||||
reference = Constant(base_value=reference)(inputs)
|
||||
|
||||
if not np.array_equal(inputs.shape, reference.shape):
|
||||
raise ValueError('reference must have the same shape as inputs.')
|
||||
|
||||
num_perturbations = masks.shape[1]
|
||||
|
||||
if self._perturb_mode == 'Insertion':
|
||||
inputs, reference = reference, inputs
|
||||
|
||||
perturbations = np.repeat(inputs[:, None, :], num_perturbations, 1)
|
||||
reference = np.repeat(reference[:, None, :], num_perturbations, 1)
|
||||
Ablation._assign(perturbations, reference, masks)
|
||||
|
||||
return perturbations
|
||||
|
||||
@staticmethod
|
||||
def _assign(original_array: np.ndarray, replacement: np.ndarray, masks: np.ndarray):
|
||||
"""Assign values to perturb pixels on perturbations."""
|
||||
if masks.dtype != bool:
|
||||
raise TypeError('The param "masks" should be an array of bool, but receive {}'.format(masks.dtype))
|
||||
|
||||
if not np.array_equal(original_array.shape, masks.shape):
|
||||
raise ValueError('masks must have the shape {} same as [batch_size, num_perturbations, inputs.shape[1:],'
|
||||
'but receive {}.'.format(original_array.shape, masks.shape))
|
||||
|
||||
original_array[masks] = replacement[masks]
|
||||
|
||||
|
||||
class AblationWithSaliency(Ablation):
|
||||
"""
|
||||
Perturbation generator to generate perturbations for a given array.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_mode (str): specify perturbing mode, through deleting or
|
||||
inserting pixels. Current support: ['Deletion', 'Insertion'].
|
||||
is_accumulate (bool): whether to accumulate the former perturbations to
|
||||
the later perturbations.
|
||||
perturb_pixel_per_step (int, optional): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (int, optional): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perturb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_mode: str,
|
||||
perturb_percent: float = 1.0,
|
||||
is_accumulate: bool = False,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None):
|
||||
super().__init__(perturb_mode)
|
||||
self._perturb_percent = perturb_percent
|
||||
self._perturb_mode = perturb_mode
|
||||
self._pixel_per_step = perturb_pixel_per_step
|
||||
self._num_perturbations = num_perturbations
|
||||
self._is_accumulate = is_accumulate
|
||||
|
||||
def generate_mask(self,
|
||||
saliency: np.ndarray,
|
||||
num_channels: Optional[int] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate mask for perturbations based on given saliency ranks.
|
||||
|
||||
Args:
|
||||
saliency (np.ndarray): Perturbing masks will be generated based on the given saliency map. The shape of
|
||||
saliency is expected to be: [batch_size, optional(num_channels), *spatial_size]. If multi-channel
|
||||
saliency is provided, an averaged saliency will be taken to calculate pixel order in spatial dimension.
|
||||
num_channels (optional[int]): Number of channels of the input data. In order to match the shape of inputs,
|
||||
num_channels should be provided when input data have channels dimension, even if num_channel. If None is
|
||||
provided, the inputs is assumed to be no-channel data, and the generated mask will have no channel
|
||||
dimension. Default: None.
|
||||
|
||||
Return:
|
||||
mask (np.ndarray): boolen mask for generate perturbations.
|
||||
"""
|
||||
|
||||
batch_size = saliency.shape[0]
|
||||
expected_num_dim = len(saliency.shape) + 1
|
||||
has_channel = num_channels is not None
|
||||
num_channels = 1 if num_channels is None else num_channels
|
||||
|
||||
if has_channel:
|
||||
saliency = saliency.mean(axis=1)
|
||||
saliency_rank = rank_pixels(saliency, descending=True)
|
||||
|
||||
num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:])
|
||||
|
||||
if self._pixel_per_step:
|
||||
pixel_per_step = self._pixel_per_step
|
||||
num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step)
|
||||
elif self._num_perturbations:
|
||||
pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations)
|
||||
num_perturbations = self._num_perturbations
|
||||
else:
|
||||
raise ValueError("Must provide either pixel_per_step or num_perturbations.")
|
||||
|
||||
masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]),
|
||||
dtype=np.bool)
|
||||
|
||||
factor = 0 if self._is_accumulate else 1
|
||||
|
||||
for i in range(batch_size):
|
||||
low_bound = 0
|
||||
up_bound = low_bound + pixel_per_step
|
||||
for j in range(num_perturbations):
|
||||
masks[i, j, :, ((saliency_rank[i] >= low_bound) & (saliency_rank[i] < up_bound))] = True
|
||||
low_bound = up_bound + factor
|
||||
up_bound += pixel_per_step
|
||||
|
||||
masks = masks if has_channel else np.squeeze(masks, axis=2)
|
||||
|
||||
if len(masks.shape) == expected_num_dim:
|
||||
return masks
|
||||
raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect {expected_num_dim}-dim.')
|
@ -0,0 +1,166 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Occlusion explainer."""
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib.stride_tricks import as_strided
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from .ablation import Ablation
|
||||
from .perturbation import PerturbationAttribution
|
||||
from .replacement import Constant
|
||||
from ...._utils import abs_max
|
||||
|
||||
_Array = np.ndarray
|
||||
_Label = Union[int, Tensor]
|
||||
|
||||
|
||||
def _generate_patches(array, window_size, stride):
|
||||
"""View as windows."""
|
||||
if not isinstance(array, np.ndarray):
|
||||
raise TypeError("`array` must be a numpy ndarray")
|
||||
|
||||
arr_shape = np.array(array.shape)
|
||||
window_size = np.array(window_size, dtype=arr_shape.dtype)
|
||||
|
||||
slices = tuple(slice(None, None, st) for st in stride)
|
||||
window_strides = np.array(array.strides)
|
||||
|
||||
indexing_strides = array[slices].strides
|
||||
win_indices_shape = (((np.array(array.shape) - np.array(window_size)) // np.array(stride)) + 1)
|
||||
|
||||
new_shape = tuple(list(win_indices_shape) + list(window_size))
|
||||
strides = tuple(list(indexing_strides) + list(window_strides))
|
||||
|
||||
patches = as_strided(array, shape=new_shape, strides=strides)
|
||||
return patches
|
||||
|
||||
|
||||
class Occlusion(PerturbationAttribution):
|
||||
r"""
|
||||
Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes
|
||||
the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as
|
||||
feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the
|
||||
averaged differences from multiple sliding windows.
|
||||
|
||||
For more details, please refer to the original paper via: `<https://arxiv.org/abs/1311.2901>`_.
|
||||
|
||||
Args:
|
||||
network (Cell): Specify the black-box model to be explained.
|
||||
|
||||
Inputs:
|
||||
inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Example:
|
||||
>>> from mindspore.explainer.explanation import Occlusion
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> occlusion = Occlusion(net)
|
||||
>>> x = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 1
|
||||
>>> saliency = occlusion(x, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network: Cell, activation_fn: Cell = nn.Softmax()):
|
||||
super().__init__(network, activation_fn)
|
||||
|
||||
self._ablation = Ablation(perturb_mode='Deletion')
|
||||
self._aggregation_fn = abs_max
|
||||
self._get_replacement = Constant(base_value=0.0)
|
||||
self._num_sample_per_dim = 32 # specify the number of perturbations each dimension.
|
||||
self._num_per_eval = 32 # number of perturbations each evaluation step.
|
||||
|
||||
def __call__(self, inputs: Tensor, targets: _Label) -> Tensor:
|
||||
"""Call function for 'Occlusion'."""
|
||||
self._verify_data(inputs, targets)
|
||||
|
||||
inputs = inputs.asnumpy()
|
||||
targets = targets.asnumpy() if isinstance(targets, Tensor) else np.array([targets] * inputs.shape[0], np.int)
|
||||
|
||||
# If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to
|
||||
# `(C, 3, 3)` and `(C, 1, 1)` separately.
|
||||
window_size = tuple(
|
||||
[inputs.shape[1]]
|
||||
+ [x % self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]])
|
||||
strides = tuple(
|
||||
[inputs.shape[1]]
|
||||
+ [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]])
|
||||
|
||||
model = nn.SequentialCell([self._model, self._activation_fn])
|
||||
|
||||
original_outputs = model(Tensor(inputs, ms.float32)).asnumpy()[np.arange(len(targets)), targets]
|
||||
|
||||
total_attribution = np.zeros_like(inputs)
|
||||
weights = np.ones_like(inputs)
|
||||
masks = Occlusion._generate_masks(inputs, window_size, strides)
|
||||
num_perturbations = masks.shape[1]
|
||||
original_outputs_repeat = np.repeat(original_outputs, repeats=num_perturbations, axis=0)
|
||||
|
||||
reference = self._get_replacement(inputs)
|
||||
occluded_inputs = self._ablation(inputs, reference, masks)
|
||||
targets_repeat = np.repeat(targets, repeats=num_perturbations, axis=0)
|
||||
|
||||
occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:]))
|
||||
if occluded_inputs.shape[0] > self._num_per_eval:
|
||||
cal_time = math.ceil(occluded_inputs.shape[0] / self._num_per_eval)
|
||||
occluded_outputs = []
|
||||
for i in range(cal_time):
|
||||
occluded_input = occluded_inputs[i*self._num_per_eval
|
||||
:min((i+1) * self._num_per_eval, occluded_inputs.shape[0])]
|
||||
target = targets_repeat[i*self._num_per_eval
|
||||
:min((i+1) * self._num_per_eval, occluded_inputs.shape[0])]
|
||||
occluded_output = model(Tensor(occluded_input)).asnumpy()[np.arange(target.shape[0]), target]
|
||||
occluded_outputs.append(occluded_output)
|
||||
occluded_outputs = np.concatenate(occluded_outputs)
|
||||
else:
|
||||
occluded_outputs = model(Tensor(occluded_inputs)).asnumpy()[np.arange(len(targets_repeat)), targets_repeat]
|
||||
outputs_diff = original_outputs_repeat - occluded_outputs
|
||||
outputs_diff = outputs_diff.reshape(inputs.shape[0], -1)
|
||||
|
||||
total_attribution += (
|
||||
outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6)
|
||||
weights += masks.sum(axis=1)
|
||||
|
||||
attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights))
|
||||
return attribution
|
||||
|
||||
@staticmethod
|
||||
def _generate_masks(inputs: Tensor, window_size: Tuple[int, ...], strides: Tuple[int, ...]) -> _Array:
|
||||
"""Generate masks to perturb contiguous regions."""
|
||||
total_dim = np.prod(inputs.shape[1:]).item()
|
||||
template = np.arange(total_dim).reshape(inputs.shape[1:])
|
||||
indices = _generate_patches(template, window_size, strides)
|
||||
num_perturbations = indices.reshape((-1,) + window_size).shape[0]
|
||||
indices = indices.reshape(num_perturbations, -1)
|
||||
|
||||
mask = np.zeros((num_perturbations, total_dim), dtype=np.bool)
|
||||
for i in range(num_perturbations):
|
||||
mask[i, indices[i]] = True
|
||||
mask = mask.reshape((num_perturbations,) + inputs.shape[1:])
|
||||
|
||||
masks = np.tile(mask, reps=(inputs.shape[0],) + (1,) * len(mask.shape))
|
||||
return masks
|
@ -0,0 +1,85 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Modules to generate perturbations."""
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
|
||||
_Array = np.ndarray
|
||||
|
||||
__all__ = [
|
||||
'BaseReplacement',
|
||||
'Constant',
|
||||
'GaussianBlur',
|
||||
'RandomPerturb',
|
||||
]
|
||||
|
||||
|
||||
class BaseReplacement:
|
||||
"""
|
||||
Base class of generator for generating different replacement for perturbations.
|
||||
|
||||
Args:
|
||||
kwargs: Optional args for generating replacement. Derived class need to
|
||||
add necessary arg names and default value to '_necessary_args'.
|
||||
If the argument has no default value, the value should be set to
|
||||
'EMPTY' to mark the required args. Initializing an object will
|
||||
check the given kwargs w.r.t '_necessary_args'.
|
||||
|
||||
Raise:
|
||||
ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark.
|
||||
"""
|
||||
_necessary_args = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._replace_args = self._necessary_args.copy()
|
||||
for key, value in self._replace_args.items():
|
||||
if key in kwargs.keys():
|
||||
self._replace_args[key] = kwargs[key]
|
||||
elif key not in kwargs.keys() and value == 'EMPTY':
|
||||
raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.")
|
||||
|
||||
def __call__(self, inputs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Constant(BaseReplacement):
|
||||
"""Generator to provide constant-value replacement for perturbations."""
|
||||
_necessary_args = {'base_value': 'EMPTY'}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
replacement = np.ones_like(inputs, dtype=np.float32)
|
||||
replacement *= self._replace_args['base_value']
|
||||
return replacement
|
||||
|
||||
|
||||
class GaussianBlur(BaseReplacement):
|
||||
"""Generator to provided gaussian blurred inputs for perturbation"""
|
||||
_necessary_args = {'sigma': 0.7}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
sigma = self._replace_args['sigma']
|
||||
replacement = gaussian_filter(inputs, sigma=sigma)
|
||||
return replacement
|
||||
|
||||
|
||||
class RandomPerturb(BaseReplacement):
|
||||
"""Generator to provide replacement by randomly adding noise."""
|
||||
_necessary_args = {'radius': 0.2}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
radius = self._replace_args['radius']
|
||||
outputs = inputs + (2 * np.random.rand(*inputs.shape) - 1) * radius
|
||||
return outputs
|
Loading…
Reference in new issue