Fixbug I2CZVJ: shape error during ReLU Gradients

pull/11400/head
lixiaohui 4 years ago
parent fa3638ad6b
commit 7284308828

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,12 +14,12 @@
# ============================================================================
"""Providing utility functions."""
from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation
from mindspore.explainer._utils import unify_inputs, unify_targets, generate_one_hot
from ...._utils import unify_inputs, unify_targets, generate_one_hot
def compute_gradients(model, inputs, targets=None, weights=None):
def get_bp_weights(model, inputs, targets=None, weights=None):
r"""
Compute the gradient of output w.r.t input.
@ -42,8 +42,28 @@ def compute_gradients(model, inputs, targets=None, weights=None):
output = model(*inputs)
num_categories = output.shape[-1]
weights = generate_one_hot(targets, num_categories)
return weights
class GradNet(Cell):
"""
Network for gradient calculation.
Args:
network (Cell): The network to generate backpropagated gradients.
"""
def __init__(self, network):
super(GradNet, self).__init__()
self.network = network
self.grad = GradOperation(get_all=True, sens_param=True)(network)
def construct(self, *input_data):
"""
Get backpropgated gradients.
grad_op = GradOperation(
get_all=True, get_by_list=False, sens_param=True)(model)
gradients = grad_op(*inputs, weights)
return gradients[0]
Returns:
Tensor, output gradients.
"""
gout = self.grad(*input_data)[0]
return gout

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,10 +16,10 @@
"""GradCAM."""
from mindspore.ops import operations as op
from mindspore.explainer._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets
from .backprop_utils import compute_gradients
from .backprop_utils import get_bp_weights, GradNet
from .intermediate_layer import IntermediateLayerAttribution
from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets
def _gradcam_aggregation(attributions):
@ -123,8 +123,9 @@ class GradCAM(IntermediateLayerAttribution):
inputs = unify_inputs(inputs)
targets = unify_targets(targets)
gradients = compute_gradients(self._backward_model, *inputs, targets)
weights = get_bp_weights(self._backward_model, *inputs, targets)
grad_net = GradNet(self._backward_model)
gradients = grad_net(*inputs, weights)
# get intermediate activation
activation = (probe.value,)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,10 +17,11 @@ from copy import deepcopy
from mindspore import nn
from mindspore.train._utils import check_value_type
from ...._operators import reshape, sqrt, Tensor
from ..attribution import Attribution
from .backprop_utils import compute_gradients
from ...._utils import abs_max, unify_inputs, unify_targets
from mindspore.explainer._operators import reshape, sqrt, Tensor
from mindspore.explainer._utils import abs_max, unify_inputs, unify_targets
from .. import Attribution
from .backprop_utils import get_bp_weights, GradNet
def _get_hook(bntype, cache):
@ -88,7 +89,7 @@ class Gradient(Attribution):
self._backward_model.set_train(False)
self._backward_model.set_grad(False)
self._hook_bn()
self._grad_op = compute_gradients
self._grad_net = GradNet(self._backward_model)
self._aggregation_fn = abs_max
def __call__(self, inputs, targets):
@ -97,7 +98,8 @@ class Gradient(Attribution):
inputs = unify_inputs(inputs)
targets = unify_targets(targets)
gradient = self._grad_op(self._backward_model, *inputs, targets)
weights = get_bp_weights(self._backward_model, *inputs, targets)
gradient = self._grad_net(*inputs, weights)
saliency = self._aggregation_fn(gradient)
return saliency

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,8 +15,9 @@
"""Base class IntermediateLayerAttribution"""
from mindspore.explainer._utils import resize as resize_fn
from .gradient import Gradient
from ...._utils import resize as resize_fn
class IntermediateLayerAttribution(Gradient):

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,13 +16,14 @@
import mindspore.nn as nn
import mindspore.ops.operations as op
from .gradient import Gradient
from ...._utils import (
from mindspore.explainer._utils import (
unify_inputs,
unify_targets,
)
from .backprop_utils import GradNet, get_bp_weights
from .gradient import Gradient
class ModifiedReLU(Gradient):
"""Basic class for modified ReLU explanation."""
@ -30,7 +31,8 @@ class ModifiedReLU(Gradient):
def __init__(self, network, use_relu_backprop=False):
super(ModifiedReLU, self).__init__(network)
self.use_relu_backprop = use_relu_backprop
self.hooked_list = []
self._hook_relu_backward()
self._grad_net = GradNet(self._backward_model)
def __call__(self, inputs, targets):
"""
@ -56,8 +58,8 @@ class ModifiedReLU(Gradient):
inputs = unify_inputs(inputs)
targets = unify_targets(targets)
self._hook_relu_backward()
gradients = self._grad_op(self._backward_model, inputs, targets)
weights = get_bp_weights(self._backward_model, inputs, targets)
gradients = self._grad_net(*inputs, weights)
saliency = self._aggregation_fn(gradients)
return saliency
@ -67,7 +69,6 @@ class ModifiedReLU(Gradient):
for _, cell in self._backward_model.cells_and_names():
if isinstance(cell, nn.ReLU):
cell.register_backward_hook(self._backward_hook)
self.hooked_list.append(cell)
def _backward_hook(self, _, grad_inputs, grad_outputs):
"""Hook function for ReLU layers."""

Loading…
Cancel
Save