diff --git a/mindspore/train/model.py b/mindspore/train/model.py index a8f142c9dd..adf7495eb8 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -549,6 +549,8 @@ class Model: >>> model.train(2, dataset) """ check_bool(dataset_sink_mode) + if sink_size == -1: + sink_size = train_dataset.get_dataset_size() check_int(sink_size) if sink_size < -1 or sink_size == 0: raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index ad6fd0f12c..988f15dd8c 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -21,7 +21,7 @@ from mindspore import Tensor class MindData: """ Stub for MindData """ - def __init__(self, size=None, batch_size=None, repeat_count=1, + def __init__(self, size=1, batch_size=None, repeat_count=1, np_types=None, output_shapes=None, input_indexs=()): self._size = size self._batch_size = batch_size diff --git a/tests/ut/python/parallel/test_loss_scale.py b/tests/ut/python/parallel/test_loss_scale.py index 88997160f8..9649679474 100644 --- a/tests/ut/python/parallel/test_loss_scale.py +++ b/tests/ut/python/parallel/test_loss_scale.py @@ -113,7 +113,7 @@ class TrainOneStepWithLossScaleCell(nn.Cell): class DatasetLenet(MindData): def __init__(self, predict, label, length=3): - super(DatasetLenet, self).__init__() + super(DatasetLenet, self).__init__(size=length) self.predict = predict self.label = label self.index = 0