|
|
|
@ -25,6 +25,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
|
from mindspore._c_dataengine import InterpolationMode
|
|
|
|
|
from mindspore.dataset.transforms.vision import Inter
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
@ -151,6 +152,51 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
|
|
|
|
|
assert data_set.get_dataset_size() == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
|
|
|
|
|
"""tutorial for cv minddataset."""
|
|
|
|
|
columns_list = ["data", "label"]
|
|
|
|
|
num_readers = 4
|
|
|
|
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
|
|
|
|
decode_op = vision.Decode()
|
|
|
|
|
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
|
|
|
|
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
|
|
|
|
|
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
|
|
|
|
|
data_set = data_set.batch(2)
|
|
|
|
|
data_set = data_set.repeat(2)
|
|
|
|
|
num_iter = 0
|
|
|
|
|
labels = []
|
|
|
|
|
for item in data_set.create_dict_iterator():
|
|
|
|
|
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
|
|
|
|
|
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
|
|
|
|
|
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
|
|
|
|
|
num_iter += 1
|
|
|
|
|
labels.append(item["label"])
|
|
|
|
|
assert num_iter == 10
|
|
|
|
|
logger.info("repeat shuffle: {}".format(labels))
|
|
|
|
|
assert len(labels) == 10
|
|
|
|
|
assert labels[0:5] == labels[0:5]
|
|
|
|
|
assert labels[0:5] != labels[5:5]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
|
|
|
|
|
"""tutorial for cv minddataset."""
|
|
|
|
|
columns_list = ["data", "label"]
|
|
|
|
|
num_readers = 4
|
|
|
|
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
|
|
|
|
decode_op = vision.Decode()
|
|
|
|
|
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
|
|
|
|
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
|
|
|
|
|
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
|
|
|
|
|
data_set = data_set.batch(32, drop_remainder=True)
|
|
|
|
|
num_iter = 0
|
|
|
|
|
for item in data_set.create_dict_iterator():
|
|
|
|
|
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
|
|
|
|
|
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
|
|
|
|
|
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
|
|
|
|
|
num_iter += 1
|
|
|
|
|
assert num_iter == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_issue_888(add_and_remove_cv_file):
|
|
|
|
|
"""issue 888 test."""
|
|
|
|
|
columns_list = ["data", "label"]
|
|
|
|
|