fix cifar stuck problem

pull/4930/head
xiefangqi 5 years ago
parent d00f7d8f74
commit e3e7820413

@ -336,6 +336,9 @@ Status CifarOp::GetCifarFiles() {
std::string err_msg = "Unable to open directory " + dataset_directory.toString(); std::string err_msg = "Unable to open directory " + dataset_directory.toString();
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
if (cifar_files_.size() == 0) {
RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_);
}
std::sort(cifar_files_.begin(), cifar_files_.end()); std::sort(cifar_files_.begin(), cifar_files_.end());
return Status::OK(); return Status::OK();
} }

@ -24,6 +24,7 @@ from mindspore import log as logger
DATA_DIR_10 = "../data/dataset/testCifar10Data" DATA_DIR_10 = "../data/dataset/testCifar10Data"
DATA_DIR_100 = "../data/dataset/testCifar100Data" DATA_DIR_100 = "../data/dataset/testCifar100Data"
NO_BIN_DIR = "../data/dataset/testMnistData"
def load_cifar(path, kind="cifar10"): def load_cifar(path, kind="cifar10"):
@ -208,6 +209,12 @@ def test_cifar10_exception():
with pytest.raises(ValueError, match=error_msg_6): with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88) ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88)
error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar10Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar10_visualize(plot=False): def test_cifar10_visualize(plot=False):
""" """
@ -352,6 +359,12 @@ def test_cifar100_exception():
with pytest.raises(ValueError, match=error_msg_6): with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88) ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88)
error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar100Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar100_visualize(plot=False): def test_cifar100_visualize(plot=False):
""" """

Loading…
Cancel
Save