From 6072b25a079e2decee9619c230d7ec89939acccd Mon Sep 17 00:00:00 2001 From: ougongchang Date: Tue, 20 Oct 2020 22:17:28 +0800 Subject: [PATCH] SummaryRecord support to record mindexplain data The SummaryRecord.add_value() method is extended to record the data of MindExplain. --- mindspore/ccsrc/utils/summary.proto | 49 +++++++++++++++++++ .../train/callback/_summary_collector.py | 2 +- mindspore/train/summary/_explain_adapter.py | 48 ++++++++++++++++++ mindspore/train/summary/_writer_pool.py | 8 ++- mindspore/train/summary/{enum.py => enums.py} | 0 mindspore/train/summary/summary_record.py | 43 ++++++++-------- .../summary/{_summary_writer.py => writer.py} | 9 ++++ .../train/summary/test_summary_collector.py | 2 +- 8 files changed, 137 insertions(+), 24 deletions(-) create mode 100644 mindspore/train/summary/_explain_adapter.py rename mindspore/train/summary/{enum.py => enums.py} (100%) rename mindspore/train/summary/{_summary_writer.py => writer.py} (93%) diff --git a/mindspore/ccsrc/utils/summary.proto b/mindspore/ccsrc/utils/summary.proto index f4a2ce957b..32043ea398 100644 --- a/mindspore/ccsrc/utils/summary.proto +++ b/mindspore/ccsrc/utils/summary.proto @@ -40,6 +40,8 @@ message Event { // Summary data Summary summary = 5; + + Explain explain = 6; } } @@ -101,3 +103,50 @@ message Summary { // Set of values for the summary. repeated Value value = 1; } + +message Explain { + message Inference{ + repeated float ground_truth_prob = 1; + repeated int32 predicted_label = 2; + repeated float predicted_prob = 3; + } + + message Explanation{ + optional string explain_method = 1; + optional int32 label = 2; + optional bytes heatmap = 3; + } + + message Benchmark{ + message TotalScore{ + optional string benchmark_method = 1; + optional float score = 2; + } + message LabelScore{ + repeated float score = 1; + optional string benchmark_method = 2; + } + + optional string explain_method = 1; + repeated TotalScore total_score = 2; + repeated LabelScore label_score = 3; + } + + message Metadata{ + repeated string label = 1; + repeated string explain_method = 2; + repeated string benchmark_method = 3; + } + + optional string image_id = 1; // The Metadata and image id must have one fill in + optional bytes image_data = 2; + repeated int32 ground_truth_label = 3; + + + optional Inference inference = 4; + repeated Explanation explanation = 5; + repeated Benchmark benchmark = 6; + + optional Metadata metadata = 7; + optional string status = 8; // enum value: run, end +} \ No newline at end of file diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index bfd10bb23f..3ce324bb66 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -26,7 +26,7 @@ from mindspore import log as logger from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.train.summary.summary_record import SummaryRecord -from mindspore.train.summary.enum import PluginEnum, ModeEnum +from mindspore.train.summary.enums import PluginEnum, ModeEnum from mindspore.train.callback import Callback, ModelCheckpoint from mindspore.train import lineage_pb2 from mindspore.train.callback._dataset_graph import DatasetGraph diff --git a/mindspore/train/summary/_explain_adapter.py b/mindspore/train/summary/_explain_adapter.py new file mode 100644 index 0000000000..156aae530a --- /dev/null +++ b/mindspore/train/summary/_explain_adapter.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Generate the explain event which conform to proto format.""" +import time + +from ..summary_pb2 import Event, Explain + + +def check_explain_proto(explain): + """ + Package the explain event. + + Args: + explain (Explain): The object of summary_pb2.Explain. + """ + if not isinstance(explain, Explain): + raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.') + + if not explain.image_id and not explain.metadata.label and not explain.benchmark: + raise ValueError(f'The Metadata and image id and benchmark must have one fill in.') + + +def package_explain_event(explain_str): + """ + Package the explain event. + + Args: + explain_str (string): The serialize string of summary_pb2.Explain. + + Returns: + Event, event object. + """ + event = Event() + event.wall_time = time.time() + event.explain.ParseFromString(explain_str) + return event.SerializeToString() diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index c27c696103..9b828b8665 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -21,7 +21,8 @@ import mindspore.log as logger from ._lineage_adapter import serialize_to_lineage_event from ._summary_adapter import package_graph_event, package_summary_event -from ._summary_writer import LineageWriter, SummaryWriter +from ._explain_adapter import package_explain_event +from .writer import LineageWriter, SummaryWriter, ExplainWriter try: from multiprocessing import get_context @@ -42,6 +43,8 @@ def _pack_data(datadict, wall_time): elif plugin in ('scalar', 'tensor', 'histogram', 'image'): summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) step = data.get('step') + elif plugin == 'explainer': + result.append([plugin, package_explain_event(data.get('value'))]) if summaries: result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()]) return result @@ -98,6 +101,8 @@ class WriterPool(ctx.Process): self._writers_.append(SummaryWriter(filepath, self._max_file_size)) elif plugin == 'lineage': self._writers_.append(LineageWriter(filepath, self._max_file_size)) + elif plugin == 'explainer': + self._writers_.append(ExplainWriter(filepath, self._max_file_size)) return self._writers_ def _write(self, plugin, data): @@ -125,7 +130,6 @@ class WriterPool(ctx.Process): Write the event to file. Args: - name (str): The key of a specified file. data (Optional[str, Tuple[list, int]]): The data to write. """ self._queue.put(('WRITE', data)) diff --git a/mindspore/train/summary/enum.py b/mindspore/train/summary/enums.py similarity index 100% rename from mindspore/train/summary/enum.py rename to mindspore/train/summary/enums.py diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 6f3cd7574d..d84df9f5e1 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -17,6 +17,7 @@ import atexit import os import re import threading +from collections import defaultdict from mindspore import log as logger @@ -24,6 +25,7 @@ from ..._c_expression import Tensor from ..._checkparam import Validator from .._utils import _check_lineage_value, _check_to_numpy, _make_directory from ._summary_adapter import get_event_file_name, package_graph_event +from ._explain_adapter import check_explain_proto from ._writer_pool import WriterPool # for the moment, this lock is for caution's sake, @@ -55,7 +57,6 @@ def _get_summary_tensor_data(): def _dictlist(): - from collections import defaultdict return defaultdict(list) @@ -133,7 +134,8 @@ class SummaryRecord: self._event_writer = WriterPool(log_dir, max_file_size, summary=self.full_file_name, - lineage=get_event_file_name(self.prefix, '_lineage')) + lineage=get_event_file_name(self.prefix, '_lineage'), + explainer=get_event_file_name(self.prefix, '_explain')) _get_summary_tensor_data() atexit.register(self.close) @@ -149,10 +151,11 @@ class SummaryRecord: def set_mode(self, mode): """ - Set the mode for the recorder to be aware. The mode is set to 'train' by default. + Sets the training phase. Different training phases affect data recording. Args: - mode (str): The mode to be set, which should be 'train' or 'eval'. + mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval', + summary_record will not record the data of summary operators. Raises: ValueError: When the mode is not recognized. @@ -170,29 +173,26 @@ class SummaryRecord: """ Add value to be recorded later. - When the plugin is 'tensor', 'scalar', 'image' or 'histogram', - the name should be the tag name, and the value should be a Tensor. - - When the plugin is 'graph', the value should be a GraphProto. - - When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage', - or 'custom_lineage_data', the value should be a proto message. - - Args: plugin (str): The value of the plugin. name (str): The value of the name. value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ The value to store. - - The data type of value should be 'GraphProto' when the plugin is 'graph'. - - The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor' + - The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object + when the plugin is 'graph'. + - The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor' or 'histogram'. - - The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'. - - The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'. - - The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'. - - The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'. - + - The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage', + see mindspore/ccsrc/lineage.proto. + - The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage', + see mindspore/ccsrc/lineage.proto. + - The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph', + see mindspore/ccsrc/lineage.proto. + - The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data', + see mindspore/ccsrc/lineage.proto. + - The data type of value should be a 'Explain' object when the plugin is 'explainer', + see mindspore/ccsrc/summary.proto. Raises: ValueError: When the name is not valid. TypeError: When the value is not a Tensor. @@ -218,6 +218,9 @@ class SummaryRecord: elif plugin == 'graph': package_graph_event(value) self._data_pool[plugin].append(dict(value=value)) + elif plugin == 'explainer': + check_explain_proto(value) + self._data_pool[plugin].append(dict(value=value.SerializeToString())) else: raise ValueError(f'No such plugin of {repr(plugin)}') diff --git a/mindspore/train/summary/_summary_writer.py b/mindspore/train/summary/writer.py similarity index 93% rename from mindspore/train/summary/_summary_writer.py rename to mindspore/train/summary/writer.py index a6874ce713..1a8e424473 100644 --- a/mindspore/train/summary/_summary_writer.py +++ b/mindspore/train/summary/writer.py @@ -94,3 +94,12 @@ class LineageWriter(BaseWriter): """Write data to file.""" if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): super().write(plugin, data) + + +class ExplainWriter(BaseWriter): + """ExplainWriter for write explain data.""" + + def write(self, plugin, data): + """Write data to file.""" + if plugin == 'explainer': + super().write(plugin, data) diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py index f802f58bcc..3349cf8287 100644 --- a/tests/ut/python/train/summary/test_summary_collector.py +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -26,7 +26,7 @@ from mindspore import Tensor from mindspore import Parameter from mindspore.train.callback import SummaryCollector from mindspore.train.callback import _InternalCallbackParam -from mindspore.train.summary.enum import ModeEnum, PluginEnum +from mindspore.train.summary.enums import ModeEnum, PluginEnum from mindspore.train.summary import SummaryRecord from mindspore.nn import Cell from mindspore.nn.optim.optimizer import Optimizer