diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 51aba7cb22..a3e5a2245c 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -67,7 +67,7 @@ class SummaryCollector(Callback): SummaryCollector can help you to collect some common information. It can help you to collect loss, learning late, computational graph and so on. - SummaryCollector also enables the summary operator to collect data from a summary file. + SummaryCollector also enables the summary operator to collect data to summary files. Note: 1. Multiple SummaryCollector instances in callback list are not allowed. @@ -367,6 +367,7 @@ class SummaryCollector(Callback): 'but got `{cb_params.mode}` mode.') self._record.set_mode(cb_params.mode) + self._dataset_sink_mode = cb_params.dataset_sink_mode def step_end(self, run_context): cb_params = run_context.original_args() @@ -386,8 +387,6 @@ class SummaryCollector(Callback): self._record.record(cb_params.cur_step_num) if self._first_step: - # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario - self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num self._tensor_collect_range = self._get_tensor_collect_range(cb_params, self._dataset_sink_mode) self._collect_at_step_end(cb_params, plugin_filter=None) self._first_step = False @@ -480,34 +479,44 @@ class SummaryCollector(Callback): def _collect_input_data(self, cb_params): """Only support to collect image data.""" - if not self._collect_specified_data.get('collect_input_data'): + if not self._is_allowed_to_collect_input_data(cb_params): return + input_data = getattr(cb_params, 'train_dataset_element', None) + if isinstance(input_data, (list, tuple)) and input_data: + input_data = input_data[0] + try: + self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data) + except (TypeError, ValueError): + logger.warning('The input data of network are not image, so will not collect by SummaryCollector.') + self._collect_specified_data['collect_input_data'] = False + return + + def _is_allowed_to_collect_input_data(self, cb_params): + """Check if the input data is allowed to be collected.""" + if not self._collect_specified_data.get('collect_input_data'): + return False + + if self._dataset_sink_mode and (context.get_context('device_target') in ('Ascend', 'GPU')): + logger.warning("On Ascend or GPU device, SummaryCollector is not supported to " + "record input data in dataset sink mode.") + self._collect_specified_data['collect_input_data'] = False + return False + input_data = getattr(cb_params, 'train_dataset_element', None) if not isinstance(input_data, (Tensor, list, tuple)): self._collect_specified_data['collect_input_data'] = False logger.warning("The type of input data is not Tensor/list/tuple, " "so SummaryCollector will not collect input data.") - return + return False if not isinstance(input_data, Tensor) and not input_data: self._collect_specified_data['collect_input_data'] = False logger.warning("The 'train_dataset_element' in cb_params is empty, " - "so SummaryCollector will not record the input data.") + "so SummaryCollector will not record the input data. ") + return False - if self._dataset_sink_mode and context.get_context('device_target') == 'Ascend': - logger.warning('On Ascend device, SummaryCollector is not supported to record input data ' - 'in dataset sink mode.') - return - - if isinstance(input_data, (list, tuple)) and input_data: - input_data = input_data[0] - try: - self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data) - except (TypeError, ValueError): - logger.warning('The input data of network are not image, so will not collect by SummaryCollector.') - self._collect_specified_data['collect_input_data'] = False - return + return True def _collect_dataset_graph(self, cb_params): """Only collect train dataset graph.""" diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 1c07768225..5c2f8ed3d5 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -430,6 +430,7 @@ class Model: train_dataset.__total_batch__ = epoch * sink_size cb_params.cur_step_num = 0 + cb_params.dataset_sink_mode = True run_context = RunContext(cb_params) list_callback.begin(run_context) @@ -497,11 +498,11 @@ class Model: dataset_sink_mode=False, epoch_num=epoch) cb_params.cur_step_num = 0 + cb_params.dataset_sink_mode = False run_context = RunContext(cb_params) list_callback.begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False - for i in range(epoch): cb_params.cur_epoch_num = i + 1 @@ -623,8 +624,8 @@ class Model: dataset_sink_mode=True) self._eval_network = eval_network cb_params.eval_network = self._eval_network + cb_params.dataset_sink_mode = True list_callback.begin(run_context) - for inputs in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) @@ -654,8 +655,8 @@ class Model: Dict, which returns the loss value and metrics values for the model in the test mode. """ run_context = RunContext(cb_params) + cb_params.dataset_sink_mode = False list_callback.begin(run_context) - dataset_helper, _ = self._exec_preprocess(self._eval_network, is_train=False, phase='eval', diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index 667f90f0b7..84602738d4 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -187,7 +187,7 @@ class WriterPool(ctx.Process): """Check if the summary process should survive.""" is_exit = False if not psutil.pid_exists(self._training_pid): - logger.warning("The training process %d is killed, summary process will exit.", self._training_pid) + logger.warning("The training process %d has exited, summary process will exit.", self._training_pid) is_exit = True if not self._writers: