!2616 Decide whether to collect data by dataset sink mode and current step in SummaryCollector

Merge pull request !2616 from ougongchang/fix_collect_freq
pull/2616/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 19f79cd744

@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False self._has_saved_custom_data = False
self._is_parse_loss_success = True self._is_parse_loss_success = True
self._first_step = True self._first_step = True
self._dataset_sink_mode = True
def __enter__(self): def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir) self._record = SummaryRecord(log_dir=self._summary_dir)
return self return self
@ -279,15 +282,15 @@ class SummaryCollector(Callback):
def step_end(self, run_context): def step_end(self, run_context):
cb_params = run_context.original_args() cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)
if cb_params.mode == ModeEnum.TRAIN.value: if cb_params.mode == ModeEnum.TRAIN.value:
# Make sure the first step data is recorded if not self._is_collect_this_step(cb_params):
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
return return
self._first_step = False
if not self._has_saved_train_network: if not self._has_saved_train_network:
self._collect_graphs(cb_params) self._collect_graphs(cb_params)
@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params) self._collect_metric(cb_params)
self._collect_histogram(cb_params) self._collect_histogram(cb_params)
self._first_step = False
self._record.record(cb_params.cur_step_num) self._record.record(cb_params.cur_step_num)
def end(self, run_context): def end(self, run_context):
@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.") f"but expected only one {self.__class__.__name__} instance.")
def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True
@staticmethod @staticmethod
def _package_custom_lineage_data(custom_lineage_data): def _package_custom_lineage_data(custom_lineage_data):
""" """

Loading…
Cancel
Save