fix get daataset size error

pull/3396/head
panfengfeng 5 years ago
parent 21edd691de
commit 939e612906

@ -212,12 +212,12 @@ Status DeviceQueueOp::SendDataToGPU() {
RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle));
total_batch++; total_batch++;
} }
if (!TaskManager::FindMe()->Interrupted()) if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else else
is_break_loop = true; is_break_loop = true;
} }
if (!TaskManager::FindMe()->Interrupted()) if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else else
is_break_loop = true; is_break_loop = true;

@ -2401,7 +2401,7 @@ class TransferDataset(DatasetOp):
# need to keep iterator alive so the executionTree is not destroyed # need to keep iterator alive so the executionTree is not destroyed
if self._noop_mode(): if self._noop_mode():
return return
self.iterator = TupleIterator(self, num_epochs=-1) self.iterator = TupleIterator(self, num_epochs=num_epochs)
def stop_send(self): def stop_send(self):
self.iterator.depipeline.StopSend() self.iterator.depipeline.StopSend()

@ -24,13 +24,18 @@ from ..nn.wrap import GetNextSingleOp
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
def _send_data(dataset): def _send_data(dataset, epoch_num):
"""Engine dataset to write data to tdt queue.""" """Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'): if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__ exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send() exec_dataset.send(epoch_num)
dataset.__has_sent__ = True dataset.__has_sent__ = True
def _send_data_no_flag(dataset, epoch_num):
"""Engine dataset to write data to tdt queue directly."""
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send(epoch_num)
class DatasetHelper: class DatasetHelper:
""" """
@ -54,7 +59,7 @@ class DatasetHelper:
>>> outputs = network(*inputs) >>> outputs = network(*inputs)
""" """
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
check_int(sink_size) check_int(sink_size)
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
@ -74,7 +79,7 @@ class DatasetHelper:
iterclass = _DatasetIterMS iterclass = _DatasetIterMS
elif context.get_context("device_target") == "CPU": elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
self.iter = iterclass(dataset, sink_size) self.iter = iterclass(dataset, sink_size, epoch_num)
else: else:
iterclass = _DatasetIterNormal iterclass = _DatasetIterNormal
self.iter = iterclass(dataset) self.iter = iterclass(dataset)
@ -98,7 +103,7 @@ class DatasetHelper:
class _DatasetIter: class _DatasetIter:
"""Base iter for dataset helper""" """Base iter for dataset helper"""
def __init__(self, dataset, sink_size): def __init__(self, dataset, sink_size, epoch_num):
self.dataset = dataset self.dataset = dataset
self.sink_size = sink_size self.sink_size = sink_size
self.sink_count = 1 self.sink_count = 1
@ -110,9 +115,9 @@ class _DatasetIter:
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset) _send_data(dataset, epoch_num)
else: else:
_send_data(dataset) _send_data_no_flag(dataset, epoch_num)
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
@ -156,8 +161,8 @@ class _DatasetIter:
class _DatasetIterGE(_DatasetIter): class _DatasetIterGE(_DatasetIter):
"""Iter for GE.""" """Iter for GE."""
def __init__(self, dataset, sink_size): def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size) super().__init__(dataset, sink_size, epoch_num)
self.sink_count = self.get_sink_count(dataset) self.sink_count = self.get_sink_count(dataset)
batch_expand_num = 1 batch_expand_num = 1
if _need_to_full(): if _need_to_full():
@ -172,8 +177,8 @@ class _DatasetIterGE(_DatasetIter):
class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)""" """Iter for context (device_target=Ascend)"""
def __init__(self, dataset, sink_size): def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size) super().__init__(dataset, sink_size, epoch_num)
self.sink_count = self.get_sink_count(dataset) self.sink_count = self.get_sink_count(dataset)
ms_role = os.getenv("MS_ROLE") ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"): if ms_role in ("MS_PSERVER", "MS_SCHED"):
@ -193,8 +198,8 @@ class _DatasetIterMSLoopSink(_DatasetIter):
class _DatasetIterMS(_DatasetIter): class _DatasetIterMS(_DatasetIter):
"""Iter for MS(enable_loop_sink=False).""" """Iter for MS(enable_loop_sink=False)."""
def __init__(self, dataset, sink_size): def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size) super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0: if sink_size > 0:
self.sink_count = sink_size self.sink_count = sink_size
else: else:
@ -206,8 +211,8 @@ class _DatasetIterMS(_DatasetIter):
class _DatasetIterPSLite(_DatasetIter): class _DatasetIterPSLite(_DatasetIter):
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
def __init__(self, dataset, sink_size): def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size) super().__init__(dataset, sink_size, epoch_num)
self.sink_count = 1 self.sink_count = 1
self.sink_size = 1 self.sink_size = 1
self.op = None self.op = None

@ -227,7 +227,7 @@ class Model:
scaling_sens /= self._device_number scaling_sens /= self._device_number
return scaling_sens return scaling_sens
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1): def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
"""Initializes dataset.""" """Initializes dataset."""
need_wrap = False need_wrap = False
if dataset_sink_mode: if dataset_sink_mode:
@ -239,7 +239,7 @@ class Model:
if not is_train: if not is_train:
dataset.__loop_size__ = 1 dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size) dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
# remove later to deal with loop sink # remove later to deal with loop sink
if need_wrap: if need_wrap:
@ -399,12 +399,18 @@ class Model:
cb_params (_InternalCallbackParam): Callback parameters. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data each sink. Default: -1. sink_size (int): Control the amount of data each sink. Default: -1.
""" """
if sink_size == -1:
epoch_num = epoch
else:
epoch_num = epoch * sink_size // train_dataset.get_dataset_size()
dataset_helper, train_network = self._exec_preprocess(self._train_network, dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True, is_train=True,
phase='train', phase='train',
dataset=train_dataset, dataset=train_dataset,
dataset_sink_mode=True, dataset_sink_mode=True,
sink_size=sink_size) sink_size=sink_size,
epoch_num=epoch_num)
self._train_network = train_network self._train_network = train_network
cb_params.train_network = self._train_network cb_params.train_network = self._train_network
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
@ -621,6 +627,8 @@ class Model:
list_callback.step_end(run_context) list_callback.step_end(run_context)
self._update_metrics(outputs) self._update_metrics(outputs)
valid_dataset.reset()
metrics = self._get_metrics() metrics = self._get_metrics()
cb_params.metrics = metrics cb_params.metrics = metrics
list_callback.end(run_context) list_callback.end(run_context)

@ -58,7 +58,7 @@ class MindData:
def create_tuple_iterator(self): def create_tuple_iterator(self):
return self.__iter__() return self.__iter__()
def send(self): def send(self, num_epochs=-1):
pass pass
def stop_send(self): def stop_send(self):

@ -15,11 +15,16 @@
"""Dataset help for minddata dataset""" """Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train.dataset_helper import _send_data
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes _to_full_shapes
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
def _send_data(dataset):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send()
dataset.__has_sent__ = True
class DatasetHelper: class DatasetHelper:
""" """

Loading…
Cancel
Save