!1935 Summary callback as collector for summary and lineage
Merge pull request !1935 from 李鸿章/policy_writerpull/1935/MERGE
commit
8867c67d61
@ -0,0 +1,129 @@
|
||||
// Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mindspore.irpb;
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
|
||||
// Event Protocol buffer, Top define
|
||||
message LineageEvent {
|
||||
// Timestamp
|
||||
required double wall_time = 1;
|
||||
|
||||
// The step of train.
|
||||
optional int64 step = 2;
|
||||
|
||||
oneof what {
|
||||
// An event file was started, with the specified version.
|
||||
// Now version is "Mindspore.Event:1"
|
||||
string version = 3;
|
||||
|
||||
// Train lineage
|
||||
TrainLineage train_lineage = 6;
|
||||
|
||||
// Evaluation lineage
|
||||
EvaluationLineage evaluation_lineage = 7;
|
||||
|
||||
// Dataset graph
|
||||
DatasetGraph dataset_graph = 9;
|
||||
|
||||
// User defined info
|
||||
UserDefinedInfo user_defined_info = 10;
|
||||
}
|
||||
}
|
||||
|
||||
// User defined info
|
||||
message UserDefinedInfo{
|
||||
// repeated user defined info
|
||||
repeated UserDefinedInfo user_info = 1;
|
||||
|
||||
// key/value which contains both scalar and dict
|
||||
map<string, UserDefinedInfo> map_dict = 2;
|
||||
map<string, int32> map_int32 = 3;
|
||||
map<string, string> map_str = 4;
|
||||
map<string, double> map_double = 5;
|
||||
}
|
||||
|
||||
// TrainLineage records infos of a train.
|
||||
message TrainLineage{
|
||||
message HyperParameters{
|
||||
optional string optimizer = 1;
|
||||
optional float learning_rate = 2;
|
||||
optional string loss_function = 3;
|
||||
optional int32 epoch = 4;
|
||||
optional string parallel_mode = 5;
|
||||
optional int32 device_num = 6;
|
||||
optional int32 batch_size = 8;
|
||||
}
|
||||
|
||||
message TrainDataset{
|
||||
optional string train_dataset_path = 1;
|
||||
optional int32 train_dataset_size = 2;
|
||||
}
|
||||
|
||||
message Algorithm{
|
||||
optional string network = 1;
|
||||
optional float loss = 2;
|
||||
}
|
||||
|
||||
message Model{
|
||||
optional string path = 3;
|
||||
optional int64 size = 4;
|
||||
}
|
||||
|
||||
optional HyperParameters hyper_parameters = 1;
|
||||
optional TrainDataset train_dataset = 2;
|
||||
optional Algorithm algorithm = 3;
|
||||
optional Model model = 4;
|
||||
}
|
||||
|
||||
//EvalLineage records infos of evaluation.
|
||||
message EvaluationLineage{
|
||||
message ValidDataset{
|
||||
optional string valid_dataset_path = 1;
|
||||
optional int32 valid_dataset_size = 2;
|
||||
}
|
||||
|
||||
optional string metric = 2;
|
||||
optional ValidDataset valid_dataset = 3;
|
||||
}
|
||||
|
||||
|
||||
// DatasetGraph
|
||||
message DatasetGraph {
|
||||
repeated DatasetGraph children = 1;
|
||||
optional OperationParameter parameter = 2;
|
||||
repeated Operation operations = 3;
|
||||
optional Operation sampler = 4;
|
||||
}
|
||||
|
||||
message Operation {
|
||||
optional OperationParameter operationParam = 1;
|
||||
repeated int32 size = 2;
|
||||
repeated float weights = 3;
|
||||
}
|
||||
|
||||
message OperationParameter{
|
||||
map<string, string> mapStr = 1;
|
||||
map<string, StrList> mapStrList = 2;
|
||||
map<string, bool> mapBool = 3;
|
||||
map<string, int32> mapInt = 4;
|
||||
map<string, double> mapDouble = 5;
|
||||
}
|
||||
|
||||
message StrList {
|
||||
repeated string strValue = 1;
|
||||
}
|
@ -1,88 +0,0 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Writes events to disk in a logdir."""
|
||||
import os
|
||||
import stat
|
||||
from collections import deque
|
||||
from multiprocessing import Pool, Process, Queue, cpu_count
|
||||
|
||||
from ..._c_expression import EventWriter_
|
||||
from ._summary_adapter import package_summary_event
|
||||
|
||||
|
||||
def _pack(result, step):
|
||||
summary_event = package_summary_event(result, step)
|
||||
return summary_event.SerializeToString()
|
||||
|
||||
|
||||
class EventWriter(Process):
|
||||
"""
|
||||
Creates a `EventWriter` and write event to file.
|
||||
|
||||
Args:
|
||||
filepath (str): Summary event file path and file name.
|
||||
flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120.
|
||||
"""
|
||||
|
||||
def __init__(self, filepath: str, flush_interval: int) -> None:
|
||||
super().__init__()
|
||||
_ = flush_interval
|
||||
with open(filepath, 'w'):
|
||||
os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR)
|
||||
self._writer = EventWriter_(filepath)
|
||||
self._queue = Queue(cpu_count() * 2)
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
with Pool(min(cpu_count(), 32)) as pool:
|
||||
deq = deque()
|
||||
while True:
|
||||
while deq and deq[0].ready():
|
||||
self._writer.Write(deq.popleft().get())
|
||||
|
||||
if not self._queue.empty():
|
||||
action, data = self._queue.get()
|
||||
if action == 'WRITE':
|
||||
if not isinstance(data, (str, bytes)):
|
||||
deq.append(pool.apply_async(_pack, data))
|
||||
else:
|
||||
self._writer.Write(data)
|
||||
elif action == 'FLUSH':
|
||||
self._writer.Flush()
|
||||
elif action == 'END':
|
||||
break
|
||||
for res in deq:
|
||||
self._writer.Write(res.get())
|
||||
|
||||
self._writer.Shut()
|
||||
|
||||
def write(self, data) -> None:
|
||||
"""
|
||||
Write the event to file.
|
||||
|
||||
Args:
|
||||
data (Optional[str, Tuple[list, int]]): The data to write.
|
||||
"""
|
||||
self._queue.put(('WRITE', data))
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer."""
|
||||
self._queue.put(('FLUSH', None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the writer."""
|
||||
self._queue.put(('END', None))
|
||||
self.join()
|
@ -0,0 +1,39 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate the lineage event which conform to proto format."""
|
||||
import time
|
||||
|
||||
from ..lineage_pb2 import LineageEvent
|
||||
|
||||
|
||||
def serialize_to_lineage_event(name, value):
|
||||
"""Serialize value to lineage event."""
|
||||
event = LineageEvent()
|
||||
event.wall_time = time.time()
|
||||
content = _get_lineage_content(name, event)
|
||||
content.ParseFromString(value)
|
||||
return event.SerializeToString()
|
||||
|
||||
|
||||
def _get_lineage_content(name, event):
|
||||
if name == 'dataset_graph':
|
||||
return event.dataset_graph
|
||||
if name == 'eval_lineage':
|
||||
return event.evaluation_lineage
|
||||
if name == 'train_lineage':
|
||||
return event.train_lineage
|
||||
if name == 'custom_lineage_data':
|
||||
return event.user_defined_info
|
||||
raise KeyError(f'No such field in LineageEvent')
|
@ -0,0 +1,79 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Writes events to disk in a logdir."""
|
||||
import os
|
||||
import stat
|
||||
|
||||
from ..._c_expression import EventWriter_
|
||||
from ._summary_adapter import package_init_event
|
||||
|
||||
|
||||
class BaseWriter:
|
||||
"""BaseWriter to be subclass."""
|
||||
|
||||
def __init__(self, filepath) -> None:
|
||||
self._filepath = filepath
|
||||
self._writer: EventWriter_ = None
|
||||
|
||||
def init_writer(self):
|
||||
"""Write some metadata etc."""
|
||||
|
||||
@property
|
||||
def writer(self) -> EventWriter_:
|
||||
"""Get the writer."""
|
||||
if self._writer is not None:
|
||||
return self._writer
|
||||
|
||||
with open(self._filepath, 'w'):
|
||||
os.chmod(self._filepath, stat.S_IWUSR | stat.S_IRUSR)
|
||||
self._writer = EventWriter_(self._filepath)
|
||||
self.init_writer()
|
||||
return self._writer
|
||||
|
||||
def write(self, plugin, mode, data):
|
||||
"""Write data to file."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer."""
|
||||
if self._writer is not None:
|
||||
self._writer.Flush()
|
||||
|
||||
def close(self):
|
||||
"""Close the writer."""
|
||||
if self._writer is not None:
|
||||
self._writer.Shut()
|
||||
|
||||
|
||||
class SummaryWriter(BaseWriter):
|
||||
"""SummaryWriter for write summaries."""
|
||||
|
||||
def init_writer(self):
|
||||
"""Write some metadata etc."""
|
||||
self.writer.Write(package_init_event().SerializeToString())
|
||||
|
||||
def write(self, plugin, mode, data):
|
||||
"""Write data to file."""
|
||||
if plugin in ('summary', 'graph'):
|
||||
self.writer.Write(data)
|
||||
|
||||
|
||||
class LineageWriter(BaseWriter):
|
||||
"""LineageWriter for write lineage."""
|
||||
|
||||
def write(self, plugin, mode, data):
|
||||
"""Write data to file."""
|
||||
if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'):
|
||||
self.writer.Write(data)
|
@ -0,0 +1,114 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Write events to disk in a base directory."""
|
||||
import os
|
||||
from collections import deque
|
||||
from multiprocessing import Pool, Process, Queue, cpu_count
|
||||
|
||||
from ._lineage_adapter import serialize_to_lineage_event
|
||||
from ._summary_adapter import package_graph_event, package_summary_event
|
||||
from ._summary_writer import SummaryWriter, LineageWriter
|
||||
|
||||
|
||||
def _pack_data(datadict):
|
||||
"""Pack data according to which plugin."""
|
||||
result = []
|
||||
summaries, step, mode = [], None, None
|
||||
for plugin, datalist in datadict.items():
|
||||
for data in datalist:
|
||||
if plugin == 'graph':
|
||||
result.append([plugin, data.get('mode'), package_graph_event(data.get('value')).SerializeToString()])
|
||||
elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'):
|
||||
result.append([plugin, data.get('mode'), serialize_to_lineage_event(plugin, data.get('value'))])
|
||||
elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
|
||||
summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
|
||||
step = data.get('step')
|
||||
mode = data.get('mode')
|
||||
if summaries:
|
||||
result.append(['summary', mode, package_summary_event(summaries, step).SerializeToString()])
|
||||
return result
|
||||
|
||||
|
||||
class WriterPool(Process):
|
||||
"""
|
||||
Use a set of pooled resident processes for writing a list of file.
|
||||
|
||||
Args:
|
||||
base_dir (str): The base directory to hold all the files.
|
||||
filelist (str): The mapping from short name to long filename.
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir, **filedict) -> None:
|
||||
super().__init__()
|
||||
self._base_dir, self._filedict = base_dir, filedict
|
||||
self._queue = Queue(cpu_count() * 2)
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
writers = self._get_writers()
|
||||
|
||||
with Pool() as pool:
|
||||
deq = deque()
|
||||
while True:
|
||||
while deq and deq[0].ready():
|
||||
for plugin, mode, data in deq.popleft().get():
|
||||
for writer in writers:
|
||||
writer.write(plugin, mode, data)
|
||||
|
||||
if not self._queue.empty():
|
||||
action, data = self._queue.get()
|
||||
if action == 'WRITE':
|
||||
deq.append(pool.apply_async(_pack_data, (data,)))
|
||||
elif action == 'FLUSH':
|
||||
for writer in writers:
|
||||
writer.flush()
|
||||
elif action == 'END':
|
||||
break
|
||||
for result in deq:
|
||||
for plugin, mode, data in result.get():
|
||||
for writer in writers:
|
||||
writer.write(plugin, mode, data)
|
||||
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
def _get_writers(self):
|
||||
writers = []
|
||||
for plugin, filename in self._filedict.items():
|
||||
filepath = os.path.join(self._base_dir, filename)
|
||||
if plugin == 'summary':
|
||||
writers.append(SummaryWriter(filepath))
|
||||
elif plugin == 'lineage':
|
||||
writers.append(LineageWriter(filepath))
|
||||
return writers
|
||||
|
||||
def write(self, data) -> None:
|
||||
"""
|
||||
Write the event to file.
|
||||
|
||||
Args:
|
||||
name (str): The key of a specified file.
|
||||
data (Optional[str, Tuple[list, int]]): The data to write.
|
||||
"""
|
||||
self._queue.put(('WRITE', data))
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer and sync data to disk."""
|
||||
self._queue.put(('FLUSH', None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the writer."""
|
||||
self._queue.put(('END', None))
|
||||
self.join()
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue