|
|
|
@ -17,8 +17,9 @@ from time import time
|
|
|
|
|
from typing import Tuple, List, Optional
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore.train.summary_pb2 import Explain
|
|
|
|
|
|
|
|
|
|
from mindspore.train._utils import check_value_type
|
|
|
|
|
from mindspore.train.summary_pb2 import Explain
|
|
|
|
|
import mindspore as ms
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
|
from mindspore import log
|
|
|
|
@ -71,6 +72,7 @@ class ExplainRunner:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, summary_dir: Optional[str] = "./"):
|
|
|
|
|
check_value_type("summary_dir", summary_dir, str)
|
|
|
|
|
self._summary_dir = summary_dir
|
|
|
|
|
self._count = 0
|
|
|
|
|
self._classes = None
|
|
|
|
@ -123,14 +125,21 @@ class ExplainRunner:
|
|
|
|
|
for exp in explainers:
|
|
|
|
|
if not isinstance(exp, Attribution) or not isinstance(explainers, list):
|
|
|
|
|
raise TypeError("Argument explainers should be a list of objects of classes in "
|
|
|
|
|
"`mindspore.explainer.explanation._attribution`.")
|
|
|
|
|
"`mindspore.explainer.explanation`.")
|
|
|
|
|
if benchmarkers is not None:
|
|
|
|
|
for bench in benchmarkers:
|
|
|
|
|
if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list):
|
|
|
|
|
raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
|
|
|
|
|
"`mindspore.explainer.benchmark._attribution`.")
|
|
|
|
|
"`mindspore.explainer.benchmark`.")
|
|
|
|
|
|
|
|
|
|
self._model = explainers[0].model
|
|
|
|
|
next_element = dataset.create_tuple_iterator().get_next()
|
|
|
|
|
inputs, _, _ = self._unpack_next_element(next_element)
|
|
|
|
|
prop_test = self._model(inputs)
|
|
|
|
|
check_value_type("output of model im explainer", prop_test, ms.Tensor)
|
|
|
|
|
if prop_test.shape[1] > len(self._classes):
|
|
|
|
|
raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please "
|
|
|
|
|
"check dataset classes or the black-box model in the explainer again.")
|
|
|
|
|
|
|
|
|
|
with SummaryRecord(self._summary_dir) as summary:
|
|
|
|
|
print("Start running and writing......")
|
|
|
|
|