SummaryRecord support to record mindexplain data

The SummaryRecord.add_value() method is extended to record the data of
MindExplain.
pull/7534/head
ougongchang 4 years ago
parent d0a1a9b73c
commit 6072b25a07

@ -40,6 +40,8 @@ message Event {
// Summary data // Summary data
Summary summary = 5; Summary summary = 5;
Explain explain = 6;
} }
} }
@ -101,3 +103,50 @@ message Summary {
// Set of values for the summary. // Set of values for the summary.
repeated Value value = 1; repeated Value value = 1;
} }
message Explain {
message Inference{
repeated float ground_truth_prob = 1;
repeated int32 predicted_label = 2;
repeated float predicted_prob = 3;
}
message Explanation{
optional string explain_method = 1;
optional int32 label = 2;
optional bytes heatmap = 3;
}
message Benchmark{
message TotalScore{
optional string benchmark_method = 1;
optional float score = 2;
}
message LabelScore{
repeated float score = 1;
optional string benchmark_method = 2;
}
optional string explain_method = 1;
repeated TotalScore total_score = 2;
repeated LabelScore label_score = 3;
}
message Metadata{
repeated string label = 1;
repeated string explain_method = 2;
repeated string benchmark_method = 3;
}
optional string image_id = 1; // The Metadata and image id must have one fill in
optional bytes image_data = 2;
repeated int32 ground_truth_label = 3;
optional Inference inference = 4;
repeated Explanation explanation = 5;
repeated Benchmark benchmark = 6;
optional Metadata metadata = 7;
optional string status = 8; // enum value: run, end
}

@ -26,7 +26,7 @@ from mindspore import log as logger
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.train.summary.summary_record import SummaryRecord from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary.enum import PluginEnum, ModeEnum from mindspore.train.summary.enums import PluginEnum, ModeEnum
from mindspore.train.callback import Callback, ModelCheckpoint from mindspore.train.callback import Callback, ModelCheckpoint
from mindspore.train import lineage_pb2 from mindspore.train import lineage_pb2
from mindspore.train.callback._dataset_graph import DatasetGraph from mindspore.train.callback._dataset_graph import DatasetGraph

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate the explain event which conform to proto format."""
import time
from ..summary_pb2 import Event, Explain
def check_explain_proto(explain):
"""
Package the explain event.
Args:
explain (Explain): The object of summary_pb2.Explain.
"""
if not isinstance(explain, Explain):
raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.')
if not explain.image_id and not explain.metadata.label and not explain.benchmark:
raise ValueError(f'The Metadata and image id and benchmark must have one fill in.')
def package_explain_event(explain_str):
"""
Package the explain event.
Args:
explain_str (string): The serialize string of summary_pb2.Explain.
Returns:
Event, event object.
"""
event = Event()
event.wall_time = time.time()
event.explain.ParseFromString(explain_str)
return event.SerializeToString()

@ -21,7 +21,8 @@ import mindspore.log as logger
from ._lineage_adapter import serialize_to_lineage_event from ._lineage_adapter import serialize_to_lineage_event
from ._summary_adapter import package_graph_event, package_summary_event from ._summary_adapter import package_graph_event, package_summary_event
from ._summary_writer import LineageWriter, SummaryWriter from ._explain_adapter import package_explain_event
from .writer import LineageWriter, SummaryWriter, ExplainWriter
try: try:
from multiprocessing import get_context from multiprocessing import get_context
@ -42,6 +43,8 @@ def _pack_data(datadict, wall_time):
elif plugin in ('scalar', 'tensor', 'histogram', 'image'): elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
step = data.get('step') step = data.get('step')
elif plugin == 'explainer':
result.append([plugin, package_explain_event(data.get('value'))])
if summaries: if summaries:
result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()]) result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()])
return result return result
@ -98,6 +101,8 @@ class WriterPool(ctx.Process):
self._writers_.append(SummaryWriter(filepath, self._max_file_size)) self._writers_.append(SummaryWriter(filepath, self._max_file_size))
elif plugin == 'lineage': elif plugin == 'lineage':
self._writers_.append(LineageWriter(filepath, self._max_file_size)) self._writers_.append(LineageWriter(filepath, self._max_file_size))
elif plugin == 'explainer':
self._writers_.append(ExplainWriter(filepath, self._max_file_size))
return self._writers_ return self._writers_
def _write(self, plugin, data): def _write(self, plugin, data):
@ -125,7 +130,6 @@ class WriterPool(ctx.Process):
Write the event to file. Write the event to file.
Args: Args:
name (str): The key of a specified file.
data (Optional[str, Tuple[list, int]]): The data to write. data (Optional[str, Tuple[list, int]]): The data to write.
""" """
self._queue.put(('WRITE', data)) self._queue.put(('WRITE', data))

@ -17,6 +17,7 @@ import atexit
import os import os
import re import re
import threading import threading
from collections import defaultdict
from mindspore import log as logger from mindspore import log as logger
@ -24,6 +25,7 @@ from ..._c_expression import Tensor
from ..._checkparam import Validator from ..._checkparam import Validator
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory from .._utils import _check_lineage_value, _check_to_numpy, _make_directory
from ._summary_adapter import get_event_file_name, package_graph_event from ._summary_adapter import get_event_file_name, package_graph_event
from ._explain_adapter import check_explain_proto
from ._writer_pool import WriterPool from ._writer_pool import WriterPool
# for the moment, this lock is for caution's sake, # for the moment, this lock is for caution's sake,
@ -55,7 +57,6 @@ def _get_summary_tensor_data():
def _dictlist(): def _dictlist():
from collections import defaultdict
return defaultdict(list) return defaultdict(list)
@ -133,7 +134,8 @@ class SummaryRecord:
self._event_writer = WriterPool(log_dir, self._event_writer = WriterPool(log_dir,
max_file_size, max_file_size,
summary=self.full_file_name, summary=self.full_file_name,
lineage=get_event_file_name(self.prefix, '_lineage')) lineage=get_event_file_name(self.prefix, '_lineage'),
explainer=get_event_file_name(self.prefix, '_explain'))
_get_summary_tensor_data() _get_summary_tensor_data()
atexit.register(self.close) atexit.register(self.close)
@ -149,10 +151,11 @@ class SummaryRecord:
def set_mode(self, mode): def set_mode(self, mode):
""" """
Set the mode for the recorder to be aware. The mode is set to 'train' by default. Sets the training phase. Different training phases affect data recording.
Args: Args:
mode (str): The mode to be set, which should be 'train' or 'eval'. mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval',
summary_record will not record the data of summary operators.
Raises: Raises:
ValueError: When the mode is not recognized. ValueError: When the mode is not recognized.
@ -170,29 +173,26 @@ class SummaryRecord:
""" """
Add value to be recorded later. Add value to be recorded later.
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
the name should be the tag name, and the value should be a Tensor.
When the plugin is 'graph', the value should be a GraphProto.
When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage',
or 'custom_lineage_data', the value should be a proto message.
Args: Args:
plugin (str): The value of the plugin. plugin (str): The value of the plugin.
name (str): The value of the name. name (str): The value of the name.
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
The value to store. The value to store.
- The data type of value should be 'GraphProto' when the plugin is 'graph'. - The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object
- The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor' when the plugin is 'graph'.
- The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor'
or 'histogram'. or 'histogram'.
- The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'. - The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage',
- The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'. see mindspore/ccsrc/lineage.proto.
- The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'. - The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage',
- The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'. see mindspore/ccsrc/lineage.proto.
- The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph',
see mindspore/ccsrc/lineage.proto.
- The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data',
see mindspore/ccsrc/lineage.proto.
- The data type of value should be a 'Explain' object when the plugin is 'explainer',
see mindspore/ccsrc/summary.proto.
Raises: Raises:
ValueError: When the name is not valid. ValueError: When the name is not valid.
TypeError: When the value is not a Tensor. TypeError: When the value is not a Tensor.
@ -218,6 +218,9 @@ class SummaryRecord:
elif plugin == 'graph': elif plugin == 'graph':
package_graph_event(value) package_graph_event(value)
self._data_pool[plugin].append(dict(value=value)) self._data_pool[plugin].append(dict(value=value))
elif plugin == 'explainer':
check_explain_proto(value)
self._data_pool[plugin].append(dict(value=value.SerializeToString()))
else: else:
raise ValueError(f'No such plugin of {repr(plugin)}') raise ValueError(f'No such plugin of {repr(plugin)}')

@ -94,3 +94,12 @@ class LineageWriter(BaseWriter):
"""Write data to file.""" """Write data to file."""
if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'):
super().write(plugin, data) super().write(plugin, data)
class ExplainWriter(BaseWriter):
"""ExplainWriter for write explain data."""
def write(self, plugin, data):
"""Write data to file."""
if plugin == 'explainer':
super().write(plugin, data)

@ -26,7 +26,7 @@ from mindspore import Tensor
from mindspore import Parameter from mindspore import Parameter
from mindspore.train.callback import SummaryCollector from mindspore.train.callback import SummaryCollector
from mindspore.train.callback import _InternalCallbackParam from mindspore.train.callback import _InternalCallbackParam
from mindspore.train.summary.enum import ModeEnum, PluginEnum from mindspore.train.summary.enums import ModeEnum, PluginEnum
from mindspore.train.summary import SummaryRecord from mindspore.train.summary import SummaryRecord
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer

Loading…
Cancel
Save