|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
@ -62,23 +62,22 @@ class RISE(PerturbationAttribution):
|
|
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> import mindspore as ms
|
|
|
|
|
>>> from mindspore.explainer.explanation import RISE
|
|
|
|
|
>>> from mindspore.nn import Sigmoid
|
|
|
|
|
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
|
|
|
|
>>> network = resnet50(10)
|
|
|
|
|
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
|
|
|
|
>>> load_param_into_net(network, param_dict)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
|
|
|
|
>>> net = LeNet5(10, num_channel=3)
|
|
|
|
|
>>> # initialize RISE explainer with the pretrained model and activation function
|
|
|
|
|
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
|
|
|
|
|
>>> rise = RISE(network, activation_fn=activation_fn)
|
|
|
|
|
>>> rise = RISE(net, activation_fn=activation_fn)
|
|
|
|
|
>>> # given an instance of RISE, saliency map can be generate
|
|
|
|
|
>>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32)
|
|
|
|
|
>>> inputs = ms.Tensor(np.random.rand(2, 3, 32, 32), ms.float32)
|
|
|
|
|
>>> # when `targets` is an integer
|
|
|
|
|
>>> targets = 5
|
|
|
|
|
>>> saliency = rise(inputs, targets)
|
|
|
|
|
>>> print(saliency.shape)
|
|
|
|
|
>>> # `targets` can also be a 2D tensor
|
|
|
|
|
>>> targets = ms.Tensor([[5], [1]], ms.int32)
|
|
|
|
|
>>> saliency = rise(inputs, targets)
|
|
|
|
|
>>> print(saliency.shape)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
@ -88,7 +87,7 @@ class RISE(PerturbationAttribution):
|
|
|
|
|
super(RISE, self).__init__(network, activation_fn, perturbation_per_eval)
|
|
|
|
|
|
|
|
|
|
self._num_masks = 6000 # number of masks to be sampled
|
|
|
|
|
self._mask_probability = 0.2 # ratio of inputs to be masked
|
|
|
|
|
self._mask_probability = 0.5 # ratio of inputs to be masked
|
|
|
|
|
self._down_sample_size = 10 # the original size of binary masks
|
|
|
|
|
self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs
|
|
|
|
|
self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value
|
|
|
|
@ -127,7 +126,9 @@ class RISE(PerturbationAttribution):
|
|
|
|
|
self._num_classes = num_classes
|
|
|
|
|
|
|
|
|
|
# Due to the unsupported Op of slice assignment, we use numpy array here
|
|
|
|
|
attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width))
|
|
|
|
|
targets = self._unify_targets(inputs, targets)
|
|
|
|
|
|
|
|
|
|
attr_np = np.zeros(shape=(batch_size, targets.shape[1], height, width))
|
|
|
|
|
|
|
|
|
|
cal_times = math.ceil(self._num_masks / self._perturbation_per_eval)
|
|
|
|
|
|
|
|
|
@ -143,24 +144,21 @@ class RISE(PerturbationAttribution):
|
|
|
|
|
weights = self._activation_fn(self.network(masked_input))
|
|
|
|
|
while len(weights.shape) > 2:
|
|
|
|
|
weights = op.mean(weights, axis=2)
|
|
|
|
|
weights = op.reshape(weights,
|
|
|
|
|
(bs, self._num_classes, 1, 1))
|
|
|
|
|
|
|
|
|
|
attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy()
|
|
|
|
|
weights = np.expand_dims(np.expand_dims(weights.asnumpy()[:, targets[idx]], 2), 3)
|
|
|
|
|
|
|
|
|
|
attr_np = attr_np / self._num_masks
|
|
|
|
|
targets = self._unify_targets(inputs, targets)
|
|
|
|
|
attr_np[idx] += np.sum(weights * masks.asnumpy(), axis=0)
|
|
|
|
|
|
|
|
|
|
attr_classes = [att_i[target] for att_i, target in zip(attr_np, targets)]
|
|
|
|
|
attr_np = attr_np / self._num_masks
|
|
|
|
|
|
|
|
|
|
return op.Tensor(attr_classes, dtype=inputs.dtype)
|
|
|
|
|
return op.Tensor(attr_np, dtype=inputs.dtype)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _verify_data(inputs, targets):
|
|
|
|
|
"""Verify the validity of the parsed inputs."""
|
|
|
|
|
check_value_type('inputs', inputs, Tensor)
|
|
|
|
|
if len(inputs.shape) != 4:
|
|
|
|
|
raise ValueError('Argument inputs must be 4D Tensor')
|
|
|
|
|
raise ValueError(f'Argument inputs must be 4D Tensor, but got {len(inputs.shape)}D Tensor.')
|
|
|
|
|
check_value_type('targets', targets, (Tensor, int, tuple, list))
|
|
|
|
|
if isinstance(targets, Tensor):
|
|
|
|
|
if len(targets.shape) > 2:
|
|
|
|
@ -168,7 +166,7 @@ class RISE(PerturbationAttribution):
|
|
|
|
|
'But got {}D.'.format(len(targets.shape)))
|
|
|
|
|
if targets.shape and len(targets) != len(inputs):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format(
|
|
|
|
|
'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}.'.format(
|
|
|
|
|
len(inputs), len(targets)))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|