diff --git a/example/resnet50_imagenet2012_THOR/model/dataset_helper.py b/example/resnet50_imagenet2012_THOR/model/dataset_helper.py index 474bccf42f..77f67344c2 100644 --- a/example/resnet50_imagenet2012_THOR/model/dataset_helper.py +++ b/example/resnet50_imagenet2012_THOR/model/dataset_helper.py @@ -15,6 +15,7 @@ """Dataset help for minddata dataset""" from mindspore._checkparam import check_bool from mindspore.parallel._utils import _get_device_num, _get_parallel_mode +from mindspore.train.dataset_helper import _send_data from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ _to_full_shapes from mindspore.train.parallel_utils import ParallelMode @@ -67,7 +68,13 @@ class _DatasetIter: self.loop_size = dataset.get_dataset_size() else: self.loop_size = dataset.__loop_size__ - dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) + dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name + + if not hasattr(dataset, '__no_send__'): + _send_data(dataset) + else: + _send_data(dataset) self.ind = 0 self.dataset = dataset diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 958ea7e2c2..e6f0a3b71d 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -16,11 +16,10 @@ import os import numpy as np from mindspore.common.tensor import Tensor -from mindspore.common.dtype import dtype_to_nptype +from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype from mindspore.common import dtype as mstype from mindspore import log as logger from mindspore.common.api import _executor -from mindspore.common.dtype import pytype_to_dtype def _convert_type(types): @@ -64,8 +63,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): input_indexs, phase=phase) - # engine dataset to write data to tdt queue - exec_dataset.send() return exec_dataset diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 6cee80cabb..cf09e3a067 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -23,6 +23,14 @@ from ..nn.wrap import GetNextSingleOp from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full +def _send_data(dataset): + """Engine dataset to write data to tdt queue.""" + if not hasattr(dataset, '__has_sent__'): + exec_dataset = dataset.__TRANSFER_DATASET__ + exec_dataset.send() + dataset.__has_sent__ = True + + class DatasetHelper: """ Help function to use the Minddata dataset. @@ -81,7 +89,13 @@ class _DatasetIter: self.loop_size = dataset.get_dataset_size() else: self.loop_size = dataset.__loop_size__ - dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) + dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name + + if not hasattr(dataset, '__no_send__'): + _send_data(dataset) + else: + _send_data(dataset) self.ind = 0 self.dataset = dataset diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 68042d8d0a..ef3d572e3e 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -285,7 +285,7 @@ class Model: if self._parameter_broadcast: self._train_network.set_broadcast_flag() - + train_dataset.__no_send__ = True train_dataset_helper, train_network = self._exec_preprocess(self._train_network, is_train=True, phase='train', @@ -302,6 +302,7 @@ class Model: self._eval_network.set_train(False) self._eval_network.phase = 'eval' + valid_dataset.__no_send__ = True valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, is_train=False, phase='eval', diff --git a/tests/st/networks/models/resnet50/src_thor/dataset_helper.py b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py index e02dcc6acb..1ca4d388f7 100644 --- a/tests/st/networks/models/resnet50/src_thor/dataset_helper.py +++ b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py @@ -15,6 +15,7 @@ """Dataset help for minddata dataset""" from mindspore._checkparam import check_bool from mindspore.parallel._utils import _get_device_num, _get_parallel_mode +from mindspore.train.dataset_helper import _send_data from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ _to_full_shapes from mindspore.train.parallel_utils import ParallelMode @@ -69,7 +70,13 @@ class _DatasetIter: self.loop_size = dataset.get_dataset_size() else: self.loop_size = dataset.__loop_size__ - dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) + dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name + + if not hasattr(dataset, '__no_send__'): + _send_data(dataset) + else: + _send_data(dataset) self.ind = 0 self.dataset = dataset