|
|
|
@ -15,7 +15,6 @@
|
|
|
|
|
"""Occlusion explainer."""
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
from typing import Tuple, Union
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from numpy.lib.stride_tricks import as_strided
|
|
|
|
@ -23,15 +22,11 @@ from numpy.lib.stride_tricks import as_strided
|
|
|
|
|
import mindspore as ms
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore.nn import Cell
|
|
|
|
|
from .ablation import Ablation
|
|
|
|
|
from .perturbation import PerturbationAttribution
|
|
|
|
|
from .replacement import Constant
|
|
|
|
|
from ...._utils import abs_max
|
|
|
|
|
|
|
|
|
|
_Array = np.ndarray
|
|
|
|
|
_Label = Union[int, Tensor]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_patches(array, window_size, stride):
|
|
|
|
|
"""View as windows."""
|
|
|
|
@ -76,16 +71,17 @@ class Occlusion(PerturbationAttribution):
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
>>> from mindspore.explainer.explanation import Occlusion
|
|
|
|
|
>>> net = resnet50(10)
|
|
|
|
|
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
>>> network = resnet50(10)
|
|
|
|
|
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
|
|
|
|
>>> load_param_into_net(net, param_dict)
|
|
|
|
|
>>> occlusion = Occlusion(net)
|
|
|
|
|
>>> x = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
|
|
|
|
>>> load_param_into_net(network, param_dict)
|
|
|
|
|
>>> occlusion = Occlusion(network)
|
|
|
|
|
>>> x = Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
|
|
|
|
>>> label = 1
|
|
|
|
|
>>> saliency = occlusion(x, label)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, network: Cell, activation_fn: Cell = nn.Softmax()):
|
|
|
|
|
def __init__(self, network, activation_fn=nn.Softmax()):
|
|
|
|
|
super().__init__(network, activation_fn)
|
|
|
|
|
|
|
|
|
|
self._ablation = Ablation(perturb_mode='Deletion')
|
|
|
|
@ -94,7 +90,7 @@ class Occlusion(PerturbationAttribution):
|
|
|
|
|
self._num_sample_per_dim = 32 # specify the number of perturbations each dimension.
|
|
|
|
|
self._num_per_eval = 32 # number of perturbations each evaluation step.
|
|
|
|
|
|
|
|
|
|
def __call__(self, inputs: Tensor, targets: _Label) -> Tensor:
|
|
|
|
|
def __call__(self, inputs, targets):
|
|
|
|
|
"""Call function for 'Occlusion'."""
|
|
|
|
|
self._verify_data(inputs, targets)
|
|
|
|
|
|
|
|
|
@ -145,11 +141,11 @@ class Occlusion(PerturbationAttribution):
|
|
|
|
|
outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6)
|
|
|
|
|
weights += masks.sum(axis=1)
|
|
|
|
|
|
|
|
|
|
attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights))
|
|
|
|
|
attribution = self._aggregation_fn(Tensor(total_attribution / weights))
|
|
|
|
|
return attribution
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _generate_masks(inputs: Tensor, window_size: Tuple[int, ...], strides: Tuple[int, ...]) -> _Array:
|
|
|
|
|
def _generate_masks(inputs, window_size, strides):
|
|
|
|
|
"""Generate masks to perturb contiguous regions."""
|
|
|
|
|
total_dim = np.prod(inputs.shape[1:]).item()
|
|
|
|
|
template = np.arange(total_dim).reshape(inputs.shape[1:])
|
|
|
|
|