From 501e01470e5abb360b6e23115ba17e43070a9759 Mon Sep 17 00:00:00 2001 From: wuweikang Date: Tue, 22 Sep 2020 23:03:50 +0800 Subject: [PATCH] revert tile.py & sink_size = get_dataset_size() --- mindspore/ops/_op_impl/tbe/tile.py | 5 +++-- mindspore/train/model.py | 2 ++ tests/dataset_mock.py | 2 +- tests/ut/python/parallel/test_loss_scale.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/tile.py b/mindspore/ops/_op_impl/tbe/tile.py index 178a623091..d55d427266 100644 --- a/mindspore/ops/_op_impl/tbe/tile.py +++ b/mindspore/ops/_op_impl/tbe/tile.py @@ -26,8 +26,9 @@ tile_op_info = TBERegOp("Tile") \ .attr("multiples", "optional", "listInt", "all")\ .input(0, "x1", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .op_pattern("dynamicFormat") \ - .dtype_format(DataType.None_None, DataType.None_None) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .get_op_info() diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 32da5a9e01..dd2fdfab67 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -548,6 +548,8 @@ class Model: >>> model.train(2, dataset) """ check_bool(dataset_sink_mode) + if sink_size == -1: + sink_size = train_dataset.get_dataset_size() check_int(sink_size) if sink_size < -1 or sink_size == 0: raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index 0ee0a90813..ee9af3c40f 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -21,7 +21,7 @@ from mindspore import Tensor class MindData: """ Stub for MindData """ - def __init__(self, size=None, batch_size=None, repeat_count=1, + def __init__(self, size=1, batch_size=None, repeat_count=1, np_types=None, output_shapes=None, input_indexs=()): self._size = size self._batch_size = batch_size diff --git a/tests/ut/python/parallel/test_loss_scale.py b/tests/ut/python/parallel/test_loss_scale.py index 88997160f8..9649679474 100644 --- a/tests/ut/python/parallel/test_loss_scale.py +++ b/tests/ut/python/parallel/test_loss_scale.py @@ -113,7 +113,7 @@ class TrainOneStepWithLossScaleCell(nn.Cell): class DatasetLenet(MindData): def __init__(self, predict, label, length=3): - super(DatasetLenet, self).__init__() + super(DatasetLenet, self).__init__(size=length) self.predict = predict self.label = label self.index = 0