This commit provides APIs for user to use the widely used attribution methods to explain DL models and the evaluation methods to quantify the explanations. With combination of MindInsight, the user can have a friendly visualization on their models.pull/7656/head
parent
a418280659
commit
744f094add
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Provide ExplainRunner High-level API."""
|
||||||
|
|
||||||
|
from ._runner import ExplainRunner
|
||||||
|
|
||||||
|
__all__ = ['ExplainRunner']
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,23 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Predefined XAI metrics."""
|
||||||
|
|
||||||
|
from ._attribution.faithfulness import Faithfulness
|
||||||
|
from ._attribution.localization import Localization
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Faithfulness",
|
||||||
|
"Localization"
|
||||||
|
]
|
@ -0,0 +1,23 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Predefined XAI metrics"""
|
||||||
|
|
||||||
|
from .faithfulness import Faithfulness
|
||||||
|
from .localization import Localization
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Faithfulness",
|
||||||
|
"Localization"
|
||||||
|
]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Localization metrics."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.train._utils import check_value_type
|
||||||
|
from .metric import AttributionMetric
|
||||||
|
from ..._operators import maximum, reshape, Tensor
|
||||||
|
from ..._utils import format_tensor_to_ndarray
|
||||||
|
|
||||||
|
|
||||||
|
def _get_max_position(saliency):
|
||||||
|
"""Get the position of the max pixel of the saliency map."""
|
||||||
|
saliency = saliency.asnumpy()
|
||||||
|
w = saliency.shape[3]
|
||||||
|
saliency = np.reshape(saliency, (len(saliency), -1))
|
||||||
|
max_arg = np.argmax(saliency, axis=1)
|
||||||
|
return max_arg // w, max_arg - (max_arg // w) * w
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_out_saliency(saliency, threshold):
|
||||||
|
"""Keep the saliency map with value greater than threshold."""
|
||||||
|
max_value = maximum(saliency)
|
||||||
|
mask_out = saliency > (reshape(max_value, (len(saliency), -1, 1, 1)) * threshold)
|
||||||
|
return mask_out
|
||||||
|
|
||||||
|
|
||||||
|
class Localization(AttributionMetric):
|
||||||
|
"""
|
||||||
|
Provides evaluation on the localization capability of XAI methods.
|
||||||
|
|
||||||
|
We support two metrics for the evaluation os localization capability: "PointingGame" and "IoSR".
|
||||||
|
For metric "PointingGame", the localization capability is calculated as the ratio of data in which the max position
|
||||||
|
of their saliency maps lies within the bounding boxes. Specifically, for a single datum, given the saliency map and
|
||||||
|
its bounding box, if the max point of its saliency map lies within the bounding box, the evaluation result is 1
|
||||||
|
otherwise 0.
|
||||||
|
|
||||||
|
For metric "IoSR" (Intersection over Salient Region), the localization capability is calculated as the intersection
|
||||||
|
of the bounding box and the salient region over the area of the salient region.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_labels (int): number of classes in the dataset.
|
||||||
|
metric (str): specific metric to calculate localization capability.
|
||||||
|
Options: "PointingGame", "IoSR".
|
||||||
|
Default: "PointingGame".
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mindspore.explainer.benchmark import Localization
|
||||||
|
>>> num_labels = 100
|
||||||
|
>>> localization = Localization(num_labels, "PointingGame")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_labels,
|
||||||
|
metric="PointingGame"
|
||||||
|
):
|
||||||
|
super(Localization, self).__init__(num_labels)
|
||||||
|
self._verify_metrics(metric)
|
||||||
|
self._metric = metric
|
||||||
|
|
||||||
|
# Arg for specific metric, for "PointingGame" it should be an integer indicating the tolerance
|
||||||
|
# of "PointingGame", while for "IoSR" it should be a float number
|
||||||
|
# indicating the threshold to choose salient region. Default: 25.
|
||||||
|
if self._metric == "PointingGame":
|
||||||
|
self._metric_arg = 15
|
||||||
|
else:
|
||||||
|
self._metric_arg = 0.5
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _verify_metrics(metric):
|
||||||
|
"""Verify the user defined metric."""
|
||||||
|
supports = ["PointingGame", "IoSR"]
|
||||||
|
if metric not in supports:
|
||||||
|
raise ValueError("Metric should be one of {}".format(supports))
|
||||||
|
|
||||||
|
def evaluate(self, explainer, inputs, targets, saliency=None, mask=None):
|
||||||
|
"""
|
||||||
|
Evaluate localization on a single data sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`.
|
||||||
|
inputs (Tensor): data sample. Currently only support single sample at each call.
|
||||||
|
targets (int): target label to evaluate on.
|
||||||
|
saliency (Tensor): A saliency tensor.
|
||||||
|
mask (Union[Tensor, np.ndarray]): ground truth bounding box/masks for the inputs w.r.t targets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray, result of localization evaluated on explainer
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # init an explainer, the network should contain the output activation function.
|
||||||
|
>>> gradient = Gradient(network)
|
||||||
|
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||||
|
>>> masks = np.zeros(1, 1, 224, 224)
|
||||||
|
>>> masks[:, :, 65: 100, 65: 100] = 1
|
||||||
|
>>> targets = 5
|
||||||
|
>>> # usage 1: input the explainer and the data to be explained,
|
||||||
|
>>> # calculate the faithfulness with the specified metric
|
||||||
|
>>> res = localization.evaluate(gradient, inputs, targets, mask=masks)
|
||||||
|
>>> # usage 2: input the generated saliency map
|
||||||
|
>>> saliency = gradient(inputs, targets)
|
||||||
|
>>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
|
||||||
|
"""
|
||||||
|
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||||
|
|
||||||
|
mask_np = format_tensor_to_ndarray(mask)[0]
|
||||||
|
|
||||||
|
if saliency is None:
|
||||||
|
saliency = explainer(inputs, targets)
|
||||||
|
|
||||||
|
if self._metric == "PointingGame":
|
||||||
|
point = _get_max_position(saliency)
|
||||||
|
|
||||||
|
x, y = np.meshgrid(
|
||||||
|
(np.arange(mask_np.shape[1]) - point[0]) ** 2,
|
||||||
|
(np.arange(mask_np.shape[2]) - point[1]) ** 2)
|
||||||
|
max_region = (x + y) < self._metric_arg ** 2
|
||||||
|
|
||||||
|
# if max_region has overlap with mask_np return 1 otherwise 0.
|
||||||
|
result = 1 if (mask_np.astype(bool) & max_region).any() else 0
|
||||||
|
|
||||||
|
elif self._metric == "IoSR":
|
||||||
|
mask_out = _mask_out_saliency(saliency, self._metric_arg)
|
||||||
|
mask_out_np = format_tensor_to_ndarray(mask_out)
|
||||||
|
overlap = np.sum(mask_np.astype(bool) & mask_out_np.astype(bool))
|
||||||
|
saliency_area = np.sum(mask_out_np)
|
||||||
|
result = overlap / saliency_area.clip(min=1e-10)
|
||||||
|
return np.array([result], np.float)
|
||||||
|
|
||||||
|
def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask):
|
||||||
|
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||||
|
check_value_type('mask', mask, (Tensor, np.ndarray))
|
||||||
|
if len(inputs.shape) != 4:
|
||||||
|
raise ValueError('Argument mask must be 4D Tensor')
|
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Base class for XAI metrics."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.train._utils import check_value_type
|
||||||
|
from ..._operators import Tensor
|
||||||
|
from ..._utils import format_tensor_to_ndarray
|
||||||
|
from ...explanation._attribution._attribution import Attribution
|
||||||
|
|
||||||
|
|
||||||
|
def verify_argument(inputs, arg_name):
|
||||||
|
"""Verify the validity of the parsed arguments."""
|
||||||
|
check_value_type(arg_name, inputs, Tensor)
|
||||||
|
if len(inputs.shape) != 4:
|
||||||
|
raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name))
|
||||||
|
if len(inputs) > 1:
|
||||||
|
raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs)))
|
||||||
|
|
||||||
|
|
||||||
|
def verify_targets(targets, num_labels):
|
||||||
|
"""Verify the validity of the parsed targets."""
|
||||||
|
check_value_type('targets', targets, (int, Tensor))
|
||||||
|
|
||||||
|
if isinstance(targets, Tensor):
|
||||||
|
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1):
|
||||||
|
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
|
||||||
|
'it should have the length = 1 as we only support single evaluation now.')
|
||||||
|
targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy())
|
||||||
|
if targets > num_labels - 1 or targets < 0:
|
||||||
|
raise ValueError('Parsed targets exceed the label range.')
|
||||||
|
|
||||||
|
|
||||||
|
class AttributionMetric:
|
||||||
|
"""Super class of XAI metric class used in classification scenarios."""
|
||||||
|
|
||||||
|
def __init__(self, num_labels=None):
|
||||||
|
self._num_labels = num_labels
|
||||||
|
self._global_results = {i: [] for i in range(num_labels)}
|
||||||
|
|
||||||
|
def evaluate(self, explainer, inputs, targets, saliency=None):
|
||||||
|
"""This function evaluates on a single sample and return the result."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def aggregate(self, result, targets):
|
||||||
|
"""Aggregates single result to global_results."""
|
||||||
|
if isinstance(result, float):
|
||||||
|
if isinstance(targets, int):
|
||||||
|
self._global_results[targets].append(result)
|
||||||
|
else:
|
||||||
|
target_np = format_tensor_to_ndarray(targets)
|
||||||
|
if len(target_np) > 1:
|
||||||
|
raise ValueError("One result can not be aggreated to multiple targets.")
|
||||||
|
else:
|
||||||
|
result_np = format_tensor_to_ndarray(result)
|
||||||
|
if isinstance(targets, int):
|
||||||
|
for res in result_np:
|
||||||
|
self._global_results[targets].append(float(res))
|
||||||
|
else:
|
||||||
|
target_np = format_tensor_to_ndarray(targets)
|
||||||
|
if len(target_np) != len(result_np):
|
||||||
|
raise ValueError("Length of result does not match with length of targets.")
|
||||||
|
for tar, res in zip(target_np, result_np):
|
||||||
|
self._global_results[int(tar)].append(float(res))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Resets global_result."""
|
||||||
|
self._global_results = {i: [] for i in range(self._num_labels)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_performances(self):
|
||||||
|
"""
|
||||||
|
Get the class performances by global result.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(:class:`np.ndarray`): :attr:`num_labels`-dimensional vector
|
||||||
|
containing per-class performance.
|
||||||
|
"""
|
||||||
|
count = np.array(
|
||||||
|
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||||
|
result_sum = np.array(
|
||||||
|
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||||
|
return result_sum / count.clip(min=1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def performance(self):
|
||||||
|
"""
|
||||||
|
Get the performance by global result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(:class:`float`): mean performance.
|
||||||
|
"""
|
||||||
|
count = sum(
|
||||||
|
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||||
|
result_sum = sum(
|
||||||
|
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||||
|
if count == 0:
|
||||||
|
return 0
|
||||||
|
return result_sum / count
|
||||||
|
|
||||||
|
def get_results(self):
|
||||||
|
"""Global result of the metric can be return"""
|
||||||
|
return self._global_results
|
||||||
|
|
||||||
|
def _check_evaluate_param(self, explainer, inputs, targets, saliency):
|
||||||
|
"""Check the evaluate parameters."""
|
||||||
|
check_value_type('explainer', explainer, Attribution)
|
||||||
|
verify_argument(inputs, 'inputs')
|
||||||
|
verify_targets(targets, self._num_labels)
|
||||||
|
check_value_type('saliency', saliency, (Tensor, type(None)))
|
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Predefined Attribution explainers."""
|
||||||
|
|
||||||
|
from ._attribution._backprop.gradcam import GradCAM
|
||||||
|
from ._attribution._backprop.gradient import Gradient
|
||||||
|
from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Gradient',
|
||||||
|
'Deconvolution',
|
||||||
|
'GuidedBackprop',
|
||||||
|
'GradCAM',
|
||||||
|
]
|
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Predefined Attribution explainers."""
|
||||||
|
from ._backprop.gradcam import GradCAM
|
||||||
|
from ._backprop.gradient import Gradient
|
||||||
|
from ._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Gradient',
|
||||||
|
'Deconvolution',
|
||||||
|
'GuidedBackprop',
|
||||||
|
'GradCAM',
|
||||||
|
]
|
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Attribution."""
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
class Attribution:
|
||||||
|
r"""
|
||||||
|
Basic class of attributing the salient score
|
||||||
|
|
||||||
|
The explainers which explanation through attributing the relevance scores
|
||||||
|
should inherit this class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (ms.nn.Cell): The black-box model to explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network):
|
||||||
|
self._verify_model(network)
|
||||||
|
self._model = network
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _verify_model(model):
|
||||||
|
"""
|
||||||
|
Verify the input `network` for __init__ function.
|
||||||
|
"""
|
||||||
|
if not isinstance(model, ms.nn.Cell):
|
||||||
|
raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.")
|
||||||
|
|
||||||
|
|
||||||
|
__call__: Callable
|
||||||
|
"""
|
||||||
|
The explainers return the explanations by calling directly on the explanation.
|
||||||
|
Derived class should overwrite this implementations for different
|
||||||
|
algorithms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (ms.Tensor): Input tensor to be explained.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- saliency map (ms.Tensor): saliency map of the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
return self._model
|
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Backprop-base _attribution explainer."""
|
||||||
|
|
||||||
|
from .gradient import Gradient
|
||||||
|
from .gradcam import GradCAM
|
||||||
|
from .modified_relu import Deconvolution, GuidedBackprop
|
||||||
|
|
||||||
|
__all__ = ['Gradient',
|
||||||
|
'GradCAM',
|
||||||
|
'Deconvolution',
|
||||||
|
'GuidedBackprop']
|
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Providing utility functions."""
|
||||||
|
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
|
||||||
|
from ...._utils import unify_inputs, unify_targets, generate_one_hot
|
||||||
|
|
||||||
|
|
||||||
|
def compute_gradients(model, inputs, targets=None, weights=None):
|
||||||
|
r"""
|
||||||
|
Compute the gradient of output w.r.t input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`ms.nn.Cell`): Differentiable black-box model.
|
||||||
|
inputs (`ms.Tensor`): Input to calculate gradient and explanation.
|
||||||
|
targets (int, optional): Target label id specifying which category to compute gradient. Default: None.
|
||||||
|
weights (`ms.Tensor`, optional): Custom weights for computing gradients. The shape of weights should match the
|
||||||
|
model outputs. If None is provided, an one-hot weights with one in targets positions will be used instead.
|
||||||
|
Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
saliency map (ms.Tensor): Gradient back-propagated to the input.
|
||||||
|
"""
|
||||||
|
inputs = unify_inputs(inputs)
|
||||||
|
if targets is None and weights is None:
|
||||||
|
raise ValueError('Must provide one of targets or weights')
|
||||||
|
if weights is None:
|
||||||
|
targets = unify_targets(targets)
|
||||||
|
output = model(*inputs).asnumpy()
|
||||||
|
num_categories = output.shape[-1]
|
||||||
|
weights = generate_one_hot(targets, num_categories)
|
||||||
|
|
||||||
|
grad_op = GradOperation(
|
||||||
|
get_all=True, get_by_list=False, sens_param=True)(model)
|
||||||
|
gradients = grad_op(*inputs, weights)
|
||||||
|
return gradients[0]
|
@ -0,0 +1,141 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
""" GradCAM and GuidedGradCAM. """
|
||||||
|
|
||||||
|
from mindspore.ops import operations as op
|
||||||
|
|
||||||
|
from .backprop_utils import compute_gradients
|
||||||
|
from .intermediate_layer import IntermediateLayerAttribution
|
||||||
|
from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _gradcam_aggregation(attributions):
|
||||||
|
"""
|
||||||
|
Aggregate the gradient and activation to get the final _attribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attributions (Tensor): the _attribution with channel dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: the _attribution with channel dimension aggregated.
|
||||||
|
"""
|
||||||
|
sum_ = op.ReduceSum(keep_dims=True)
|
||||||
|
relu_ = op.ReLU()
|
||||||
|
attributions = relu_(sum_(attributions, 1))
|
||||||
|
return attributions
|
||||||
|
|
||||||
|
|
||||||
|
class GradCAM(IntermediateLayerAttribution):
|
||||||
|
r"""
|
||||||
|
Provides GradCAM explanation method.
|
||||||
|
|
||||||
|
GradCAM generates saliency map at intermediate layer.
|
||||||
|
..math:
|
||||||
|
\alpha_k^c = 1/Z \sum_i \sum_j \div{\partial{y^c}}{\partial{A_{i,j}^k}}
|
||||||
|
L_{GradCAM} = ReLu(\sum_k \alpha_k^c A^k)
|
||||||
|
For more details, please refer to the original paper: GradCAM
|
||||||
|
[https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The black-box model to be explained.
|
||||||
|
layer (str): The layer name to generate the explanation at. Default: ''.
|
||||||
|
If default, the explantion will be generated at the input layer.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = resnet50(10)
|
||||||
|
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||||
|
>>> load_param_into_net(net, param_dict)
|
||||||
|
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||||
|
>>> # you may also use the net itself.
|
||||||
|
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||||
|
>>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer.
|
||||||
|
>>> layer_name = '0.layer4'
|
||||||
|
>>> # init GradCAM with a trained network and specify the layer to obtain
|
||||||
|
>>> gradcam = GradCAM(net, layer=layer_name)
|
||||||
|
>>> # parse data and the target label to be explained and get the saliency map
|
||||||
|
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||||
|
>>> label = 5
|
||||||
|
>>> saliency = gradcam(inputs, label)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
network,
|
||||||
|
layer=""):
|
||||||
|
super(GradCAM, self).__init__(network, layer)
|
||||||
|
|
||||||
|
self._saliency_cell = retrieve_layer(self._backward_model, target_layer=layer)
|
||||||
|
self._avgpool = op.ReduceMean(keep_dims=True)
|
||||||
|
self._intermediate_grad = None
|
||||||
|
self._aggregation_fn = _gradcam_aggregation
|
||||||
|
self._resize_mode = 'bilinear'
|
||||||
|
|
||||||
|
def _hook_cell(self):
|
||||||
|
if self._saliency_cell:
|
||||||
|
self._saliency_cell.register_backward_hook(self._cell_hook_fn)
|
||||||
|
self._saliency_cell.enable_hook = True
|
||||||
|
self._intermediate_grad = None
|
||||||
|
|
||||||
|
def _cell_hook_fn(self, _, grad_input, grad_output):
|
||||||
|
"""
|
||||||
|
Hook function to deal with the backward gradient.
|
||||||
|
|
||||||
|
The arguments are set as required by Cell.register_back_hook
|
||||||
|
"""
|
||||||
|
self._intermediate_grad = grad_input
|
||||||
|
|
||||||
|
def __call__(self, inputs, targets):
|
||||||
|
"""
|
||||||
|
Call function for `GradCAM`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Tensor): The input data to be explained, 4D Tensor.
|
||||||
|
targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer.
|
||||||
|
If `targets` is a 1D Tensor, its length should be the same as `inputs`.
|
||||||
|
"""
|
||||||
|
self._verify_data(inputs, targets)
|
||||||
|
self._hook_cell()
|
||||||
|
|
||||||
|
with ForwardProbe(self._saliency_cell) as probe:
|
||||||
|
|
||||||
|
inputs = unify_inputs(inputs)
|
||||||
|
targets = unify_targets(targets)
|
||||||
|
|
||||||
|
gradients = compute_gradients(self._backward_model, *inputs, targets)
|
||||||
|
|
||||||
|
# get intermediate activation
|
||||||
|
activation = (probe.value,)
|
||||||
|
|
||||||
|
if self._layer == "":
|
||||||
|
activation = inputs
|
||||||
|
self._intermediate_grad = unify_inputs(gradients)
|
||||||
|
if self._intermediate_grad is not None:
|
||||||
|
# average pooling on gradients
|
||||||
|
intermediate_grad = unify_inputs(
|
||||||
|
self._avgpool(self._intermediate_grad[0], (2, 3)))
|
||||||
|
else:
|
||||||
|
raise ValueError("Gradient for intermediate layer is not "
|
||||||
|
"obtained")
|
||||||
|
mul = op.Mul()
|
||||||
|
attribution = self._aggregation_fn(
|
||||||
|
mul(*intermediate_grad, *activation))
|
||||||
|
if self._resize:
|
||||||
|
attribution = self._resize_fn(attribution, *inputs,
|
||||||
|
mode=self._resize_mode)
|
||||||
|
self._intermediate_grad = None
|
||||||
|
|
||||||
|
return attribution
|
@ -0,0 +1,129 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Gradient explainer."""
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as op
|
||||||
|
from mindspore.train._utils import check_value_type
|
||||||
|
from ...._operators import reshape, sqrt, Tensor
|
||||||
|
from .._attribution import Attribution
|
||||||
|
from .backprop_utils import compute_gradients
|
||||||
|
from ...._utils import unify_inputs, unify_targets
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hook(bntype, cache):
|
||||||
|
"""Provide backward hook function for BatchNorm layer in eval mode."""
|
||||||
|
var, gamma, eps = cache
|
||||||
|
if bntype == "2d":
|
||||||
|
var = reshape(var, (1, -1, 1, 1))
|
||||||
|
gamma = reshape(gamma, (1, -1, 1, 1))
|
||||||
|
elif bntype == "1d":
|
||||||
|
var = reshape(var, (1, -1, 1))
|
||||||
|
gamma = reshape(gamma, (1, -1, 1))
|
||||||
|
|
||||||
|
def reset_gradient(_, grad_input, grad_output):
|
||||||
|
grad_output = grad_input[0] * gamma / sqrt(var + eps)
|
||||||
|
return grad_output
|
||||||
|
|
||||||
|
return reset_gradient
|
||||||
|
|
||||||
|
|
||||||
|
def _abs_max(gradients):
|
||||||
|
"""
|
||||||
|
Transform gradients to saliency through abs then take max along
|
||||||
|
channels.
|
||||||
|
"""
|
||||||
|
gradients = op.Abs()(gradients)
|
||||||
|
saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
|
||||||
|
return saliency
|
||||||
|
|
||||||
|
|
||||||
|
class Gradient(Attribution):
|
||||||
|
r"""
|
||||||
|
Provides Gradient explanation method.
|
||||||
|
|
||||||
|
Gradient is the simplest attribution method which uses the naive gradients of outputs w.r.t inputs as the
|
||||||
|
explanation.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
_attribution = \div{\delta{y}, \delta{x}}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The black-box model to be explained.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = resnet50(10)
|
||||||
|
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||||
|
>>> load_param_into_net(net, param_dict)
|
||||||
|
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||||
|
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||||
|
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||||
|
>>> # init Gradient with a trained network.
|
||||||
|
>>> gradient = Gradient(net)
|
||||||
|
>>> # parse data and the target label to be explained and get the saliency map
|
||||||
|
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||||
|
>>> label = 5
|
||||||
|
>>> saliency = gradient(inputs, label)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network):
|
||||||
|
super(Gradient, self).__init__(network)
|
||||||
|
self._backward_model = deepcopy(network)
|
||||||
|
self._backward_model.set_train(False)
|
||||||
|
self._backward_model.set_grad(False)
|
||||||
|
self._hook_bn()
|
||||||
|
self._grad_op = compute_gradients
|
||||||
|
self._aggregation_fn = _abs_max
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, inputs, targets):
|
||||||
|
"""
|
||||||
|
Call function for `Gradient`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Tensor): The input data to be explained, 4D Tensor.
|
||||||
|
targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer.
|
||||||
|
If `targets` is a 1D `Tensor`, its length should be the same as `inputs`.
|
||||||
|
"""
|
||||||
|
self._verify_data(inputs, targets)
|
||||||
|
inputs = unify_inputs(inputs)
|
||||||
|
targets = unify_targets(targets)
|
||||||
|
|
||||||
|
gradient = self._grad_op(self._backward_model, *inputs, targets)
|
||||||
|
saliency = self._aggregation_fn(gradient)
|
||||||
|
return saliency
|
||||||
|
|
||||||
|
def _hook_bn(self):
|
||||||
|
"""Hook BatchNorm layer for `self._backward_model.`"""
|
||||||
|
for _, cell in self._backward_model.cells_and_names():
|
||||||
|
if isinstance(cell, nn.BatchNorm2d):
|
||||||
|
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||||
|
cell.register_backward_hook(_get_hook("2d", cache=cache))
|
||||||
|
elif isinstance(cell, nn.BatchNorm1d):
|
||||||
|
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||||
|
cell.register_backward_hook(_get_hook("1d", cache=cache))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _verify_data(inputs, targets):
|
||||||
|
"""Verify the validity of the parsed inputs."""
|
||||||
|
check_value_type('inputs', inputs, Tensor)
|
||||||
|
if len(inputs.shape) != 4:
|
||||||
|
raise ValueError('Argument inputs must be 4D Tensor')
|
||||||
|
check_value_type('targets', targets, (Tensor, int))
|
||||||
|
if isinstance(targets, Tensor):
|
||||||
|
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)):
|
||||||
|
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
|
||||||
|
'it should have the same length as inputs.')
|
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Base class IntermediateLayerAttribution"""
|
||||||
|
|
||||||
|
from .gradient import Gradient
|
||||||
|
from ...._utils import resize as resize_fn
|
||||||
|
|
||||||
|
|
||||||
|
class IntermediateLayerAttribution(Gradient):
|
||||||
|
"""
|
||||||
|
Base class for generating _attribution map at intermediate layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (nn.Cell): DNN model to be explained.
|
||||||
|
layer (str, optional): string that specifies the layer to generate
|
||||||
|
intermediate _attribution. When using default value, the input layer
|
||||||
|
will be specified. Default: ''.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, layer=''):
|
||||||
|
super(IntermediateLayerAttribution, self).__init__(network)
|
||||||
|
|
||||||
|
# Whether resize the _attribution layer to the input size.
|
||||||
|
self._resize = True
|
||||||
|
# string that specifies the resize mode. Default: 'nearest_neighbor'.
|
||||||
|
self._resize_mode = 'nearest_neighbor'
|
||||||
|
|
||||||
|
self._layer = layer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resize_fn(attributions, inputs, mode):
|
||||||
|
"""Resize the intermediate layer _attribution to the same size as inputs."""
|
||||||
|
height, width = inputs.shape[2], inputs.shape[3]
|
||||||
|
return resize_fn(attributions, (height, width), mode)
|
@ -0,0 +1,117 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Explainer with modified ReLU."""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops.operations as op
|
||||||
|
|
||||||
|
from .gradient import Gradient
|
||||||
|
from ...._utils import (
|
||||||
|
unify_inputs,
|
||||||
|
unify_targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModifiedReLU(Gradient):
|
||||||
|
"""Basic class for modified ReLU explanation."""
|
||||||
|
|
||||||
|
def __init__(self, network, use_relu_backprop=False):
|
||||||
|
super(ModifiedReLU, self).__init__(network)
|
||||||
|
self.use_relu_backprop = use_relu_backprop
|
||||||
|
self.hooked_list = []
|
||||||
|
|
||||||
|
def __call__(self, inputs, targets):
|
||||||
|
self._verify_data(inputs, targets)
|
||||||
|
inputs = unify_inputs(inputs)
|
||||||
|
targets = unify_targets(targets)
|
||||||
|
|
||||||
|
self._hook_relu_backward()
|
||||||
|
gradients = self._grad_op(self._backward_model, inputs, targets)
|
||||||
|
saliency = self._aggregation_fn(gradients)
|
||||||
|
|
||||||
|
return saliency
|
||||||
|
|
||||||
|
def _hook_relu_backward(self):
|
||||||
|
"""Set backward hook for ReLU layers."""
|
||||||
|
for _, cell in self._backward_model.cells_and_names():
|
||||||
|
if isinstance(cell, nn.ReLU):
|
||||||
|
cell.register_backward_hook(self._backward_hook)
|
||||||
|
self.hooked_list.append(cell)
|
||||||
|
|
||||||
|
def _backward_hook(self, _, grad_inputs, grad_outputs):
|
||||||
|
"""Hook function for ReLU layers."""
|
||||||
|
inputs = grad_inputs if self.use_relu_backprop else grad_outputs
|
||||||
|
relu = op.ReLU()
|
||||||
|
if isinstance(inputs, tuple):
|
||||||
|
return relu(*inputs)
|
||||||
|
return relu(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class Deconvolution(ModifiedReLU):
|
||||||
|
"""
|
||||||
|
Deconvolution explanation.
|
||||||
|
|
||||||
|
To use `Deconvolution`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object
|
||||||
|
rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The black-box model to be explained.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = resnet50(10)
|
||||||
|
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||||
|
>>> load_param_into_net(net, param_dict)
|
||||||
|
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||||
|
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||||
|
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||||
|
>>> # init Gradient with a trained network.
|
||||||
|
>>> deconvolution = Deconvolution(net)
|
||||||
|
>>> # parse data and the target label to be explained and get the saliency map
|
||||||
|
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||||
|
>>> label = 5
|
||||||
|
>>> saliency = deconvolution(inputs, label)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network):
|
||||||
|
super(Deconvolution, self).__init__(network, use_relu_backprop=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GuidedBackprop(ModifiedReLU):
|
||||||
|
"""
|
||||||
|
Guided-Backpropation explanation.
|
||||||
|
|
||||||
|
To use `GuidedBackprop`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object
|
||||||
|
rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The black-box model to be explained.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = resnet50(10)
|
||||||
|
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||||
|
>>> load_param_into_net(net, param_dict)
|
||||||
|
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||||
|
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||||
|
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||||
|
>>> # init Gradient with a trained network.
|
||||||
|
>>> gbp = GuidedBackprop(net)
|
||||||
|
>>> # parse data and the target label to be explained and get the saliency map
|
||||||
|
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||||
|
>>> label = 5
|
||||||
|
>>> saliency = gbp(inputs, label)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GuidedBackprop, self).__init__(network, use_relu_backprop=False)
|
Loading…
Reference in new issue