diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 9137a7b671..9d13c3adbc 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -24,6 +24,7 @@ from importlib import import_module import numpy as np from mindspore import log as logger +from mindspore import context from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.train.summary.summary_record import SummaryRecord @@ -453,9 +454,10 @@ class SummaryCollector(Callback): if not self._collect_specified_data.get('collect_input_data'): return - if self._dataset_sink_mode: + if self._dataset_sink_mode and context.get_context('device_target') == 'Ascend': self._collect_specified_data['collect_input_data'] = False - logger.warning('SummaryCollector is not supported to record input data in dataset sink mode.') + logger.warning('On Ascend device, SummaryCollector is not supported to record input data ' + 'in dataset sink mode.') return input_data = getattr(cb_params, 'train_dataset_element', None)