|
|
|
@ -14,14 +14,11 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Occlusion explainer."""
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from numpy.lib.stride_tricks import as_strided
|
|
|
|
|
|
|
|
|
|
import mindspore as ms
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from .ablation import Ablation
|
|
|
|
|
from .perturbation import PerturbationAttribution
|
|
|
|
|
from .replacement import Constant
|
|
|
|
@ -62,8 +59,8 @@ class Occlusion(PerturbationAttribution):
|
|
|
|
|
network (Cell): Specify the black-box model to be explained.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
|
|
|
|
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
|
|
|
|
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
|
|
|
|
- **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer.
|
|
|
|
|
If it is a 1D tensor, its length should be the same as `inputs`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
@ -72,13 +69,15 @@ class Occlusion(PerturbationAttribution):
|
|
|
|
|
Example:
|
|
|
|
|
>>> from mindspore.explainer.explanation import Occlusion
|
|
|
|
|
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
|
|
|
|
>>> network = resnet50(10)
|
|
|
|
|
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
|
|
|
|
>>> load_param_into_net(network, param_dict)
|
|
|
|
|
>>> # initialize Occlusion explainer and pass the pretrained model
|
|
|
|
|
>>> occlusion = Occlusion(network)
|
|
|
|
|
>>> x = Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
|
|
|
|
>>> label = 1
|
|
|
|
|
>>> saliency = occlusion(x, label)
|
|
|
|
|
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
|
|
|
|
>>> label = ms.Tensor([1], ms.int32)
|
|
|
|
|
>>> saliency = occlusion(input_x, label)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, activation_fn=nn.Softmax()):
|
|
|
|
@ -88,62 +87,63 @@ class Occlusion(PerturbationAttribution):
|
|
|
|
|
self._aggregation_fn = abs_max
|
|
|
|
|
self._get_replacement = Constant(base_value=0.0)
|
|
|
|
|
self._num_sample_per_dim = 32 # specify the number of perturbations each dimension.
|
|
|
|
|
self._num_per_eval = 32 # number of perturbations each evaluation step.
|
|
|
|
|
self._num_per_eval = 2 # number of perturbations generate for each sample per evaluation step.
|
|
|
|
|
|
|
|
|
|
def __call__(self, inputs, targets):
|
|
|
|
|
"""Call function for 'Occlusion'."""
|
|
|
|
|
self._verify_data(inputs, targets)
|
|
|
|
|
|
|
|
|
|
inputs = inputs.asnumpy()
|
|
|
|
|
targets = targets.asnumpy() if isinstance(targets, Tensor) else np.array([targets] * inputs.shape[0], np.int)
|
|
|
|
|
inputs_np = inputs.asnumpy()
|
|
|
|
|
targets_np = targets.asnumpy() if isinstance(targets, ms.Tensor) else np.array([targets], np.int)
|
|
|
|
|
|
|
|
|
|
# If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to
|
|
|
|
|
# `(C, 3, 3)` and `(C, 1, 1)` separately.
|
|
|
|
|
window_size = tuple(
|
|
|
|
|
[inputs.shape[1]]
|
|
|
|
|
+ [x % self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]])
|
|
|
|
|
strides = tuple(
|
|
|
|
|
[inputs.shape[1]]
|
|
|
|
|
+ [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]])
|
|
|
|
|
batch_size = inputs_np.shape[0]
|
|
|
|
|
window_size, strides = self._get_window_size_and_strides(inputs_np)
|
|
|
|
|
|
|
|
|
|
model = nn.SequentialCell([self._model, self._activation_fn])
|
|
|
|
|
|
|
|
|
|
original_outputs = model(Tensor(inputs, ms.float32)).asnumpy()[np.arange(len(targets)), targets]
|
|
|
|
|
original_outputs = model(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np]
|
|
|
|
|
|
|
|
|
|
total_attribution = np.zeros_like(inputs)
|
|
|
|
|
weights = np.ones_like(inputs)
|
|
|
|
|
masks = Occlusion._generate_masks(inputs, window_size, strides)
|
|
|
|
|
total_attribution = np.zeros_like(inputs_np)
|
|
|
|
|
weights = np.ones_like(inputs_np)
|
|
|
|
|
masks = Occlusion._generate_masks(inputs_np, window_size, strides)
|
|
|
|
|
num_perturbations = masks.shape[1]
|
|
|
|
|
original_outputs_repeat = np.repeat(original_outputs, repeats=num_perturbations, axis=0)
|
|
|
|
|
|
|
|
|
|
reference = self._get_replacement(inputs)
|
|
|
|
|
occluded_inputs = self._ablation(inputs, reference, masks)
|
|
|
|
|
targets_repeat = np.repeat(targets, repeats=num_perturbations, axis=0)
|
|
|
|
|
|
|
|
|
|
occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:]))
|
|
|
|
|
if occluded_inputs.shape[0] > self._num_per_eval:
|
|
|
|
|
cal_time = math.ceil(occluded_inputs.shape[0] / self._num_per_eval)
|
|
|
|
|
occluded_outputs = []
|
|
|
|
|
for i in range(cal_time):
|
|
|
|
|
occluded_input = occluded_inputs[i*self._num_per_eval
|
|
|
|
|
:min((i+1) * self._num_per_eval, occluded_inputs.shape[0])]
|
|
|
|
|
target = targets_repeat[i*self._num_per_eval
|
|
|
|
|
:min((i+1) * self._num_per_eval, occluded_inputs.shape[0])]
|
|
|
|
|
occluded_output = model(Tensor(occluded_input)).asnumpy()[np.arange(target.shape[0]), target]
|
|
|
|
|
occluded_outputs.append(occluded_output)
|
|
|
|
|
occluded_outputs = np.concatenate(occluded_outputs)
|
|
|
|
|
else:
|
|
|
|
|
occluded_outputs = model(Tensor(occluded_inputs)).asnumpy()[np.arange(len(targets_repeat)), targets_repeat]
|
|
|
|
|
outputs_diff = original_outputs_repeat - occluded_outputs
|
|
|
|
|
outputs_diff = outputs_diff.reshape(inputs.shape[0], -1)
|
|
|
|
|
|
|
|
|
|
total_attribution += (
|
|
|
|
|
outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6)
|
|
|
|
|
weights += masks.sum(axis=1)
|
|
|
|
|
|
|
|
|
|
attribution = self._aggregation_fn(Tensor(total_attribution / weights))
|
|
|
|
|
reference = self._get_replacement(inputs_np)
|
|
|
|
|
|
|
|
|
|
count = 0
|
|
|
|
|
while count < num_perturbations:
|
|
|
|
|
ith_masks = masks[:, count:min(count+self._num_per_eval, num_perturbations)]
|
|
|
|
|
actual_num_eval = ith_masks.shape[1]
|
|
|
|
|
num_samples = batch_size * actual_num_eval
|
|
|
|
|
occluded_inputs = self._ablation(inputs_np, reference, ith_masks)
|
|
|
|
|
occluded_inputs = occluded_inputs.reshape((-1, *inputs_np.shape[1:]))
|
|
|
|
|
targets_repeat = np.repeat(targets_np, repeats=actual_num_eval, axis=0)
|
|
|
|
|
occluded_outputs = model(
|
|
|
|
|
ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat]
|
|
|
|
|
original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0)
|
|
|
|
|
outputs_diff = original_outputs_repeat - occluded_outputs
|
|
|
|
|
total_attribution += (
|
|
|
|
|
outputs_diff.reshape(ith_masks.shape[:2] + (1,) * (len(masks.shape) - 2)) * ith_masks).sum(axis=1)
|
|
|
|
|
weights += ith_masks.sum(axis=1)
|
|
|
|
|
count += actual_num_eval
|
|
|
|
|
attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights, ms.float32))
|
|
|
|
|
return attribution
|
|
|
|
|
|
|
|
|
|
def _get_window_size_and_strides(self, inputs):
|
|
|
|
|
"""
|
|
|
|
|
Return window_size and strides.
|
|
|
|
|
|
|
|
|
|
# If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to
|
|
|
|
|
# `(C, 3, 3)` and `(C, 1, 1)` separately. Otherwise, the window_size and strides will generated adaptively to
|
|
|
|
|
match self._num_sample_per_dim.
|
|
|
|
|
"""
|
|
|
|
|
window_size = tuple(
|
|
|
|
|
[inputs.shape[1]]
|
|
|
|
|
+ [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]])
|
|
|
|
|
strides = tuple(
|
|
|
|
|
[inputs.shape[1]]
|
|
|
|
|
+ [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]])
|
|
|
|
|
return window_size, strides
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _generate_masks(inputs, window_size, strides):
|
|
|
|
|
"""Generate masks to perturb contiguous regions."""
|
|
|
|
|