|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
@ -16,13 +16,14 @@
|
|
|
|
|
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
import mindspore.ops.operations as op
|
|
|
|
|
|
|
|
|
|
from .gradient import Gradient
|
|
|
|
|
from ...._utils import (
|
|
|
|
|
from mindspore.explainer._utils import (
|
|
|
|
|
unify_inputs,
|
|
|
|
|
unify_targets,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from .backprop_utils import GradNet, get_bp_weights
|
|
|
|
|
from .gradient import Gradient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModifiedReLU(Gradient):
|
|
|
|
|
"""Basic class for modified ReLU explanation."""
|
|
|
|
@ -30,7 +31,8 @@ class ModifiedReLU(Gradient):
|
|
|
|
|
def __init__(self, network, use_relu_backprop=False):
|
|
|
|
|
super(ModifiedReLU, self).__init__(network)
|
|
|
|
|
self.use_relu_backprop = use_relu_backprop
|
|
|
|
|
self.hooked_list = []
|
|
|
|
|
self._hook_relu_backward()
|
|
|
|
|
self._grad_net = GradNet(self._backward_model)
|
|
|
|
|
|
|
|
|
|
def __call__(self, inputs, targets):
|
|
|
|
|
"""
|
|
|
|
@ -56,8 +58,8 @@ class ModifiedReLU(Gradient):
|
|
|
|
|
inputs = unify_inputs(inputs)
|
|
|
|
|
targets = unify_targets(targets)
|
|
|
|
|
|
|
|
|
|
self._hook_relu_backward()
|
|
|
|
|
gradients = self._grad_op(self._backward_model, inputs, targets)
|
|
|
|
|
weights = get_bp_weights(self._backward_model, inputs, targets)
|
|
|
|
|
gradients = self._grad_net(*inputs, weights)
|
|
|
|
|
saliency = self._aggregation_fn(gradients)
|
|
|
|
|
|
|
|
|
|
return saliency
|
|
|
|
@ -67,7 +69,6 @@ class ModifiedReLU(Gradient):
|
|
|
|
|
for _, cell in self._backward_model.cells_and_names():
|
|
|
|
|
if isinstance(cell, nn.ReLU):
|
|
|
|
|
cell.register_backward_hook(self._backward_hook)
|
|
|
|
|
self.hooked_list.append(cell)
|
|
|
|
|
|
|
|
|
|
def _backward_hook(self, _, grad_inputs, grad_outputs):
|
|
|
|
|
"""Hook function for ReLU layers."""
|
|
|
|
|